From 638681ae3ab6147ae2c08fea8c7bc743e5a90003 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 5 Mar 2025 08:23:46 -0800 Subject: [PATCH 1/2] Add mux-of-const canonicalizations to FIRRTL These canonicalizations are already present in comb, but I am adding them to the firrtl dialect so that they can run before lower layers. We generally pool constants at the top of a module. If constants are used within a layerblock, they can be captured as ports during lower layers, which in turn prevents comb canonicalizers from running. Ideally, layer-sink would aggressively clone or sink constants into layerblocks, but we're not there yet. These canonicalizations are important because they can drastically simplify the if/else branches which drive registers. When lowering a seq.firreg to sv, we incorporate any muxes feeding into the register as if-else branches, which conditionally drive the register. These canonicalizations can transform a deeply nested mux chain into a tree of and/or expressions, which reduces the depth of if/else branches --- .../Dialect/FIRRTL/FIRRTLCanonicalization.td | 32 +++++++++++++++++++ lib/Dialect/FIRRTL/FIRRTLFolds.cpp | 12 +++---- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td index a0157c0508e1..0615d72e5fcf 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td @@ -656,6 +656,38 @@ def MuxNEQ : Pat< (MoveNameHint $old, (MuxPrimOp (EQPrimOp $a, $b), $y, $x)), [(EqualTypes $x, $y), (KnownWidth $x)]>; +// mux(cond, 0, b) -> and(~cond, b) +def MuxLHS0 : Pat< + (MuxPrimOp:$old $cond, (ConstantOp:$a $_), $b), + (MoveNameHint $old, (AndPrimOp (NotPrimOp $cond), $b)), + [(ZeroConstantOp $a), + (EqualTypes $cond, $a), + (EqualTypes $cond, $b)]>; + +// mux(cond, 1, b) -> or(cond, b) +def MuxLHS1 : Pat< + (MuxPrimOp:$old $cond, (ConstantOp:$a $_), $b), + (MoveNameHint $old, (OrPrimOp $cond, $b)), + [(OneConstantOp $a), + (EqualTypes $cond, $a), + (EqualTypes $cond, $b)]>; + +// mux(cond, a, 0) -> and(cond, a) +def MuxRHS0 : Pat< + (MuxPrimOp:$old $cond, $a, (ConstantOp:$b $_)), + (MoveNameHint $old, (AndPrimOp $cond, $a)), + [(ZeroConstantOp $b), + (EqualTypes $cond, $a), + (EqualTypes $cond, $b)]>; + +// mux(cond, a, 1) -> or(~cond, a) +def MuxRHS1 : Pat< + (MuxPrimOp:$old $cond, $a, (ConstantOp:$b $_)), + (MoveNameHint $old, (OrPrimOp (NotPrimOp $cond), $a)), + [(OneConstantOp $b), + (EqualTypes $cond, $a), + (EqualTypes $cond, $b)]>; + // mux(cond : u0, a, b) -> mux(0 : u1, a, b) def MuxPadSel : Pat< (MuxPrimOp:$old $cond, $a, $b), diff --git a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp index 822e521699b7..983db91690f0 100644 --- a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp @@ -1472,12 +1472,12 @@ class MuxSharedCond : public mlir::RewritePattern { void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add( - context); + results.add(context); } void Mux2CellIntrinsicOp::getCanonicalizationPatterns( From ac0bead9434a3830b29553b3dd533e6f6c7330d6 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Fri, 7 Mar 2025 14:18:28 -0800 Subject: [PATCH 2/2] Add some canonicalizers for muxes and registers This commit adds four new canonicalization patterns to FIRRTL. mux(cond, 0, b) -> and(not(cond), b) mux(cond, 1, b) -> or(cond, b) mux(cond, a, 0) -> and(cond, a) mux(cond, a, 1) -> or(not(cond), a) These canonicalizations are already present for the comb dialect, but we want to run these canonicalizers before lowering layers, which can obscure constant ops behind ports. The problem with these mux canonicalizers is, they conflict with a register canonicalizer. This register canonicalizer converts a register to a constant, if the register's next-value is a mux of either the register itself, or a constant. For example: connect(reg, mux(reset, 0, reg)) ==> reg -> 0 These new canonicalizers would transform the connect to: connect(reg, and(not(reset), reg)) ...which prevents the register canonicalizer from running. To get this behaviour back, this PR adds four additional canonicalizations for both registers and reg resets: For registers, the canonicalizers are: connect(reg, and(reg, x)) ==> reg -> 0 connect(reg, and(x, reg)) ==> reg -> 0 connect(reg, or(reg, x)) ==> reg -> 1 connect(reg, or(x, reg)) ==> reg -> 1 For regresets, we have the same canonicalizers, but with an additional check: the reset value must be a constant zero or one. reset(reg) = 0 ==> connect(reg, and(reg, x)) ==> reg -> 0 reset(reg) = 0 ==> connect(reg, and(x, reg)) ==> reg -> 0 reset(reg) = 1 ==> connect(reg, or(reg, x)) ==> reg -> 1 reset(reg) = 1 ==> connect(reg, or(x, reg)) ==> reg -> 1 --- .../Dialect/FIRRTL/FIRRTLCanonicalization.td | 8 +- lib/Dialect/FIRRTL/FIRRTLFolds.cpp | 141 +++++++++++++++++- test/Dialect/FIRRTL/canonicalization.mlir | 78 +++++++++- 3 files changed, 216 insertions(+), 11 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td index 0615d72e5fcf..5b918070ac87 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td @@ -657,7 +657,7 @@ def MuxNEQ : Pat< [(EqualTypes $x, $y), (KnownWidth $x)]>; // mux(cond, 0, b) -> and(~cond, b) -def MuxLHS0 : Pat< +def MuxLhsZero : Pat< (MuxPrimOp:$old $cond, (ConstantOp:$a $_), $b), (MoveNameHint $old, (AndPrimOp (NotPrimOp $cond), $b)), [(ZeroConstantOp $a), @@ -665,7 +665,7 @@ def MuxLHS0 : Pat< (EqualTypes $cond, $b)]>; // mux(cond, 1, b) -> or(cond, b) -def MuxLHS1 : Pat< +def MuxLhsOne : Pat< (MuxPrimOp:$old $cond, (ConstantOp:$a $_), $b), (MoveNameHint $old, (OrPrimOp $cond, $b)), [(OneConstantOp $a), @@ -673,7 +673,7 @@ def MuxLHS1 : Pat< (EqualTypes $cond, $b)]>; // mux(cond, a, 0) -> and(cond, a) -def MuxRHS0 : Pat< +def MuxRhsZero : Pat< (MuxPrimOp:$old $cond, $a, (ConstantOp:$b $_)), (MoveNameHint $old, (AndPrimOp $cond, $a)), [(ZeroConstantOp $b), @@ -681,7 +681,7 @@ def MuxRHS0 : Pat< (EqualTypes $cond, $b)]>; // mux(cond, a, 1) -> or(~cond, a) -def MuxRHS1 : Pat< +def MuxRhsOne : Pat< (MuxPrimOp:$old $cond, $a, (ConstantOp:$b $_)), (MoveNameHint $old, (OrPrimOp (NotPrimOp $cond), $a)), [(OneConstantOp $b), diff --git a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp index 983db91690f0..01f479cf2fe2 100644 --- a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp @@ -1476,8 +1476,8 @@ void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot, patterns::MuxSameTrue, patterns::MuxSameFalse, patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, - patterns::MuxPadSel, patterns::MuxLHS0, patterns::MuxLHS1, - patterns::MuxRHS0, patterns::MuxRHS1>(context); + patterns::MuxPadSel, patterns::MuxLhsZero, patterns::MuxLhsOne, + patterns::MuxRhsZero, patterns::MuxRhsOne>(context); } void Mux2CellIntrinsicOp::getCanonicalizationPatterns( @@ -2256,6 +2256,77 @@ struct FoldResetMux : public mlir::RewritePattern { }; } // namespace +namespace { +/// This canonicalizer provides the following patterns: +/// reset(reg) = 0 ==> connect(reg, and(reg, x)) ==> reg -> 0 +/// reset(reg) = 0 ==> connect(reg, and(x, reg)) ==> reg -> 0 +/// reset(reg) = 1 ==> connect(reg, or(reg, x)) ==> reg -> 1 +/// reset(reg) = 1 ==> connect(reg, or(x, reg)) ==> reg -> 1 +/// +/// Justification: The initial value of a register is indeterminant, which means +/// we are free to choose any initial value when optimizing the circuit. For the +/// AND patterns, if the reset is zero, and we assume the initial value is zero, +/// then the register will always be zero. For the OR patterns, if the reset is +/// one, and we assume the initial value is one, then the register will always +/// be one. +struct RegResetAndOrOfSelf : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(RegResetOp op, + PatternRewriter &rewriter) const override { + // Do not fold the register away if it is important. + if (hasDontTouch(op.getOperation()) || !AnnotationSet(op).empty() || + op.isForceable()) + return failure(); + + // This canonicalization only applies when the register holds 1 bit. + auto type = dyn_cast(op.getResult().getType()); + if (!type || type.getWidthOrSentinel() != 1) + return failure(); + + // This canonicalization only applies when the reset is a constant. + auto reset = + dyn_cast_or_null(op.getResetValue().getDefiningOp()); + if (!reset) + return failure(); + + auto value = reset.getValue(); + + // Find the one true connect, or bail. + auto connect = getSingleConnectUserOf(op.getResult()); + if (!connect) + return failure(); + + auto *src = connect.getSrc().getDefiningOp(); + if (!src) + return failure(); + + if (value == 0) { + if (auto srcAnd = dyn_cast(src)) { + if (srcAnd.getLhs().getDefiningOp() == op || + srcAnd.getRhs().getDefiningOp() == op) { + rewriter.eraseOp(connect); + replaceOpAndCopyName(rewriter, op, reset.getResult()); + return success(); + } + } + } + + if (value == 1) { + if (auto srcOr = dyn_cast(src)) { + if (srcOr.getLhs().getDefiningOp() == op || + srcOr.getRhs().getDefiningOp() == op) { + rewriter.eraseOp(connect); + replaceOpAndCopyName(rewriter, op, reset.getResult()); + return success(); + } + } + } + + return failure(); + } +}; +} // namespace + static bool isDefinedByOneConstantOp(Value v) { if (auto c = v.getDefiningOp()) return c.getValue().isOne(); @@ -2279,7 +2350,9 @@ canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) { void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); results.add(canonicalizeRegResetWithOneReset); results.add(demoteForceableIfUnused); } @@ -3156,10 +3229,66 @@ static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) { return success(); } +/// This canonicalizer provides the following patterns: +/// connect(reg, and(reg, x)) ==> reg -> 0 +/// connect(reg, and(x, reg)) ==> reg -> 0 +/// connect(reg, or(reg, x)) ==> reg -> 1 +/// connect(reg, or(x, reg)) ==> reg -> 1 +/// +/// Justification: The initial value of a register is indeterminant, which means +/// We are free to choose any initial value when optimizing the circuit. For the +/// AND patterns, if we assume the initial value is zero, then the register will +/// always be zero. For the OR patterns, if we assume the initial value is one, +/// then the register will always be one. +static LogicalResult foldRegAndOrOfSelf(RegOp reg, PatternRewriter &rewriter) { + // This canonicalization only applies when the register holds 1 bit. + auto type = dyn_cast(reg.getResult().getType()); + if (!type || type.getWidthOrSentinel() != 1) + return failure(); + + // Find the one true connect, or bail. + auto connect = getSingleConnectUserOf(reg.getResult()); + if (!connect) + return failure(); + + auto *src = connect.getSrc().getDefiningOp(); + if (!src) + return failure(); + + // connect(reg, and(reg, x)) ==> reg -> 0 + // connect(reg, and(x, reg)) ==> reg -> 0 + if (auto srcAnd = dyn_cast(src)) { + if (srcAnd.getLhs().getDefiningOp() == reg || + srcAnd.getRhs().getDefiningOp() == reg) { + auto attr = getIntAttr(type, APInt(1, 0)); + replaceOpWithNewOpAndCopyName(rewriter, reg, type, attr); + rewriter.eraseOp(connect); + return success(); + } + } + + // connect(reg, or(reg, x)) ==> reg -> 1 + // connect(reg, or(x, reg)) ==> reg -> 1 + if (auto srcOr = dyn_cast(src)) { + if (srcOr.getLhs().getDefiningOp() == reg || + srcOr.getRhs().getDefiningOp() == reg) { + auto attr = getIntAttr(type, APInt(1, 1)); + replaceOpWithNewOpAndCopyName(rewriter, reg, type, attr); + rewriter.eraseOp(connect); + return success(); + } + } + + return failure(); +} + LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) { - if (!hasDontTouch(op.getOperation()) && !op.isForceable() && - succeeded(foldHiddenReset(op, rewriter))) - return success(); + if (!hasDontTouch(op.getOperation()) && !op.isForceable()) { + if (succeeded(foldHiddenReset(op, rewriter))) + return success(); + if (succeeded(foldRegAndOrOfSelf(op, rewriter))) + return success(); + } if (succeeded(demoteForceableIfUnused(op, rewriter))) return success(); diff --git a/test/Dialect/FIRRTL/canonicalization.mlir b/test/Dialect/FIRRTL/canonicalization.mlir index 55916cdaba55..b91e921565c3 100644 --- a/test/Dialect/FIRRTL/canonicalization.mlir +++ b/test/Dialect/FIRRTL/canonicalization.mlir @@ -571,7 +571,11 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>, out %out3: !firrtl.uint<1>, out %out4: !firrtl.uint<4>, out %out5: !firrtl.uint<1>, - out %out6: !firrtl.uint<1>) { + out %out6: !firrtl.uint<1>, + out %out7: !firrtl.uint<1>, + out %out8: !firrtl.uint<1>, + out %out9: !firrtl.uint<1>, + out %out10: !firrtl.uint<1>) { // CHECK: firrtl.matchingconnect %out, %in %0 = firrtl.int.mux2cell (%cond, %in, %in) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> firrtl.connect %out, %0 : !firrtl.uint<4>, !firrtl.uint<4> @@ -634,6 +638,32 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>, // CHECK-NEXT: mux4cell(%[[SEL]], %17 = firrtl.int.mux4cell (%val1, %val1, %val2, %val1, %val2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> firrtl.matchingconnect %out6, %17 : !firrtl.uint<1> + + // mux(cond, 0, x) -> and(~cond, x) + // CHECK: [[V1:%.+]] = firrtl.not %cond + // CHECK-NEXT: [[V2:%.+]] = firrtl.and [[V1]], %val1 + // CHECK-NEXT: firrtl.matchingconnect %out7, [[V2]] + %18 = firrtl.mux (%cond, %c0_ui1, %val1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.connect %out7, %18 : !firrtl.uint<1>, !firrtl.uint<1> + + // mux(cond, 1, x) -> or(cond, x) + // CHECK: [[V:%.+]] = firrtl.or %cond, %val1 + // CHECK-NEXT: firrtl.matchingconnect %out8, [[V]] + %19 = firrtl.mux (%cond, %c1_ui1, %val1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.connect %out8, %19 : !firrtl.uint<1>, !firrtl.uint<1> + + // mux(cond, x, 0) -> and(cond, x) + // CHECK: [[V:%.+]] = firrtl.and %cond, %val1 + // CHECK-NEXT: firrtl.matchingconnect %out9, [[V]] + %20 = firrtl.mux (%cond, %val1, %c0_ui1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.connect %out9, %20 : !firrtl.uint<1>, !firrtl.uint<1> + + // mux(cond, x, 1) -> or(~cond, x) + // CHECK: [[V1:%.+]] = firrtl.not %cond + // CHECK-NEXT: [[V2:%.+]] = firrtl.or [[V1]], %val1 + // CHECK-NEXT: firrtl.matchingconnect %out10, [[V2]] + %21 = firrtl.mux (%cond, %val1, %c1_ui1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.connect %out10, %21 : !firrtl.uint<1>, !firrtl.uint<1> } // CHECK-LABEL: firrtl.module @Pad @@ -2618,6 +2648,52 @@ firrtl.module @constReg9(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, i firrtl.matchingconnect %out, %r : !firrtl.uint<1> } +// Check that a register driven by an and(en, reg) is folded to a constant zero. +// CHECK-LABEL: @constRegAnd +firrtl.module @constRegAnd(in %clock: !firrtl.clock, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { + // CHECK-NOT: firrtl.reg + // CHECK: firrtl.matchingconnect %out, %c0_ui1 + %r = firrtl.reg %clock {firrtl.random_init_start = 0 : ui64} : !firrtl.clock, !firrtl.uint<1> + %0 = firrtl.and %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.matchingconnect %out, %r : !firrtl.uint<1> +} + +// Check that a register driven by an or(en, reg) is folded to a constant one. +// CHECK-LABEL: @constRegOr +firrtl.module @constRegOr(in %clock: !firrtl.clock, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { + // CHECK-NOT: firrtl.reg + // CHECK: firrtl.matchingconnect %out, %c1_ui1 + %r = firrtl.reg %clock {firrtl.random_init_start = 0 : ui64} : !firrtl.clock, !firrtl.uint<1> + %0 = firrtl.or %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.matchingconnect %out, %r : !firrtl.uint<1> +} + +// Check that a regreset driven by an and(en, reg) is folded to a constant zero, when the reset is zero. +// CHECK-LABEL: @constRegResetAnd +firrtl.module @constRegResetAnd(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { + // CHECK-NOT: firrtl.reg + // CHECK: firrtl.matchingconnect %out, %c0_ui1 + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %r = firrtl.regreset %clock, %reset, %c0_ui1 : !firrtl.clock, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1> + %0 = firrtl.and %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.matchingconnect %out, %r : !firrtl.uint<1> +} + +// Check that a regreset driven by an or(en, reg) is folded to a constant one, when the reset is one. +// CHECK-LABEL: @constRegResetOr +firrtl.module @constRegResetOr(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) { + // CHECK-NOT: firrtl.reg + // CHECK: firrtl.matchingconnect %out, %c1_ui1 + %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1> + %r = firrtl.regreset %clock, %reset, %c1_ui1 : !firrtl.clock, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1> + %0 = firrtl.or %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.matchingconnect %out, %r : !firrtl.uint<1> +} + firrtl.module @BitCast(out %o:!firrtl.bundle, ready: uint<1>, data: uint<1>> ) { %a = firrtl.wire : !firrtl.bundle, ready: uint<1>, data: uint<1>> %b = firrtl.bitcast %a : (!firrtl.bundle, ready: uint<1>, data: uint<1>>) -> (!firrtl.uint<3>)