diff --git a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td index 5294f3903eda..f832e636ceec 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td @@ -19,16 +19,16 @@ include "FIRRTLStatements.td" include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" -/// Constraint that matches a ConstantOp, SpecialConstantOp, AggregateConstantOp, or InvalidValueOp. +/// Constraint that matches a ConstantOp, SpecialConstantOp, or AggregateConstantOp. def AnyConstantOp : Constraint + isa_and_nonnull ($0.getDefiningOp()) }]>>; /// Constraint that matches non-constant operations. Used to ensure that the /// const-on-LHS rewriting patterns converge in case both operands are constant. def NonConstantOp : Constraint + !isa_and_nonnull ($0.getDefiningOp()) }]>>; @@ -138,7 +138,7 @@ def DShlOfConstant : Pat< (PadPrimOp (MoveNameHint $old, (ShlPrimOp $x, (LimitConstant32 $cst))), (TypeWidth32 $old)), [(KnownWidth $x)]>; -// dshr(a, const) -> shl(a, const) +// dshr(a, const) -> shr(a, const) def DShrOfConstant : Pat< (DShrPrimOp:$old $x, (ConstantOp $cst)), (PadPrimOp (MoveNameHint $old, (ShrPrimOp $x, (LimitConstant32 $cst))), (TypeWidth32 $old)), @@ -206,6 +206,18 @@ def AddOfSelf : Pat < (MoveNameHint $old, (ShlPrimOp $x, (NativeCodeCall<"1">))), [(KnownWidth $x)]>; +// add((pad a, n), b) -> pad(add(a, b), n) +def AddOfPadL : Pat < + (AddPrimOp:$old (PadPrimOp $a, $m), $b), + (PadPrimOp (MoveNameHint $old, (AddPrimOp $a, $b)), (TypeWidth32 $old)), + [(KnownWidth $a), (KnownWidth $b)]>; + +// add(b, (pad a, n)) -> pad(add(b, a), n) +def AddOfPadR : Pat < + (AddPrimOp:$old $b, (PadPrimOp $a, $m)), + (PadPrimOp (MoveNameHint $old, (AddPrimOp $b, $a)), (TypeWidth32 $old)), + [(KnownWidth $a), (KnownWidth $b)]>; + // sub(a, 0) -> a def SubOfZero : Pat < (SubPrimOp:$old $x, (ConstantOp:$zcst $cst)), @@ -230,6 +242,18 @@ def SubOfSelf : Pat < (NativeCodeCall<"$_builder.create($0.getLoc(), $0.getType().cast(), getIntZerosAttr($0.getType()))"> $old), [(KnownWidth $x)]>; +// sub((pad a, n), b) -> pad(sub(a, b), n) +def SubOfPadL : Pat < + (SubPrimOp:$old (PadPrimOp $a, $m), $b), + (PadPrimOp (MoveNameHint $old, (SubPrimOp $a, $b)), (TypeWidth32 $old)), + [(KnownWidth $a), (KnownWidth $b)]>; + +// sub(b, (pad a, n)) -> pad(sub(b, a), n) +def SubOfPadR : Pat < + (SubPrimOp:$old $b, (PadPrimOp $a, $m)), + (PadPrimOp (MoveNameHint $old, (SubPrimOp $b, $a)), (TypeWidth32 $old)), + [(KnownWidth $a), (KnownWidth $b)]>; + // and(x, 0) -> 0, fold can't handle all cases def AndOfZero : Pat < (AndPrimOp:$old $x, (ConstantOp:$zcst $cst)), diff --git a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp index e7b1ac836d35..d47fd567175f 100644 --- a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp @@ -370,9 +370,9 @@ OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) { void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .insert( - context); + results.insert( + context); } OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) { @@ -384,7 +384,8 @@ OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) { void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.insert(context); + patterns::SubFromZeroUnsigned, patterns::SubOfSelf, + patterns::SubOfPadL, patterns::SubOfPadR>(context); } OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) { diff --git a/test/Dialect/FIRRTL/canonicalization.mlir b/test/Dialect/FIRRTL/canonicalization.mlir index 506cf3b3e82b..43badd92af59 100644 --- a/test/Dialect/FIRRTL/canonicalization.mlir +++ b/test/Dialect/FIRRTL/canonicalization.mlir @@ -1754,6 +1754,90 @@ firrtl.module @add_double(out %out: !firrtl.uint<5>, in %in: !firrtl.uint<4>) { firrtl.connect %out, %add : !firrtl.uint<5>, !firrtl.uint<5> } +// CHECK-LABEL: @add_narrow +// CHECK-NEXT: %[[add1:.+]] = firrtl.add %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5> +// CHECK-NEXT: %[[pad1:.+]] = firrtl.pad %[[add1]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7> +// CHECK-NEXT: %[[add2:.+]] = firrtl.add %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5> +// CHECK-NEXT: %[[pad2:.+]] = firrtl.pad %[[add2]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7> +// CHECK-NEXT: %[[add3:.+]] = firrtl.add %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5> +// CHECK-NEXT: %[[pad3:.+]] = firrtl.pad %[[add3]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7> +// CHECK-NEXT: firrtl.strictconnect %out1, %[[pad1]] +// CHECK-NEXT: firrtl.strictconnect %out2, %[[pad2]] +// CHECK-NEXT: firrtl.strictconnect %out3, %[[pad3]] +firrtl.module @add_narrow(out %out1: !firrtl.uint<7>, out %out2: !firrtl.uint<7>, out %out3: !firrtl.uint<7>, in %in1: !firrtl.uint<4>, in %in2: !firrtl.uint<2>) { + %t1 = firrtl.pad %in1, 6 : (!firrtl.uint<4>) -> !firrtl.uint<6> + %t2 = firrtl.pad %in2, 6 : (!firrtl.uint<2>) -> !firrtl.uint<6> + %add1 = firrtl.add %t1, %t2 : (!firrtl.uint<6>, !firrtl.uint<6>) -> !firrtl.uint<7> + %add2 = firrtl.add %in1, %t2 : (!firrtl.uint<4>, !firrtl.uint<6>) -> !firrtl.uint<7> + %add3 = firrtl.add %t1, %in2 : (!firrtl.uint<6>, !firrtl.uint<2>) -> !firrtl.uint<7> + firrtl.strictconnect %out1, %add1 : !firrtl.uint<7> + firrtl.strictconnect %out2, %add2 : !firrtl.uint<7> + firrtl.strictconnect %out3, %add3 : !firrtl.uint<7> +} + +// CHECK-LABEL: @adds_narrow +// CHECK-NEXT: %[[add1:.+]] = firrtl.add %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5> +// CHECK-NEXT: %[[pad1:.+]] = firrtl.pad %[[add1]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7> +// CHECK-NEXT: %[[add2:.+]] = firrtl.add %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5> +// CHECK-NEXT: %[[pad2:.+]] = firrtl.pad %[[add2]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7> +// CHECK-NEXT: %[[add3:.+]] = firrtl.add %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5> +// CHECK-NEXT: %[[pad3:.+]] = firrtl.pad %[[add3]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7> +// CHECK-NEXT: firrtl.strictconnect %out1, %[[pad1]] +// CHECK-NEXT: firrtl.strictconnect %out2, %[[pad2]] +// CHECK-NEXT: firrtl.strictconnect %out3, %[[pad3]] +firrtl.module @adds_narrow(out %out1: !firrtl.sint<7>, out %out2: !firrtl.sint<7>, out %out3: !firrtl.sint<7>, in %in1: !firrtl.sint<4>, in %in2: !firrtl.sint<2>) { + %t1 = firrtl.pad %in1, 6 : (!firrtl.sint<4>) -> !firrtl.sint<6> + %t2 = firrtl.pad %in2, 6 : (!firrtl.sint<2>) -> !firrtl.sint<6> + %add1 = firrtl.add %t1, %t2 : (!firrtl.sint<6>, !firrtl.sint<6>) -> !firrtl.sint<7> + %add2 = firrtl.add %in1, %t2 : (!firrtl.sint<4>, !firrtl.sint<6>) -> !firrtl.sint<7> + %add3 = firrtl.add %t1, %in2 : (!firrtl.sint<6>, !firrtl.sint<2>) -> !firrtl.sint<7> + firrtl.strictconnect %out1, %add1 : !firrtl.sint<7> + firrtl.strictconnect %out2, %add2 : !firrtl.sint<7> + firrtl.strictconnect %out3, %add3 : !firrtl.sint<7> +} + +// CHECK-LABEL: @sub_narrow +// CHECK-NEXT: %[[add1:.+]] = firrtl.sub %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5> +// CHECK-NEXT: %[[pad1:.+]] = firrtl.pad %[[add1]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7> +// CHECK-NEXT: %[[add2:.+]] = firrtl.sub %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5> +// CHECK-NEXT: %[[pad2:.+]] = firrtl.pad %[[add2]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7> +// CHECK-NEXT: %[[add3:.+]] = firrtl.sub %in1, %in2 : (!firrtl.uint<4>, !firrtl.uint<2>) -> !firrtl.uint<5> +// CHECK-NEXT: %[[pad3:.+]] = firrtl.pad %[[add3]], 7 : (!firrtl.uint<5>) -> !firrtl.uint<7> +// CHECK-NEXT: firrtl.strictconnect %out1, %[[pad1]] +// CHECK-NEXT: firrtl.strictconnect %out2, %[[pad2]] +// CHECK-NEXT: firrtl.strictconnect %out3, %[[pad3]] +firrtl.module @sub_narrow(out %out1: !firrtl.uint<7>, out %out2: !firrtl.uint<7>, out %out3: !firrtl.uint<7>, in %in1: !firrtl.uint<4>, in %in2: !firrtl.uint<2>) { + %t1 = firrtl.pad %in1, 6 : (!firrtl.uint<4>) -> !firrtl.uint<6> + %t2 = firrtl.pad %in2, 6 : (!firrtl.uint<2>) -> !firrtl.uint<6> + %add1 = firrtl.sub %t1, %t2 : (!firrtl.uint<6>, !firrtl.uint<6>) -> !firrtl.uint<7> + %add2 = firrtl.sub %in1, %t2 : (!firrtl.uint<4>, !firrtl.uint<6>) -> !firrtl.uint<7> + %add3 = firrtl.sub %t1, %in2 : (!firrtl.uint<6>, !firrtl.uint<2>) -> !firrtl.uint<7> + firrtl.strictconnect %out1, %add1 : !firrtl.uint<7> + firrtl.strictconnect %out2, %add2 : !firrtl.uint<7> + firrtl.strictconnect %out3, %add3 : !firrtl.uint<7> +} + +// CHECK-LABEL: @subs_narrow +// CHECK-NEXT: %[[add1:.+]] = firrtl.sub %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5> +// CHECK-NEXT: %[[pad1:.+]] = firrtl.pad %[[add1]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7> +// CHECK-NEXT: %[[add2:.+]] = firrtl.sub %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5> +// CHECK-NEXT: %[[pad2:.+]] = firrtl.pad %[[add2]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7> +// CHECK-NEXT: %[[add3:.+]] = firrtl.sub %in1, %in2 : (!firrtl.sint<4>, !firrtl.sint<2>) -> !firrtl.sint<5> +// CHECK-NEXT: %[[pad3:.+]] = firrtl.pad %[[add3]], 7 : (!firrtl.sint<5>) -> !firrtl.sint<7> +// CHECK-NEXT: firrtl.strictconnect %out1, %[[pad1]] +// CHECK-NEXT: firrtl.strictconnect %out2, %[[pad2]] +// CHECK-NEXT: firrtl.strictconnect %out3, %[[pad3]] +firrtl.module @subs_narrow(out %out1: !firrtl.sint<7>, out %out2: !firrtl.sint<7>, out %out3: !firrtl.sint<7>, in %in1: !firrtl.sint<4>, in %in2: !firrtl.sint<2>) { + %t1 = firrtl.pad %in1, 6 : (!firrtl.sint<4>) -> !firrtl.sint<6> + %t2 = firrtl.pad %in2, 6 : (!firrtl.sint<2>) -> !firrtl.sint<6> + %add1 = firrtl.sub %t1, %t2 : (!firrtl.sint<6>, !firrtl.sint<6>) -> !firrtl.sint<7> + %add2 = firrtl.sub %in1, %t2 : (!firrtl.sint<4>, !firrtl.sint<6>) -> !firrtl.sint<7> + %add3 = firrtl.sub %t1, %in2 : (!firrtl.sint<6>, !firrtl.sint<2>) -> !firrtl.sint<7> + firrtl.strictconnect %out1, %add1 : !firrtl.sint<7> + firrtl.strictconnect %out2, %add2 : !firrtl.sint<7> + firrtl.strictconnect %out3, %add3 : !firrtl.sint<7> +} + // CHECK-LABEL: @sub_cst_prop1 // CHECK-NEXT: %c1_ui9 = firrtl.constant 1 : !firrtl.uint<9> // CHECK-NEXT: firrtl.strictconnect %out_b, %c1_ui9 : !firrtl.uint<9>