Skip to content

Commit

Permalink
[FIRRLT] Narrow adds (#4869)
Browse files Browse the repository at this point in the history
Narrow add and sub operations.
  • Loading branch information
darthscsi authored Mar 22, 2023
1 parent 64e66c8 commit f1532b0
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 8 deletions.
32 changes: 28 additions & 4 deletions include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ include "FIRRTLStatements.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"

/// Constraint that matches a ConstantOp, SpecialConstantOp, AggregateConstantOp, or InvalidValueOp.
/// Constraint that matches a ConstantOp, SpecialConstantOp, or AggregateConstantOp.
def AnyConstantOp : Constraint<CPred<[{
isa_and_nonnull<ConstantOp, SpecialConstantOp, InvalidValueOp, AggregateConstantOp>
isa_and_nonnull<ConstantOp, SpecialConstantOp, AggregateConstantOp>
($0.getDefiningOp())
}]>>;

/// Constraint that matches non-constant operations. Used to ensure that the
/// const-on-LHS rewriting patterns converge in case both operands are constant.
def NonConstantOp : Constraint<CPred<[{
!isa_and_nonnull<ConstantOp, SpecialConstantOp, InvalidValueOp, AggregateConstantOp>
!isa_and_nonnull<ConstantOp, SpecialConstantOp, AggregateConstantOp>
($0.getDefiningOp())
}]>>;

Expand Down Expand Up @@ -138,7 +138,7 @@ def DShlOfConstant : Pat<
(PadPrimOp (MoveNameHint $old, (ShlPrimOp $x, (LimitConstant32 $cst))), (TypeWidth32 $old)),
[(KnownWidth $x)]>;

// dshr(a, const) -> shl(a, const)
// dshr(a, const) -> shr(a, const)
def DShrOfConstant : Pat<
(DShrPrimOp:$old $x, (ConstantOp $cst)),
(PadPrimOp (MoveNameHint $old, (ShrPrimOp $x, (LimitConstant32 $cst))), (TypeWidth32 $old)),
Expand Down Expand Up @@ -206,6 +206,18 @@ def AddOfSelf : Pat <
(MoveNameHint $old, (ShlPrimOp $x, (NativeCodeCall<"1">))),
[(KnownWidth $x)]>;

// add((pad a, n), b) -> pad(add(a, b), n)
def AddOfPadL : Pat <
(AddPrimOp:$old (PadPrimOp $a, $m), $b),
(PadPrimOp (MoveNameHint $old, (AddPrimOp $a, $b)), (TypeWidth32 $old)),
[(KnownWidth $a), (KnownWidth $b)]>;

// add(b, (pad a, n)) -> pad(add(b, a), n)
def AddOfPadR : Pat <
(AddPrimOp:$old $b, (PadPrimOp $a, $m)),
(PadPrimOp (MoveNameHint $old, (AddPrimOp $b, $a)), (TypeWidth32 $old)),
[(KnownWidth $a), (KnownWidth $b)]>;

// sub(a, 0) -> a
def SubOfZero : Pat <
(SubPrimOp:$old $x, (ConstantOp:$zcst $cst)),
Expand All @@ -230,6 +242,18 @@ def SubOfSelf : Pat <
(NativeCodeCall<"$_builder.create<ConstantOp>($0.getLoc(), $0.getType().cast<IntType>(), getIntZerosAttr($0.getType()))"> $old),
[(KnownWidth $x)]>;

// sub((pad a, n), b) -> pad(sub(a, b), n)
def SubOfPadL : Pat <
(SubPrimOp:$old (PadPrimOp $a, $m), $b),
(PadPrimOp (MoveNameHint $old, (SubPrimOp $a, $b)), (TypeWidth32 $old)),
[(KnownWidth $a), (KnownWidth $b)]>;

// sub(b, (pad a, n)) -> pad(sub(b, a), n)
def SubOfPadR : Pat <
(SubPrimOp:$old $b, (PadPrimOp $a, $m)),
(PadPrimOp (MoveNameHint $old, (SubPrimOp $b, $a)), (TypeWidth32 $old)),
[(KnownWidth $a), (KnownWidth $b)]>;

// and(x, 0) -> 0, fold can't handle all cases
def AndOfZero : Pat <
(AndPrimOp:$old $x, (ConstantOp:$zcst $cst)),
Expand Down
9 changes: 5 additions & 4 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,9 @@ OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {

void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.insert<patterns::moveConstAdd, patterns::AddOfZero, patterns::AddOfSelf>(
context);
results.insert<patterns::moveConstAdd, patterns::AddOfZero,
patterns::AddOfSelf, patterns::AddOfPadL, patterns::AddOfPadR>(
context);
}

OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
Expand All @@ -384,7 +384,8 @@ OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
patterns::SubFromZeroUnsigned, patterns::SubOfSelf>(context);
patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
patterns::SubOfPadL, patterns::SubOfPadR>(context);
}

OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
Expand Down
84 changes: 84 additions & 0 deletions test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,90 @@ firrtl.module @add_double(out %out: !firrtl.uint<5>, in %in: !firrtl.uint<4>) {
firrtl.connect %out, %add : !firrtl.uint<5>, !firrtl.uint<5>
}

// CHECK-LABEL: @add_narrow
// CHECK-NEXT: %[[add1:.+]] = firrtl.add %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5>
// CHECK-NEXT: %[[pad1:.+]] = firrtl.pad %[[add1]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7>
// CHECK-NEXT: %[[add2:.+]] = firrtl.add %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5>
// CHECK-NEXT: %[[pad2:.+]] = firrtl.pad %[[add2]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7>
// CHECK-NEXT: %[[add3:.+]] = firrtl.add %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5>
// CHECK-NEXT: %[[pad3:.+]] = firrtl.pad %[[add3]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7>
// CHECK-NEXT: firrtl.strictconnect %out1, %[[pad1]]
// CHECK-NEXT: firrtl.strictconnect %out2, %[[pad2]]
// CHECK-NEXT: firrtl.strictconnect %out3, %[[pad3]]
firrtl.module @add_narrow(out %out1: !firrtl.uint<7>, out %out2: !firrtl.uint<7>, out %out3: !firrtl.uint<7>, in %in1: !firrtl.uint<4>, in %in2: !firrtl.uint<2>) {
%t1 = firrtl.pad %in1, 6 : (!firrtl.uint<4>) -> !firrtl.uint<6>
%t2 = firrtl.pad %in2, 6 : (!firrtl.uint<2>) -> !firrtl.uint<6>
%add1 = firrtl.add %t1, %t2 : (!firrtl.uint<6>, !firrtl.uint<6>) -> !firrtl.uint<7>
%add2 = firrtl.add %in1, %t2 : (!firrtl.uint<4>, !firrtl.uint<6>) -> !firrtl.uint<7>
%add3 = firrtl.add %t1, %in2 : (!firrtl.uint<6>, !firrtl.uint<2>) -> !firrtl.uint<7>
firrtl.strictconnect %out1, %add1 : !firrtl.uint<7>
firrtl.strictconnect %out2, %add2 : !firrtl.uint<7>
firrtl.strictconnect %out3, %add3 : !firrtl.uint<7>
}

// CHECK-LABEL: @adds_narrow
// CHECK-NEXT: %[[add1:.+]] = firrtl.add %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5>
// CHECK-NEXT: %[[pad1:.+]] = firrtl.pad %[[add1]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7>
// CHECK-NEXT: %[[add2:.+]] = firrtl.add %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5>
// CHECK-NEXT: %[[pad2:.+]] = firrtl.pad %[[add2]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7>
// CHECK-NEXT: %[[add3:.+]] = firrtl.add %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5>
// CHECK-NEXT: %[[pad3:.+]] = firrtl.pad %[[add3]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7>
// CHECK-NEXT: firrtl.strictconnect %out1, %[[pad1]]
// CHECK-NEXT: firrtl.strictconnect %out2, %[[pad2]]
// CHECK-NEXT: firrtl.strictconnect %out3, %[[pad3]]
firrtl.module @adds_narrow(out %out1: !firrtl.sint<7>, out %out2: !firrtl.sint<7>, out %out3: !firrtl.sint<7>, in %in1: !firrtl.sint<4>, in %in2: !firrtl.sint<2>) {
%t1 = firrtl.pad %in1, 6 : (!firrtl.sint<4>) -> !firrtl.sint<6>
%t2 = firrtl.pad %in2, 6 : (!firrtl.sint<2>) -> !firrtl.sint<6>
%add1 = firrtl.add %t1, %t2 : (!firrtl.sint<6>, !firrtl.sint<6>) -> !firrtl.sint<7>
%add2 = firrtl.add %in1, %t2 : (!firrtl.sint<4>, !firrtl.sint<6>) -> !firrtl.sint<7>
%add3 = firrtl.add %t1, %in2 : (!firrtl.sint<6>, !firrtl.sint<2>) -> !firrtl.sint<7>
firrtl.strictconnect %out1, %add1 : !firrtl.sint<7>
firrtl.strictconnect %out2, %add2 : !firrtl.sint<7>
firrtl.strictconnect %out3, %add3 : !firrtl.sint<7>
}

