Skip to content

CompileTime: fix reshape elementwise #680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
247 changes: 216 additions & 31 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3746,6 +3746,78 @@ struct BroadcastReshape final
}
};

// Returns legal, and if reshaped comes before op
std::pair<bool, bool> fastDoesADominateB(Operation *reshaped, Operation *op,
Value v) {
assert(reshaped);
assert(op);
size_t limit = 200;
if (reshaped->getBlock() == op->getBlock()) {

// TODO we could do the following, if size wasn't O(N) =/
// op->getBlock()->getOperations().size() <= limit) {
if (op->getBlock()->isOpOrderValid()) {
return std::make_pair(true, reshaped->isBeforeInBlock(op));
}
if (v)
if (auto pred = v.getDefiningOp()) {
bool seenReshape = false;
bool seenUser = false;
Operation *cur = pred->getNextNode();
for (int i = 0; cur && i < limit; i++) {
// TODO we could make this an isancestor query, but of course compile
// time if (cur->isAncestor(reshaped))
if (cur == reshaped)
seenReshape = true;
// if (cur->isAncestor(op))
if (cur == op) {
seenUser = true;
}
if (seenReshape || seenUser)
break;
cur = cur->getNextNode();
}
if (seenReshape && !seenUser) {
return std::make_pair(true, true);
}
if (!seenReshape && seenUser) {
return std::make_pair(true, false);
}
}
{
bool seenUser = false;
Operation *cur = reshaped->getNextNode();
for (int i = 0; cur && i < limit; i++) {
// if (cur->isAncestor(op))
if (cur == op) {
seenUser = true;
return std::make_pair(true, true);
}
cur = cur->getNextNode();
}
if (!cur) {
std::make_pair(true, false);
}
}
{
bool seenReshape = false;
Operation *cur = op->getNextNode();
for (int i = 0; cur && i < limit; i++) {
// if (cur->isAncestor(reshaped))
if (cur == reshaped) {
seenReshape = true;
return std::make_pair(true, false);
}
cur = cur->getNextNode();
}
if (!cur) {
std::make_pair(true, true);
}
}
}
return std::make_pair(false, false);
}

