Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,38 @@ def MuxNEQ : Pat<
(MoveNameHint $old, (MuxPrimOp (EQPrimOp $a, $b), $y, $x)),
[(EqualTypes $x, $y), (KnownWidth $x)]>;

// mux(cond, 0, b) -> and(~cond, b)
def MuxLhsZero : Pat<
(MuxPrimOp:$old $cond, (ConstantOp:$a $_), $b),
(MoveNameHint $old, (AndPrimOp (NotPrimOp $cond), $b)),
[(ZeroConstantOp $a),
(EqualTypes $cond, $a),
(EqualTypes $cond, $b)]>;

// mux(cond, 1, b) -> or(cond, b)
def MuxLhsOne : Pat<
(MuxPrimOp:$old $cond, (ConstantOp:$a $_), $b),
(MoveNameHint $old, (OrPrimOp $cond, $b)),
[(OneConstantOp $a),
(EqualTypes $cond, $a),
(EqualTypes $cond, $b)]>;

// mux(cond, a, 0) -> and(cond, a)
def MuxRhsZero : Pat<
(MuxPrimOp:$old $cond, $a, (ConstantOp:$b $_)),
(MoveNameHint $old, (AndPrimOp $cond, $a)),
[(ZeroConstantOp $b),
(EqualTypes $cond, $a),
(EqualTypes $cond, $b)]>;

// mux(cond, a, 1) -> or(~cond, a)
def MuxRhsOne : Pat<
(MuxPrimOp:$old $cond, $a, (ConstantOp:$b $_)),
(MoveNameHint $old, (OrPrimOp (NotPrimOp $cond), $a)),
[(OneConstantOp $b),
(EqualTypes $cond, $a),
(EqualTypes $cond, $b)]>;

// mux(cond : u0, a, b) -> mux(0 : u1, a, b)
def MuxPadSel : Pat<
(MuxPrimOp:$old $cond, $a, $b),
Expand Down
149 changes: 139 additions & 10 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1472,12 +1472,12 @@ class MuxSharedCond : public mlir::RewritePattern {

void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
patterns::MuxSameTrue, patterns::MuxSameFalse,
patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
context);
results.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
patterns::MuxEQOperandsSwapped, patterns::MuxNEQ,
patterns::MuxNot, patterns::MuxSameTrue, patterns::MuxSameFalse,
patterns::NarrowMuxLHS, patterns::NarrowMuxRHS,
patterns::MuxPadSel, patterns::MuxLhsZero, patterns::MuxLhsOne,
patterns::MuxRhsZero, patterns::MuxRhsOne>(context);
}

void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
Expand Down Expand Up @@ -2256,6 +2256,77 @@ struct FoldResetMux : public mlir::RewritePattern {
};
} // namespace

