From 7de51c1b36779b5dfc6a47d77a0b36121e627700 Mon Sep 17 00:00:00 2001 From: Andrew Lenharth Date: Mon, 5 Dec 2022 15:10:37 -0600 Subject: [PATCH] [COMB] Factor Mux of Mux with common child conditions and common child values (#4403) --- lib/Dialect/Comb/CombFolds.cpp | 26 +++++++++++++++++++++++++ test/Dialect/Comb/canonicalization.mlir | 20 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/lib/Dialect/Comb/CombFolds.cpp b/lib/Dialect/Comb/CombFolds.cpp index 42ecdb82a199..c2bbe78506be 100644 --- a/lib/Dialect/Comb/CombFolds.cpp +++ b/lib/Dialect/Comb/CombFolds.cpp @@ -2381,6 +2381,32 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op, return success(); } + // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c)) + if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), + falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); + trueMux && falseMux && trueMux.getCond() == falseMux.getCond() && + trueMux.getTrueValue() == falseMux.getTrueValue()) { + auto subMux = rewriter.create( + rewriter.getFusedLoc(trueMux.getLoc(), falseMux.getLoc()), op.getCond(), + trueMux.getFalseValue(), falseMux.getFalseValue()); + replaceOpWithNewOpAndCopyName(rewriter, op, trueMux.getCond(), + trueMux.getTrueValue(), subMux); + return success(); + } + + // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b) + if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), + falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); + trueMux && falseMux && trueMux.getCond() == falseMux.getCond() && + trueMux.getFalseValue() == falseMux.getFalseValue()) { + auto subMux = rewriter.create( + rewriter.getFusedLoc(trueMux.getLoc(), falseMux.getLoc()), op.getCond(), + trueMux.getTrueValue(), falseMux.getTrueValue()); + replaceOpWithNewOpAndCopyName(rewriter, op, trueMux.getCond(), + subMux, trueMux.getFalseValue()); + return success(); + } + // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a if (foldCommonMuxValue(op, false, rewriter)) return success(); diff --git a/test/Dialect/Comb/canonicalization.mlir b/test/Dialect/Comb/canonicalization.mlir index a78db7157c50..68ef52cacd9f 100644 --- a/test/Dialect/Comb/canonicalization.mlir +++ b/test/Dialect/Comb/canonicalization.mlir @@ -1415,3 +1415,23 @@ hw.module @ArrayConcatFlatten(%a: !hw.array<3xi1>) -> (b: i3) { %4 = comb.concat %2, %3 : i1, i2 hw.output %4 : i3 } + +// CHECK-LABEL: hw.module @MuxSimplify +hw.module @MuxSimplify(%index: i1, %a: i1, %foo_0: i2, %foo_1: i2) -> (r_0: i2, r_1: i2) { + %true = hw.constant true + %c-2_i2 = hw.constant -2 : i2 + %c1_i2 = hw.constant 1 : i2 + %0 = comb.xor bin %index, %true : i1 + %1 = comb.mux bin %0, %c1_i2, %foo_0 : i2 + %2 = comb.mux bin %index, %c1_i2, %foo_1 : i2 + %3 = comb.mux bin %0, %c-2_i2, %foo_0 : i2 + %4 = comb.mux bin %a, %1, %3 : i2 + %5 = comb.mux bin %index, %c-2_i2, %foo_1 : i2 + %6 = comb.mux bin %a, %2, %5 : i2 + hw.output %4, %6 : i2, i2 +} +// CHECK: %0 = comb.mux %a, %c1_i2, %c-2_i2 : i2 +// CHECK-NEXT: %1 = comb.mux %index, %foo_0, %0 : i2 +// CHECK-NEXT: %2 = comb.mux %a, %c1_i2, %c-2_i2 : i2 +// CHECK-NEXT: %3 = comb.mux %index, %2, %foo_1 : i2 +// CHECK-NEXT: hw.output %1, %3 : i2, i2