Skip to content

Commit

Permalink
[COMB] Factor Mux of Mux with common child conditions and common chil…
Browse files Browse the repository at this point in the history
…d values (#4403)
  • Loading branch information
darthscsi authored Dec 5, 2022
1 parent 5e5dfb4 commit 7de51c1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
26 changes: 26 additions & 0 deletions lib/Dialect/Comb/CombFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2381,6 +2381,32 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
return success();
}

// mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
trueMux.getTrueValue() == falseMux.getTrueValue()) {
auto subMux = rewriter.create<MuxOp>(
rewriter.getFusedLoc(trueMux.getLoc(), falseMux.getLoc()), op.getCond(),
trueMux.getFalseValue(), falseMux.getFalseValue());
replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
trueMux.getTrueValue(), subMux);
return success();
}

// mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
trueMux.getFalseValue() == falseMux.getFalseValue()) {
auto subMux = rewriter.create<MuxOp>(
rewriter.getFusedLoc(trueMux.getLoc(), falseMux.getLoc()), op.getCond(),
trueMux.getTrueValue(), falseMux.getTrueValue());
replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
subMux, trueMux.getFalseValue());
return success();
}

// mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
if (foldCommonMuxValue(op, false, rewriter))
return success();
Expand Down
20 changes: 20 additions & 0 deletions test/Dialect/Comb/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1415,3 +1415,23 @@ hw.module @ArrayConcatFlatten(%a: !hw.array<3xi1>) -> (b: i3) {
%4 = comb.concat %2, %3 : i1, i2
hw.output %4 : i3
}

// CHECK-LABEL: hw.module @MuxSimplify
hw.module @MuxSimplify(%index: i1, %a: i1, %foo_0: i2, %foo_1: i2) -> (r_0: i2, r_1: i2) {
%true = hw.constant true
%c-2_i2 = hw.constant -2 : i2
%c1_i2 = hw.constant 1 : i2
%0 = comb.xor bin %index, %true : i1
%1 = comb.mux bin %0, %c1_i2, %foo_0 : i2
%2 = comb.mux bin %index, %c1_i2, %foo_1 : i2
%3 = comb.mux bin %0, %c-2_i2, %foo_0 : i2
%4 = comb.mux bin %a, %1, %3 : i2
%5 = comb.mux bin %index, %c-2_i2, %foo_1 : i2
%6 = comb.mux bin %a, %2, %5 : i2
hw.output %4, %6 : i2, i2
}
// CHECK: %0 = comb.mux %a, %c1_i2, %c-2_i2 : i2
// CHECK-NEXT: %1 = comb.mux %index, %foo_0, %0 : i2
// CHECK-NEXT: %2 = comb.mux %a, %c1_i2, %c-2_i2 : i2
// CHECK-NEXT: %3 = comb.mux %index, %2, %foo_1 : i2
// CHECK-NEXT: hw.output %1, %3 : i2, i2

0 comments on commit 7de51c1

Please sign in to comment.