diff --git a/lib/Dialect/Comb/CombFolds.cpp b/lib/Dialect/Comb/CombFolds.cpp index c2bbe78506be..389de1f8a3cf 100644 --- a/lib/Dialect/Comb/CombFolds.cpp +++ b/lib/Dialect/Comb/CombFolds.cpp @@ -2407,6 +2407,21 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op, return success(); } + // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b) + if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), + falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); + trueMux && falseMux && + trueMux.getTrueValue() == falseMux.getTrueValue() && + trueMux.getFalseValue() == falseMux.getFalseValue()) { + auto subMux = rewriter.create( + rewriter.getFusedLoc( + {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}), + op.getCond(), trueMux.getCond(), falseMux.getCond()); + replaceOpWithNewOpAndCopyName( + rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue()); + return success(); + } + // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a if (foldCommonMuxValue(op, false, rewriter)) return success(); diff --git a/test/Dialect/Comb/canonicalization.mlir b/test/Dialect/Comb/canonicalization.mlir index 68ef52cacd9f..924735b9c2e7 100644 --- a/test/Dialect/Comb/canonicalization.mlir +++ b/test/Dialect/Comb/canonicalization.mlir @@ -1417,7 +1417,7 @@ hw.module @ArrayConcatFlatten(%a: !hw.array<3xi1>) -> (b: i3) { } // CHECK-LABEL: hw.module @MuxSimplify -hw.module @MuxSimplify(%index: i1, %a: i1, %foo_0: i2, %foo_1: i2) -> (r_0: i2, r_1: i2) { +hw.module @MuxSimplify(%index: i1, %a: i1, %foo_0: i2, %foo_1: i2) -> (r_0: i2, r_1: i2, r_2 : i2) { %true = hw.constant true %c-2_i2 = hw.constant -2 : i2 %c1_i2 = hw.constant 1 : i2 @@ -1428,10 +1428,19 @@ hw.module @MuxSimplify(%index: i1, %a: i1, %foo_0: i2, %foo_1: i2) -> (r_0: i2, %4 = comb.mux bin %a, %1, %3 : i2 %5 = comb.mux bin %index, %c-2_i2, %foo_1 : i2 %6 = comb.mux bin %a, %2, %5 : i2 - hw.output %4, %6 : i2, i2 + + %7 = comb.mux bin %a, %foo_0, %foo_1 : i2 + %8 = comb.mux bin %index, %foo_0, %foo_1 : i2 + %9 = comb.xor %a, %index : i1 + %10 = comb.mux bin %9, %7, %8 : i2 + + hw.output %4, %6, %10 : i2, i2, i2 } // CHECK: %0 = comb.mux %a, %c1_i2, %c-2_i2 : i2 // CHECK-NEXT: %1 = comb.mux %index, %foo_0, %0 : i2 // CHECK-NEXT: %2 = comb.mux %a, %c1_i2, %c-2_i2 : i2 // CHECK-NEXT: %3 = comb.mux %index, %2, %foo_1 : i2 -// CHECK-NEXT: hw.output %1, %3 : i2, i2 +// CHECK-NEXT: %4 = comb.xor %a, %index : i1 +// CHECK-NEXT: %5 = comb.mux %4, %a, %index : i1 +// CHECK-NEXT: %6 = comb.mux %5, %foo_0, %foo_1 : i2 +// CHECK-NEXT: hw.output %1, %3, %6