namespace {
/// This canonicalizer provides the following patterns:
/// reset(reg) = 0 ==> connect(reg, and(reg, x)) ==> reg -> 0
/// reset(reg) = 0 ==> connect(reg, and(x, reg)) ==> reg -> 0
/// reset(reg) = 1 ==> connect(reg, or(reg, x)) ==> reg -> 1
/// reset(reg) = 1 ==> connect(reg, or(x, reg)) ==> reg -> 1
Comment on lines +2261 to +2264
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mega-nit: This syntax is bothering me and took me a while to figure it out. The ==> is an "AND" in the first use and is "is converted to" (or "IMPLIES"?) in the second.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK I can delete these comments, I don't think they're helpful. FWIW ==> is implication, -> is canonicalization.

///
/// Justification: The initial value of a register is indeterminant, which means
/// we are free to choose any initial value when optimizing the circuit. For the
/// AND patterns, if the reset is zero, and we assume the initial value is zero,
/// then the register will always be zero. For the OR patterns, if the reset is
/// one, and we assume the initial value is one, then the register will always
/// be one.
struct RegResetAndOrOfSelf : public mlir::OpRewritePattern<RegResetOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(RegResetOp op,
PatternRewriter &rewriter) const override {
Comment on lines +2274 to +2275
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supreme nit:

Suggested change
LogicalResult matchAndRewrite(RegResetOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(RegResetOp reg,
PatternRewriter &rewriter) const override {

When reading the later rewrite pattern, I found it clearer to use the specific "reg" instead of the generic "op".

// Do not fold the register away if it is important.
if (hasDontTouch(op.getOperation()) || !AnnotationSet(op).empty() ||
op.isForceable())
return failure();

// This canonicalization only applies when the register holds 1 bit.
auto type = dyn_cast<UIntType>(op.getResult().getType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think SIntType would work, too?

if (!type || type.getWidthOrSentinel() != 1)
return failure();

// This canonicalization only applies when the reset is a constant.
auto reset =
dyn_cast_or_null<ConstantOp>(op.getResetValue().getDefiningOp());
if (!reset)
return failure();
Comment on lines +2286 to +2290
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: please be exact. This is the reset value and not the reset (or exactly, the reset signal).


auto value = reset.getValue();

// Find the one true connect, or bail.
auto connect = getSingleConnectUserOf(op.getResult());
if (!connect)
return failure();

auto *src = connect.getSrc().getDefiningOp();
if (!src)
return failure();

if (value == 0) {
if (auto srcAnd = dyn_cast<AndPrimOp>(src)) {
if (srcAnd.getLhs().getDefiningOp() == op ||
srcAnd.getRhs().getDefiningOp() == op) {
rewriter.eraseOp(connect);
replaceOpAndCopyName(rewriter, op, reset.getResult());
return success();
}
}
}

if (value == 1) {
if (auto srcOr = dyn_cast<OrPrimOp>(src)) {
if (srcOr.getLhs().getDefiningOp() == op ||
srcOr.getRhs().getDefiningOp() == op) {
rewriter.eraseOp(connect);
replaceOpAndCopyName(rewriter, op, reset.getResult());
return success();
}
}
}

return failure();
}
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Challenge 1: Can this be written in using ODS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if that makes sense, since this isn't so much a pattern on the register op, but more, a pattern on the connect driving the register. Maybe I could do it using some new constraints. I took a quick look around and it doesn't look like we do this kind of thing using ODS. I'm interested in suggestions though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, the problem isn't in generating the match, but the replacement. The match would be something like: ConnectOp(RegResetOp(...), and(RegResetOp(...), _) (where the registers are the same). It's that the replacement isn't a replacement of the ConnectOp, but of the RegResetOp. Yeah, I'm not sure exactly how to do this.

} // namespace

static bool isDefinedByOneConstantOp(Value v) {
if (auto c = v.getDefiningOp<ConstantOp>())
return c.getValue().isOne();
Expand All @@ -2279,7 +2350,9 @@ canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {

void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
results
.add<patterns::RegResetWithZeroReset, FoldResetMux, RegResetAndOrOfSelf>(
context);
results.add(canonicalizeRegResetWithOneReset);
results.add(demoteForceableIfUnused<RegResetOp>);
}
Expand Down Expand Up @@ -3156,10 +3229,66 @@ static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
return success();
}

/// This canonicalizer provides the following patterns:
/// connect(reg, and(reg, x)) ==> reg -> 0
/// connect(reg, and(x, reg)) ==> reg -> 0
/// connect(reg, or(reg, x)) ==> reg -> 1
/// connect(reg, or(x, reg)) ==> reg -> 1
///
/// Justification: The initial value of a register is indeterminant, which means
/// We are free to choose any initial value when optimizing the circuit. For the
/// AND patterns, if we assume the initial value is zero, then the register will
/// always be zero. For the OR patterns, if we assume the initial value is one,
/// then the register will always be one.
Comment on lines +3238 to +3242
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great for including this on both rewrite patterns! Providing this as a comment makes it clear not only what the pattern is doing, but why and with what justification in the spec we are using. 💯

static LogicalResult foldRegAndOrOfSelf(RegOp reg, PatternRewriter &rewriter) {
// This canonicalization only applies when the register holds 1 bit.
auto type = dyn_cast<UIntType>(reg.getResult().getType());
if (!type || type.getWidthOrSentinel() != 1)
return failure();

// Find the one true connect, or bail.
auto connect = getSingleConnectUserOf(reg.getResult());
if (!connect)
return failure();

auto *src = connect.getSrc().getDefiningOp();
if (!src)
return failure();

// connect(reg, and(reg, x)) ==> reg -> 0
// connect(reg, and(x, reg)) ==> reg -> 0
if (auto srcAnd = dyn_cast<AndPrimOp>(src)) {
if (srcAnd.getLhs().getDefiningOp() == reg ||
srcAnd.getRhs().getDefiningOp() == reg) {
auto attr = getIntAttr(type, APInt(1, 0));
replaceOpWithNewOpAndCopyName<ConstantOp>(rewriter, reg, type, attr);
rewriter.eraseOp(connect);
return success();
}
}

// connect(reg, or(reg, x)) ==> reg -> 1
// connect(reg, or(x, reg)) ==> reg -> 1
if (auto srcOr = dyn_cast<OrPrimOp>(src)) {
if (srcOr.getLhs().getDefiningOp() == reg ||
srcOr.getRhs().getDefiningOp() == reg) {
auto attr = getIntAttr(type, APInt(1, 1));
replaceOpWithNewOpAndCopyName<ConstantOp>(rewriter, reg, type, attr);
rewriter.eraseOp(connect);
return success();
}
}

return failure();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Challenge 2: Can this be written in terms of ODS?


LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
succeeded(foldHiddenReset(op, rewriter)))
return success();
if (!hasDontTouch(op.getOperation()) && !op.isForceable()) {
if (succeeded(foldHiddenReset(op, rewriter)))
return success();
if (succeeded(foldRegAndOrOfSelf(op, rewriter)))
return success();
}

if (succeeded(demoteForceableIfUnused(op, rewriter)))
return success();
Expand Down
78 changes: 77 additions & 1 deletion test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,11 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
out %out3: !firrtl.uint<1>,
out %out4: !firrtl.uint<4>,
out %out5: !firrtl.uint<1>,
out %out6: !firrtl.uint<1>) {
out %out6: !firrtl.uint<1>,
out %out7: !firrtl.uint<1>,
out %out8: !firrtl.uint<1>,
out %out9: !firrtl.uint<1>,
out %out10: !firrtl.uint<1>) {
// CHECK: firrtl.matchingconnect %out, %in
%0 = firrtl.int.mux2cell (%cond, %in, %in) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.connect %out, %0 : !firrtl.uint<4>, !firrtl.uint<4>
Expand Down Expand Up @@ -634,6 +638,32 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
// CHECK-NEXT: mux4cell(%[[SEL]],
%17 = firrtl.int.mux4cell (%val1, %val1, %val2, %val1, %val2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.matchingconnect %out6, %17 : !firrtl.uint<1>

// mux(cond, 0, x) -> and(~cond, x)
// CHECK: [[V1:%.+]] = firrtl.not %cond
// CHECK-NEXT: [[V2:%.+]] = firrtl.and [[V1]], %val1
// CHECK-NEXT: firrtl.matchingconnect %out7, [[V2]]
%18 = firrtl.mux (%cond, %c0_ui1, %val1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.connect %out7, %18 : !firrtl.uint<1>, !firrtl.uint<1>

// mux(cond, 1, x) -> or(cond, x)
// CHECK: [[V:%.+]] = firrtl.or %cond, %val1
// CHECK-NEXT: firrtl.matchingconnect %out8, [[V]]
%19 = firrtl.mux (%cond, %c1_ui1, %val1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.connect %out8, %19 : !firrtl.uint<1>, !firrtl.uint<1>

// mux(cond, x, 0) -> and(cond, x)
// CHECK: [[V:%.+]] = firrtl.and %cond, %val1
// CHECK-NEXT: firrtl.matchingconnect %out9, [[V]]
%20 = firrtl.mux (%cond, %val1, %c0_ui1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.connect %out9, %20 : !firrtl.uint<1>, !firrtl.uint<1>

// mux(cond, x, 1) -> or(~cond, x)
// CHECK: [[V1:%.+]] = firrtl.not %cond
// CHECK-NEXT: [[V2:%.+]] = firrtl.or [[V1]], %val1
// CHECK-NEXT: firrtl.matchingconnect %out10, [[V2]]
%21 = firrtl.mux (%cond, %val1, %c1_ui1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.connect %out10, %21 : !firrtl.uint<1>, !firrtl.uint<1>
}

// CHECK-LABEL: firrtl.module @Pad
Expand Down Expand Up @@ -2618,6 +2648,52 @@ firrtl.module @constReg9(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, i
firrtl.matchingconnect %out, %r : !firrtl.uint<1>
}

// Check that a register driven by an and(en, reg) is folded to a constant zero.
// CHECK-LABEL: @constRegAnd
firrtl.module @constRegAnd(in %clock: !firrtl.clock, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) {
// CHECK-NOT: firrtl.reg
// CHECK: firrtl.matchingconnect %out, %c0_ui1
%r = firrtl.reg %clock {firrtl.random_init_start = 0 : ui64} : !firrtl.clock, !firrtl.uint<1>
%0 = firrtl.and %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.matchingconnect %out, %r : !firrtl.uint<1>
}

// Check that a register driven by an or(en, reg) is folded to a constant one.
// CHECK-LABEL: @constRegOr
firrtl.module @constRegOr(in %clock: !firrtl.clock, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) {
// CHECK-NOT: firrtl.reg
// CHECK: firrtl.matchingconnect %out, %c1_ui1
%r = firrtl.reg %clock {firrtl.random_init_start = 0 : ui64} : !firrtl.clock, !firrtl.uint<1>
%0 = firrtl.or %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.matchingconnect %out, %r : !firrtl.uint<1>
}

// Check that a regreset driven by an and(en, reg) is folded to a constant zero, when the reset is zero.
// CHECK-LABEL: @constRegResetAnd
firrtl.module @constRegResetAnd(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) {
// CHECK-NOT: firrtl.reg
// CHECK: firrtl.matchingconnect %out, %c0_ui1
%c0_ui1 = firrtl.constant 0 : !firrtl.uint<1>
%r = firrtl.regreset %clock, %reset, %c0_ui1 : !firrtl.clock, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
%0 = firrtl.and %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.matchingconnect %out, %r : !firrtl.uint<1>
}

// Check that a regreset driven by an or(en, reg) is folded to a constant one, when the reset is one.
// CHECK-LABEL: @constRegResetOr
firrtl.module @constRegResetOr(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) {
// CHECK-NOT: firrtl.reg
// CHECK: firrtl.matchingconnect %out, %c1_ui1
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%r = firrtl.regreset %clock, %reset, %c1_ui1 : !firrtl.clock, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
%0 = firrtl.or %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.matchingconnect %out, %r : !firrtl.uint<1>
}

firrtl.module @BitCast(out %o:!firrtl.bundle<valid: uint<1>, ready: uint<1>, data: uint<1>> ) {
%a = firrtl.wire : !firrtl.bundle<valid: uint<1>, ready: uint<1>, data: uint<1>>
%b = firrtl.bitcast %a : (!firrtl.bundle<valid: uint<1>, ready: uint<1>, data: uint<1>>) -> (!firrtl.uint<3>)
Expand Down
Loading