// CHECK-LABEL: @sub_narrow
// CHECK-NEXT: %[[add1:.+]] = firrtl.sub %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5>
// CHECK-NEXT: %[[pad1:.+]] = firrtl.pad %[[add1]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7>
// CHECK-NEXT: %[[add2:.+]] = firrtl.sub %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5>
// CHECK-NEXT: %[[pad2:.+]] = firrtl.pad %[[add2]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7>
// CHECK-NEXT: %[[add3:.+]] = firrtl.sub %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5>
// CHECK-NEXT: %[[pad3:.+]] = firrtl.pad %[[add3]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7>
// CHECK-NEXT: firrtl.strictconnect %out1, %[[pad1]]
// CHECK-NEXT: firrtl.strictconnect %out2, %[[pad2]]
// CHECK-NEXT: firrtl.strictconnect %out3, %[[pad3]]
firrtl.module @sub_narrow(out %out1: !firrtl.uint<7>, out %out2: !firrtl.uint<7>, out %out3: !firrtl.uint<7>, in %in1: !firrtl.uint<4>, in %in2: !firrtl.uint<2>) {
%t1 = firrtl.pad %in1, 6 : (!firrtl.uint<4>) -> !firrtl.uint<6>
%t2 = firrtl.pad %in2, 6 : (!firrtl.uint<2>) -> !firrtl.uint<6>
%add1 = firrtl.sub %t1, %t2 : (!firrtl.uint<6>, !firrtl.uint<6>) -> !firrtl.uint<7>
%add2 = firrtl.sub %in1, %t2 : (!firrtl.uint<4>, !firrtl.uint<6>) -> !firrtl.uint<7>
%add3 = firrtl.sub %t1, %in2 : (!firrtl.uint<6>, !firrtl.uint<2>) -> !firrtl.uint<7>
firrtl.strictconnect %out1, %add1 : !firrtl.uint<7>
firrtl.strictconnect %out2, %add2 : !firrtl.uint<7>
firrtl.strictconnect %out3, %add3 : !firrtl.uint<7>
}

// CHECK-LABEL: @subs_narrow
// CHECK-NEXT: %[[add1:.+]] = firrtl.sub %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5>
// CHECK-NEXT: %[[pad1:.+]] = firrtl.pad %[[add1]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7>
// CHECK-NEXT: %[[add2:.+]] = firrtl.sub %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5>
// CHECK-NEXT: %[[pad2:.+]] = firrtl.pad %[[add2]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7>
// CHECK-NEXT: %[[add3:.+]] = firrtl.sub %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5>
// CHECK-NEXT: %[[pad3:.+]] = firrtl.pad %[[add3]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7>
// CHECK-NEXT: firrtl.strictconnect %out1, %[[pad1]]
// CHECK-NEXT: firrtl.strictconnect %out2, %[[pad2]]
// CHECK-NEXT: firrtl.strictconnect %out3, %[[pad3]]
firrtl.module @subs_narrow(out %out1: !firrtl.sint<7>, out %out2: !firrtl.sint<7>, out %out3: !firrtl.sint<7>, in %in1: !firrtl.sint<4>, in %in2: !firrtl.sint<2>) {
%t1 = firrtl.pad %in1, 6 : (!firrtl.sint<4>) -> !firrtl.sint<6>
%t2 = firrtl.pad %in2, 6 : (!firrtl.sint<2>) -> !firrtl.sint<6>
%add1 = firrtl.sub %t1, %t2 : (!firrtl.sint<6>, !firrtl.sint<6>) -> !firrtl.sint<7>
%add2 = firrtl.sub %in1, %t2 : (!firrtl.sint<4>, !firrtl.sint<6>) -> !firrtl.sint<7>
%add3 = firrtl.sub %t1, %in2 : (!firrtl.sint<6>, !firrtl.sint<2>) -> !firrtl.sint<7>
firrtl.strictconnect %out1, %add1 : !firrtl.sint<7>
firrtl.strictconnect %out2, %add2 : !firrtl.sint<7>
firrtl.strictconnect %out3, %add3 : !firrtl.sint<7>
}

// CHECK-LABEL: @sub_cst_prop1
// CHECK-NEXT: %c1_ui9 = firrtl.constant 1 : !firrtl.uint<9>
// CHECK-NEXT: firrtl.strictconnect %out_b, %c1_ui9 : !firrtl.uint<9>
Expand Down

0 comments on commit f1532b0

Please sign in to comment.