struct BroadcastToReshape final
: OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -3787,9 +3859,39 @@ struct BroadcastToReshape final
// replace with reshape
if (op.getType() == op.getOperand().getType())
rewriter.replaceOp(op, op.getOperand());
else
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(),
op.getOperand());
else {
auto NT = op.getType();
stablehlo::ReshapeOp reshaped = nullptr;
bool before = false;
for (auto u : op.getOperand().getUsers()) {
auto re = dyn_cast<stablehlo::ReshapeOp>(u);
if (!re)
continue;
if (re.getType() != NT)
continue;
auto &&[legal, before2] = fastDoesADominateB(op, re, op.getOperand());
if (!legal)
continue;
before = before2;
reshaped = re;
break;
}
if (!reshaped) {
if (auto rop = op.getOperand().getDefiningOp()) {
rewriter.setInsertionPointAfter(rop);
} else if (auto ba = dyn_cast<BlockArgument>(op.getOperand())) {
rewriter.setInsertionPointToStart(ba.getOwner());
}
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(),
op.getOperand());
} else {
if (before) {
rewriter.modifyOpInPlace(reshaped,
[&]() { reshaped->moveBefore(op); });
}
rewriter.replaceOp(op, reshaped);
}
}
return success();
}
};
Expand Down Expand Up @@ -7495,29 +7597,96 @@ struct ReshapeElementwise final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {

LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op,
PatternRewriter &rewriter) const override {
if (op.getType() == op.getOperand().getType()) {
rewriter.replaceOp(op, op.getOperand());
return success();
}
auto elem = op.getOperand().getDefiningOp();
if (!elem)
return failure();

if (onlySingleUser && !llvm::hasSingleElement(elem->getUsers()))
if (!elem->hasTrait<mlir::OpTrait::Elementwise>())
return failure();

if (!elem->hasTrait<mlir::OpTrait::Elementwise>())
bool singleUse = true;
SmallVector<stablehlo::ReshapeOp> toReplace;
for (auto U : elem->getUsers()) {
if (auto re = dyn_cast<stablehlo::ReshapeOp>(U)) {
if (re.getType() == op.getType()) {
toReplace.push_back(re);
continue;
}
}
singleUse = false;
break;
}

if (onlySingleUser && !singleUse)
return failure();

if (singleUse) {
auto pt = rewriter.getInsertionPoint();
pt--;
rewriter.setInsertionPoint(rewriter.getInsertionBlock(), pt);
}

SmallVector<Value> ops;
for (auto v : elem->getOperands()) {
ops.push_back(rewriter.create<stablehlo::ReshapeOp>(
op.getLoc(),
RankedTensorType::get(
op.getType().getShape(),
cast<RankedTensorType>(v.getType()).getElementType()),
v));
auto NT = RankedTensorType::get(
op.getType().getShape(),
cast<RankedTensorType>(v.getType()).getElementType());
stablehlo::ReshapeOp reshaped = nullptr;
bool before;
for (auto u : v.getUsers()) {
auto re = dyn_cast<stablehlo::ReshapeOp>(u);
if (!re)
continue;
if (re.getType() != NT)
continue;
auto &&[legal, before2] = fastDoesADominateB(elem, re, v);
if (!legal) {
continue;
}
before = before2;
reshaped = re;
break;
}
if (!reshaped) {
if (auto rop = v.getDefiningOp()) {
rewriter.setInsertionPointAfter(rop);
} else if (auto ba = dyn_cast<BlockArgument>(v)) {
rewriter.setInsertionPointToStart(ba.getOwner());
}
reshaped = rewriter.create<stablehlo::ReshapeOp>(op.getLoc(), NT, v);
} else {
if (before) {
if (auto rop = v.getDefiningOp()) {
rewriter.modifyOpInPlace(reshaped,
[&]() { reshaped->moveAfter(rop); });
} else {
rewriter.modifyOpInPlace(reshaped,
[&]() { reshaped->moveBefore(elem); });
}
}
}
ops.push_back(reshaped);
}

if (singleUse) {
rewriter.modifyOpInPlace(elem, [&]() {
elem->setOperands(ops);
elem->getResult(0).setType(op.getType());
});
for (auto re : toReplace)
rewriter.replaceOp(re, elem);
} else {
rewriter.setInsertionPointAfter(elem);
auto newOp = rewriter.create(
elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops),
TypeRange(op.getType()), elem->getAttrs(), {}, {});
for (auto re : toReplace)
rewriter.replaceOp(re, newOp);
}
auto newOp = rewriter.create(
elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops),
TypeRange(op.getType()), elem->getAttrs(), {}, {});
rewriter.replaceOp(op, newOp);
return success();
}
};
Expand Down Expand Up @@ -7720,12 +7889,15 @@ template <typename T> struct CSE final : OpRewritePattern<T> {
continue;
if (nop->getBlock() != op->getBlock())
continue;
if (nop->isBeforeInBlock(op)) {
rewriter.replaceOp(op, nop);
return success();
} else {
rewriter.replaceOp(nop, op);
return success();
auto &&[legal, before] = fastDoesADominateB(nop, op, nullptr);
if (legal) {
if (before) {
rewriter.replaceOp(op, nop);
return success();
} else {
rewriter.replaceOp(nop, op);
return success();
}
}
}
return failure();
Expand Down Expand Up @@ -12421,16 +12593,19 @@ struct CommonCompareExpressionRewrite
continue;

if (userCompareOp.getLhs() == lhs && userCompareOp.getRhs() == rhs) {
if (user->isBeforeInBlock(op)) {
auto negatedCondition = rewriter.create<stablehlo::NotOp>(
op.getLoc(), userCompareOp.getResult());
rewriter.replaceOp(op, negatedCondition);
return success();
} else {
auto negatedCondition = rewriter.create<stablehlo::NotOp>(
userCompareOp.getLoc(), op.getResult());
rewriter.replaceOp(user, negatedCondition);
return success();
auto &&[legal, before] = fastDoesADominateB(user, op, opOperand);
if (legal) {
if (before) {
auto negatedCondition = rewriter.create<stablehlo::NotOp>(
op.getLoc(), userCompareOp.getResult());
rewriter.replaceOp(op, negatedCondition);
return success();
} else {
auto negatedCondition = rewriter.create<stablehlo::NotOp>(
userCompareOp.getLoc(), op.getResult());
rewriter.replaceOp(user, negatedCondition);
return success();
}
}
}
}
Expand Down Expand Up @@ -14328,6 +14503,16 @@ struct EnzymeHLOOptPass
GreedyRewriteConfig config;
config.maxIterations = max_iterations;
config.useTopDownTraversal = top_down;
getOperation()->walk([](Operation *op) {
for (auto &region : op->getRegions()) {
for (auto &blk : region.getBlocks()) {

if (!blk.isOpOrderValid()) {
blk.recomputeOpOrder();
}
}
}
});
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
signalPassFailure();
Expand Down
6 changes: 3 additions & 3 deletions test/lit_tests/raising/affine_to_stablehlo13.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ module {
}
}
// CHECK: func.func private @repeat_iv_raised(%arg0: tensor<10xi64>, %arg1: tensor<10xi64>, %arg2: tensor<10x10xf64>, %arg3: tensor<10xf64>) -> (tensor<10xi64>, tensor<10xi64>, tensor<10x10xf64>, tensor<10xf64>) {
// CHECK-NEXT: %0 = stablehlo.reshape %arg1 : (tensor<10xi64>) -> tensor<10x1xi64>
// CHECK-NEXT: %1 = stablehlo.reshape %arg0 : (tensor<10xi64>) -> tensor<10x1xi64>
// CHECK-NEXT: %2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<10x1xi64>, tensor<10x1xi64>) -> tensor<10x2xi64>
// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<10xi64>) -> tensor<10x1xi64>
// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<10xi64>) -> tensor<10x1xi64>
// CHECK-NEXT: %2 = stablehlo.concatenate %1, %0, dim = 1 : (tensor<10x1xi64>, tensor<10x1xi64>) -> tensor<10x2xi64>
// CHECK-NEXT: %3 = "stablehlo.gather"(%arg2, %2) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<10x10xf64>, tensor<10x2xi64>) -> tensor<10xf64>
// CHECK-NEXT: return %arg0, %arg1, %arg2, %3 : tensor<10xi64>, tensor<10xi64>, tensor<10x10xf64>, tensor<10xf64>
// CHECK-NEXT: }
Expand Down
6 changes: 3 additions & 3 deletions test/lit_tests/raising/affine_to_stablehlo15.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ module {
// CHECK-NEXT: %1 = stablehlo.dynamic_slice %arg0, %iterArg, %c_1, sizes = [1, 10] : (tensor<4x10xf32>, tensor<i64>, tensor<i64>) -> tensor<1x10xf32>
// CHECK-NEXT: %2 = stablehlo.reshape %1 : (tensor<1x10xf32>) -> tensor<10xf32>
// CHECK-NEXT: %3 = arith.mulf %2, %2 : tensor<10xf32>
// CHECK-NEXT: %4 = stablehlo.multiply %iterArg, %c_0 : tensor<i64>
// CHECK-NEXT: %5 = stablehlo.reshape %3 : (tensor<10xf32>) -> tensor<1x10xf32>
// CHECK-NEXT: %6 = stablehlo.dynamic_update_slice %iterArg_2, %5, %4, %c_1 : (tensor<16x10xf32>, tensor<1x10xf32>, tensor<i64>, tensor<i64>) -> tensor<16x10xf32>
// CHECK-NEXT: %4 = stablehlo.reshape %3 : (tensor<10xf32>) -> tensor<1x10xf32>
// CHECK-NEXT: %5 = stablehlo.multiply %iterArg, %c_0 : tensor<i64>
// CHECK-NEXT: %6 = stablehlo.dynamic_update_slice %iterArg_2, %4, %5, %c_1 : (tensor<16x10xf32>, tensor<1x10xf32>, tensor<i64>, tensor<i64>) -> tensor<16x10xf32>
// CHECK-NEXT: %7 = stablehlo.add %iterArg, %c : tensor<i64>
// CHECK-NEXT: stablehlo.return %7, %6 : tensor<i64>, tensor<16x10xf32>
// CHECK-NEXT: }
Expand Down
12 changes: 6 additions & 6 deletions test/lit_tests/raising/affine_to_stablehlo_pforred.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ module @"reactant_loop!" attributes {mhlo.num_partitions = 1 : i64, mhlo.num_rep
// CHECK-NEXT: %15 = stablehlo.reduce(%12 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<9x20x45xf64>, tensor<f64>) -> tensor<20x45xf64>
// CHECK-NEXT: %16 = stablehlo.reduce(%14 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<9x20x45xf64>, tensor<f64>) -> tensor<20x45xf64>
// CHECK-NEXT: %17 = arith.addf %5, %15 {fastmathFlags = #llvm.fastmath<none>} : tensor<20x45xf64>
// CHECK-NEXT: %18 = arith.addf %8, %16 {fastmathFlags = #llvm.fastmath<none>} : tensor<20x45xf64>
// CHECK-NEXT: %19 = stablehlo.reshape %18 : (tensor<20x45xf64>) -> tensor<1x20x45xf64>
// CHECK-NEXT: %20 = stablehlo.dynamic_update_slice %arg1, %19, %c_0, %c, %c : (tensor<1x35x59xf64>, tensor<1x20x45xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x35x59xf64>
// CHECK-NEXT: %21 = stablehlo.reshape %17 : (tensor<20x45xf64>) -> tensor<1x20x45xf64>
// CHECK-NEXT: %22 = stablehlo.dynamic_update_slice %arg0, %21, %c_0, %c, %c : (tensor<1x34x59xf64>, tensor<1x20x45xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x34x59xf64>
// CHECK-NEXT: return %22, %20, %arg2, %arg3, %arg4 : tensor<1x34x59xf64>, tensor<1x35x59xf64>, tensor<24xf64>, tensor<24x34x59xf64>, tensor<24x35x59xf64>
// CHECK-NEXT: %[[i21:.+]] = stablehlo.reshape %17 : (tensor<20x45xf64>) -> tensor<1x20x45xf64>
// CHECK-NEXT: %[[i18:.+]] = arith.addf %8, %16 {fastmathFlags = #llvm.fastmath<none>} : tensor<20x45xf64>
// CHECK-NEXT: %[[i19:.+]] = stablehlo.reshape %[[i18]] : (tensor<20x45xf64>) -> tensor<1x20x45xf64>
// CHECK-NEXT: %[[i20:.+]] = stablehlo.dynamic_update_slice %arg1, %[[i19]], %c_0, %c, %c : (tensor<1x35x59xf64>, tensor<1x20x45xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x35x59xf64>
// CHECK-NEXT: %22 = stablehlo.dynamic_update_slice %arg0, %[[i21]], %c_0, %c, %c : (tensor<1x34x59xf64>, tensor<1x20x45xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x34x59xf64>
// CHECK-NEXT: return %22, %[[i20]], %arg2, %arg3, %arg4 : tensor<1x34x59xf64>, tensor<1x35x59xf64>, tensor<24xf64>, tensor<24x34x59xf64>, tensor<24x35x59xf64>
// CHECK-NEXT: }
13 changes: 6 additions & 7 deletions test/lit_tests/raising/affine_to_stablehlo_pforred2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,10 @@ func.func private @"##call__Z29gpu__compute_barotropic_mode_16CompilerMetadataI1
// CHECK-NEXT: %50 = stablehlo.reduce(%44 init: %cst_4) applies stablehlo.add across dimensions = [2] : (tensor<96x192x19xf64>, tensor<f64>) -> tensor<96x192xf64>
// CHECK-NEXT: %51 = stablehlo.reduce(%49 init: %cst_4) applies stablehlo.add across dimensions = [2] : (tensor<96x192x19xf64>, tensor<f64>) -> tensor<96x192xf64>
// CHECK-NEXT: %52 = arith.addf %33, %50 {fastmathFlags = #llvm.fastmath<none>} : tensor<96x192xf64>
// CHECK-NEXT: %53 = arith.addf %37, %51 {fastmathFlags = #llvm.fastmath<none>} : tensor<96x192xf64>
// CHECK-NEXT: %54 = stablehlo.reshape %53 : (tensor<96x192xf64>) -> tensor<1x96x192xf64>
// CHECK-NEXT: %55 = stablehlo.dynamic_update_slice %arg1, %54, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x140x206xf64>
// CHECK-NEXT: %56 = stablehlo.reshape %52 : (tensor<96x192xf64>) -> tensor<1x96x192xf64>
// CHECK-NEXT: %57 = stablehlo.dynamic_update_slice %arg0, %56, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x140x206xf64>
// CHECK-NEXT: return %57, %55, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<1x140x206xf64>, tensor<1x140x206xf64>, tensor<35xf64>, tensor<34xf64>, tensor<1x110x206xf64>, tensor<34x110x206xf64>, tensor<34x110x206xf64>
// CHECK-NEXT: %53 = stablehlo.reshape %52 : (tensor<96x192xf64>) -> tensor<1x96x192xf64>
// CHECK-NEXT: %54 = arith.addf %37, %51 {fastmathFlags = #llvm.fastmath<none>} : tensor<96x192xf64>
// CHECK-NEXT: %55 = stablehlo.reshape %54 : (tensor<96x192xf64>) -> tensor<1x96x192xf64>
// CHECK-NEXT: %56 = stablehlo.dynamic_update_slice %arg1, %55, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x140x206xf64>
// CHECK-NEXT: %57 = stablehlo.dynamic_update_slice %arg0, %53, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x140x206xf64>
// CHECK-NEXT: return %57, %56, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<1x140x206xf64>, tensor<1x140x206xf64>, tensor<35xf64>, tensor<34xf64>, tensor<1x110x206xf64>, tensor<34x110x206xf64>, tensor<34x110x206xf64>
// CHECK-NEXT: }

6 changes: 3 additions & 3 deletions test/lit_tests/reshapeelementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ module {
}

// CHECK: func.func @main(%arg0: tensor<100x200x300xbf16>, %arg1: tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> {
// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16>
// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16>
// CHECK-NEXT: %2 = stablehlo.subtract %0, %1 : tensor<20000x300xbf16>
// CHECK-DAG: %[[a0:.+]] = stablehlo.reshape %arg0 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16>
// CHECK-DAG: %[[a1:.+]] = stablehlo.reshape %arg1 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16>
// CHECK-NEXT: %2 = stablehlo.subtract %[[a0]], %[[a1]] : tensor<20000x300xbf16>
// CHECK-NEXT: return %2 : tensor<20000x300xbf16>
// CHECK-NEXT: }
// CHECK: func.func @main2(%arg0: tensor<100x200x300xbf16>) -> tensor<20000x300xf32> {
Expand Down