diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index d48ad98ea..5913d9ad2 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3746,6 +3746,78 @@ struct BroadcastReshape final } }; +// Returns legal, and if reshaped comes before op +std::pair fastDoesADominateB(Operation *reshaped, Operation *op, + Value v) { + assert(reshaped); + assert(op); + size_t limit = 200; + if (reshaped->getBlock() == op->getBlock()) { + + // TODO we could do the following, if size wasn't O(N) =/ + // op->getBlock()->getOperations().size() <= limit) { + if (op->getBlock()->isOpOrderValid()) { + return std::make_pair(true, reshaped->isBeforeInBlock(op)); + } + if (v) + if (auto pred = v.getDefiningOp()) { + bool seenReshape = false; + bool seenUser = false; + Operation *cur = pred->getNextNode(); + for (int i = 0; cur && i < limit; i++) { + // TODO we could make this an isancestor query, but of course compile + // time if (cur->isAncestor(reshaped)) + if (cur == reshaped) + seenReshape = true; + // if (cur->isAncestor(op)) + if (cur == op) { + seenUser = true; + } + if (seenReshape || seenUser) + break; + cur = cur->getNextNode(); + } + if (seenReshape && !seenUser) { + return std::make_pair(true, true); + } + if (!seenReshape && seenUser) { + return std::make_pair(true, false); + } + } + { + bool seenUser = false; + Operation *cur = reshaped->getNextNode(); + for (int i = 0; cur && i < limit; i++) { + // if (cur->isAncestor(op)) + if (cur == op) { + seenUser = true; + return std::make_pair(true, true); + } + cur = cur->getNextNode(); + } + if (!cur) { + std::make_pair(true, false); + } + } + { + bool seenReshape = false; + Operation *cur = op->getNextNode(); + for (int i = 0; cur && i < limit; i++) { + // if (cur->isAncestor(reshaped)) + if (cur == reshaped) { + seenReshape = true; + return std::make_pair(true, false); + } + cur = cur->getNextNode(); + } + if (!cur) { + std::make_pair(true, true); + } + } + } + return std::make_pair(false, false); +} + struct BroadcastToReshape final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -3787,9 +3859,39 @@ struct BroadcastToReshape final // replace with reshape if (op.getType() == op.getOperand().getType()) rewriter.replaceOp(op, op.getOperand()); - else - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getOperand()); + else { + auto NT = op.getType(); + stablehlo::ReshapeOp reshaped = nullptr; + bool before = false; + for (auto u : op.getOperand().getUsers()) { + auto re = dyn_cast(u); + if (!re) + continue; + if (re.getType() != NT) + continue; + auto &&[legal, before2] = fastDoesADominateB(op, re, op.getOperand()); + if (!legal) + continue; + before = before2; + reshaped = re; + break; + } + if (!reshaped) { + if (auto rop = op.getOperand().getDefiningOp()) { + rewriter.setInsertionPointAfter(rop); + } else if (auto ba = dyn_cast(op.getOperand())) { + rewriter.setInsertionPointToStart(ba.getOwner()); + } + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getOperand()); + } else { + if (before) { + rewriter.modifyOpInPlace(reshaped, + [&]() { reshaped->moveBefore(op); }); + } + rewriter.replaceOp(op, reshaped); + } + } return success(); } }; @@ -7495,29 +7597,96 @@ struct ReshapeElementwise final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op, PatternRewriter &rewriter) const override { + if (op.getType() == op.getOperand().getType()) { + rewriter.replaceOp(op, op.getOperand()); + return success(); + } auto elem = op.getOperand().getDefiningOp(); if (!elem) return failure(); - if (onlySingleUser && !llvm::hasSingleElement(elem->getUsers())) + if (!elem->hasTrait()) return failure(); - if (!elem->hasTrait()) + bool singleUse = true; + SmallVector toReplace; + for (auto U : elem->getUsers()) { + if (auto re = dyn_cast(U)) { + if (re.getType() == op.getType()) { + toReplace.push_back(re); + continue; + } + } + singleUse = false; + break; + } + + if (onlySingleUser && !singleUse) return failure(); + if (singleUse) { + auto pt = rewriter.getInsertionPoint(); + pt--; + rewriter.setInsertionPoint(rewriter.getInsertionBlock(), pt); + } + SmallVector ops; for (auto v : elem->getOperands()) { - ops.push_back(rewriter.create( - op.getLoc(), - RankedTensorType::get( - op.getType().getShape(), - cast(v.getType()).getElementType()), - v)); + auto NT = RankedTensorType::get( + op.getType().getShape(), + cast(v.getType()).getElementType()); + stablehlo::ReshapeOp reshaped = nullptr; + bool before; + for (auto u : v.getUsers()) { + auto re = dyn_cast(u); + if (!re) + continue; + if (re.getType() != NT) + continue; + auto &&[legal, before2] = fastDoesADominateB(elem, re, v); + if (!legal) { + continue; + } + before = before2; + reshaped = re; + break; + } + if (!reshaped) { + if (auto rop = v.getDefiningOp()) { + rewriter.setInsertionPointAfter(rop); + } else if (auto ba = dyn_cast(v)) { + rewriter.setInsertionPointToStart(ba.getOwner()); + } + reshaped = rewriter.create(op.getLoc(), NT, v); + } else { + if (before) { + if (auto rop = v.getDefiningOp()) { + rewriter.modifyOpInPlace(reshaped, + [&]() { reshaped->moveAfter(rop); }); + } else { + rewriter.modifyOpInPlace(reshaped, + [&]() { reshaped->moveBefore(elem); }); + } + } + } + ops.push_back(reshaped); + } + + if (singleUse) { + rewriter.modifyOpInPlace(elem, [&]() { + elem->setOperands(ops); + elem->getResult(0).setType(op.getType()); + }); + for (auto re : toReplace) + rewriter.replaceOp(re, elem); + } else { + rewriter.setInsertionPointAfter(elem); + auto newOp = rewriter.create( + elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops), + TypeRange(op.getType()), elem->getAttrs(), {}, {}); + for (auto re : toReplace) + rewriter.replaceOp(re, newOp); } - auto newOp = rewriter.create( - elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops), - TypeRange(op.getType()), elem->getAttrs(), {}, {}); - rewriter.replaceOp(op, newOp); return success(); } }; @@ -7720,12 +7889,15 @@ template struct CSE final : OpRewritePattern { continue; if (nop->getBlock() != op->getBlock()) continue; - if (nop->isBeforeInBlock(op)) { - rewriter.replaceOp(op, nop); - return success(); - } else { - rewriter.replaceOp(nop, op); - return success(); + auto &&[legal, before] = fastDoesADominateB(nop, op, nullptr); + if (legal) { + if (before) { + rewriter.replaceOp(op, nop); + return success(); + } else { + rewriter.replaceOp(nop, op); + return success(); + } } } return failure(); @@ -12421,16 +12593,19 @@ struct CommonCompareExpressionRewrite continue; if (userCompareOp.getLhs() == lhs && userCompareOp.getRhs() == rhs) { - if (user->isBeforeInBlock(op)) { - auto negatedCondition = rewriter.create( - op.getLoc(), userCompareOp.getResult()); - rewriter.replaceOp(op, negatedCondition); - return success(); - } else { - auto negatedCondition = rewriter.create( - userCompareOp.getLoc(), op.getResult()); - rewriter.replaceOp(user, negatedCondition); - return success(); + auto &&[legal, before] = fastDoesADominateB(user, op, opOperand); + if (legal) { + if (before) { + auto negatedCondition = rewriter.create( + op.getLoc(), userCompareOp.getResult()); + rewriter.replaceOp(op, negatedCondition); + return success(); + } else { + auto negatedCondition = rewriter.create( + userCompareOp.getLoc(), op.getResult()); + rewriter.replaceOp(user, negatedCondition); + return success(); + } } } } @@ -14328,6 +14503,16 @@ struct EnzymeHLOOptPass GreedyRewriteConfig config; config.maxIterations = max_iterations; config.useTopDownTraversal = top_down; + getOperation()->walk([](Operation *op) { + for (auto ®ion : op->getRegions()) { + for (auto &blk : region.getBlocks()) { + + if (!blk.isOpOrderValid()) { + blk.recomputeOpOrder(); + } + } + } + }); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { signalPassFailure(); diff --git a/test/lit_tests/raising/affine_to_stablehlo13.mlir b/test/lit_tests/raising/affine_to_stablehlo13.mlir index 5d9aa9bf1..d442dd316 100644 --- a/test/lit_tests/raising/affine_to_stablehlo13.mlir +++ b/test/lit_tests/raising/affine_to_stablehlo13.mlir @@ -94,9 +94,9 @@ module { } } // CHECK: func.func private @repeat_iv_raised(%arg0: tensor<10xi64>, %arg1: tensor<10xi64>, %arg2: tensor<10x10xf64>, %arg3: tensor<10xf64>) -> (tensor<10xi64>, tensor<10xi64>, tensor<10x10xf64>, tensor<10xf64>) { -// CHECK-NEXT: %0 = stablehlo.reshape %arg1 : (tensor<10xi64>) -> tensor<10x1xi64> -// CHECK-NEXT: %1 = stablehlo.reshape %arg0 : (tensor<10xi64>) -> tensor<10x1xi64> -// CHECK-NEXT: %2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<10x1xi64>, tensor<10x1xi64>) -> tensor<10x2xi64> +// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<10xi64>) -> tensor<10x1xi64> +// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<10xi64>) -> tensor<10x1xi64> +// CHECK-NEXT: %2 = stablehlo.concatenate %1, %0, dim = 1 : (tensor<10x1xi64>, tensor<10x1xi64>) -> tensor<10x2xi64> // CHECK-NEXT: %3 = "stablehlo.gather"(%arg2, %2) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<10x10xf64>, tensor<10x2xi64>) -> tensor<10xf64> // CHECK-NEXT: return %arg0, %arg1, %arg2, %3 : tensor<10xi64>, tensor<10xi64>, tensor<10x10xf64>, tensor<10xf64> // CHECK-NEXT: } diff --git a/test/lit_tests/raising/affine_to_stablehlo15.mlir b/test/lit_tests/raising/affine_to_stablehlo15.mlir index f116fe23f..f0db5de42 100644 --- a/test/lit_tests/raising/affine_to_stablehlo15.mlir +++ b/test/lit_tests/raising/affine_to_stablehlo15.mlir @@ -25,9 +25,9 @@ module { // CHECK-NEXT: %1 = stablehlo.dynamic_slice %arg0, %iterArg, %c_1, sizes = [1, 10] : (tensor<4x10xf32>, tensor, tensor) -> tensor<1x10xf32> // CHECK-NEXT: %2 = stablehlo.reshape %1 : (tensor<1x10xf32>) -> tensor<10xf32> // CHECK-NEXT: %3 = arith.mulf %2, %2 : tensor<10xf32> -// CHECK-NEXT: %4 = stablehlo.multiply %iterArg, %c_0 : tensor -// CHECK-NEXT: %5 = stablehlo.reshape %3 : (tensor<10xf32>) -> tensor<1x10xf32> -// CHECK-NEXT: %6 = stablehlo.dynamic_update_slice %iterArg_2, %5, %4, %c_1 : (tensor<16x10xf32>, tensor<1x10xf32>, tensor, tensor) -> tensor<16x10xf32> +// CHECK-NEXT: %4 = stablehlo.reshape %3 : (tensor<10xf32>) -> tensor<1x10xf32> +// CHECK-NEXT: %5 = stablehlo.multiply %iterArg, %c_0 : tensor +// CHECK-NEXT: %6 = stablehlo.dynamic_update_slice %iterArg_2, %4, %5, %c_1 : (tensor<16x10xf32>, tensor<1x10xf32>, tensor, tensor) -> tensor<16x10xf32> // CHECK-NEXT: %7 = stablehlo.add %iterArg, %c : tensor // CHECK-NEXT: stablehlo.return %7, %6 : tensor, tensor<16x10xf32> // CHECK-NEXT: } diff --git a/test/lit_tests/raising/affine_to_stablehlo_pforred.mlir b/test/lit_tests/raising/affine_to_stablehlo_pforred.mlir index 6dcad5808..b3756bca8 100644 --- a/test/lit_tests/raising/affine_to_stablehlo_pforred.mlir +++ b/test/lit_tests/raising/affine_to_stablehlo_pforred.mlir @@ -56,10 +56,10 @@ module @"reactant_loop!" attributes {mhlo.num_partitions = 1 : i64, mhlo.num_rep // CHECK-NEXT: %15 = stablehlo.reduce(%12 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<9x20x45xf64>, tensor) -> tensor<20x45xf64> // CHECK-NEXT: %16 = stablehlo.reduce(%14 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<9x20x45xf64>, tensor) -> tensor<20x45xf64> // CHECK-NEXT: %17 = arith.addf %5, %15 {fastmathFlags = #llvm.fastmath} : tensor<20x45xf64> -// CHECK-NEXT: %18 = arith.addf %8, %16 {fastmathFlags = #llvm.fastmath} : tensor<20x45xf64> -// CHECK-NEXT: %19 = stablehlo.reshape %18 : (tensor<20x45xf64>) -> tensor<1x20x45xf64> -// CHECK-NEXT: %20 = stablehlo.dynamic_update_slice %arg1, %19, %c_0, %c, %c : (tensor<1x35x59xf64>, tensor<1x20x45xf64>, tensor, tensor, tensor) -> tensor<1x35x59xf64> -// CHECK-NEXT: %21 = stablehlo.reshape %17 : (tensor<20x45xf64>) -> tensor<1x20x45xf64> -// CHECK-NEXT: %22 = stablehlo.dynamic_update_slice %arg0, %21, %c_0, %c, %c : (tensor<1x34x59xf64>, tensor<1x20x45xf64>, tensor, tensor, tensor) -> tensor<1x34x59xf64> -// CHECK-NEXT: return %22, %20, %arg2, %arg3, %arg4 : tensor<1x34x59xf64>, tensor<1x35x59xf64>, tensor<24xf64>, tensor<24x34x59xf64>, tensor<24x35x59xf64> +// CHECK-NEXT: %[[i21:.+]] = stablehlo.reshape %17 : (tensor<20x45xf64>) -> tensor<1x20x45xf64> +// CHECK-NEXT: %[[i18:.+]] = arith.addf %8, %16 {fastmathFlags = #llvm.fastmath} : tensor<20x45xf64> +// CHECK-NEXT: %[[i19:.+]] = stablehlo.reshape %[[i18]] : (tensor<20x45xf64>) -> tensor<1x20x45xf64> +// CHECK-NEXT: %[[i20:.+]] = stablehlo.dynamic_update_slice %arg1, %[[i19]], %c_0, %c, %c : (tensor<1x35x59xf64>, tensor<1x20x45xf64>, tensor, tensor, tensor) -> tensor<1x35x59xf64> +// CHECK-NEXT: %22 = stablehlo.dynamic_update_slice %arg0, %[[i21]], %c_0, %c, %c : (tensor<1x34x59xf64>, tensor<1x20x45xf64>, tensor, tensor, tensor) -> tensor<1x34x59xf64> +// CHECK-NEXT: return %22, %[[i20]], %arg2, %arg3, %arg4 : tensor<1x34x59xf64>, tensor<1x35x59xf64>, tensor<24xf64>, tensor<24x34x59xf64>, tensor<24x35x59xf64> // CHECK-NEXT: } diff --git a/test/lit_tests/raising/affine_to_stablehlo_pforred2.mlir b/test/lit_tests/raising/affine_to_stablehlo_pforred2.mlir index bdcbfc0eb..51fb3333d 100644 --- a/test/lit_tests/raising/affine_to_stablehlo_pforred2.mlir +++ b/test/lit_tests/raising/affine_to_stablehlo_pforred2.mlir @@ -119,11 +119,10 @@ func.func private @"##call__Z29gpu__compute_barotropic_mode_16CompilerMetadataI1 // CHECK-NEXT: %50 = stablehlo.reduce(%44 init: %cst_4) applies stablehlo.add across dimensions = [2] : (tensor<96x192x19xf64>, tensor) -> tensor<96x192xf64> // CHECK-NEXT: %51 = stablehlo.reduce(%49 init: %cst_4) applies stablehlo.add across dimensions = [2] : (tensor<96x192x19xf64>, tensor) -> tensor<96x192xf64> // CHECK-NEXT: %52 = arith.addf %33, %50 {fastmathFlags = #llvm.fastmath} : tensor<96x192xf64> -// CHECK-NEXT: %53 = arith.addf %37, %51 {fastmathFlags = #llvm.fastmath} : tensor<96x192xf64> -// CHECK-NEXT: %54 = stablehlo.reshape %53 : (tensor<96x192xf64>) -> tensor<1x96x192xf64> -// CHECK-NEXT: %55 = stablehlo.dynamic_update_slice %arg1, %54, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor, tensor, tensor) -> tensor<1x140x206xf64> -// CHECK-NEXT: %56 = stablehlo.reshape %52 : (tensor<96x192xf64>) -> tensor<1x96x192xf64> -// CHECK-NEXT: %57 = stablehlo.dynamic_update_slice %arg0, %56, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor, tensor, tensor) -> tensor<1x140x206xf64> -// CHECK-NEXT: return %57, %55, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<1x140x206xf64>, tensor<1x140x206xf64>, tensor<35xf64>, tensor<34xf64>, tensor<1x110x206xf64>, tensor<34x110x206xf64>, tensor<34x110x206xf64> +// CHECK-NEXT: %53 = stablehlo.reshape %52 : (tensor<96x192xf64>) -> tensor<1x96x192xf64> +// CHECK-NEXT: %54 = arith.addf %37, %51 {fastmathFlags = #llvm.fastmath} : tensor<96x192xf64> +// CHECK-NEXT: %55 = stablehlo.reshape %54 : (tensor<96x192xf64>) -> tensor<1x96x192xf64> +// CHECK-NEXT: %56 = stablehlo.dynamic_update_slice %arg1, %55, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor, tensor, tensor) -> tensor<1x140x206xf64> +// CHECK-NEXT: %57 = stablehlo.dynamic_update_slice %arg0, %53, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor, tensor, tensor) -> tensor<1x140x206xf64> +// CHECK-NEXT: return %57, %56, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<1x140x206xf64>, tensor<1x140x206xf64>, tensor<35xf64>, tensor<34xf64>, tensor<1x110x206xf64>, tensor<34x110x206xf64>, tensor<34x110x206xf64> // CHECK-NEXT: } - diff --git a/test/lit_tests/reshapeelementwise.mlir b/test/lit_tests/reshapeelementwise.mlir index 46fffe346..851dd59a2 100644 --- a/test/lit_tests/reshapeelementwise.mlir +++ b/test/lit_tests/reshapeelementwise.mlir @@ -14,9 +14,9 @@ module { } // CHECK: func.func @main(%arg0: tensor<100x200x300xbf16>, %arg1: tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> { -// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> -// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> -// CHECK-NEXT: %2 = stablehlo.subtract %0, %1 : tensor<20000x300xbf16> +// CHECK-DAG: %[[a0:.+]] = stablehlo.reshape %arg0 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> +// CHECK-DAG: %[[a1:.+]] = stablehlo.reshape %arg1 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> +// CHECK-NEXT: %2 = stablehlo.subtract %[[a0]], %[[a1]] : tensor<20000x300xbf16> // CHECK-NEXT: return %2 : tensor<20000x300xbf16> // CHECK-NEXT: } // CHECK: func.func @main2(%arg0: tensor<100x200x300xbf16>) -> tensor<20000x300xf32> {