Skip to content

Commit db618e3

Browse files
committed
fix
1 parent 0fbfb8c commit db618e3

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7507,7 +7507,7 @@ struct ReshapeElementwise final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
75077507
auto NT = RankedTensorType::get(
75087508
op.getType().getShape(),
75097509
cast<RankedTensorType>(v.getType()).getElementType());
7510-
Value reshaped = nullptr;
7510+
stablehlo::ReshapeOp reshaped = nullptr;
75117511
for (auto u : v.getUsers()) {
75127512
auto re = dyn_cast<stablehlo::ReshapeOp>(u);
75137513
if (!re)
@@ -7518,12 +7518,16 @@ struct ReshapeElementwise final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
75187518
break;
75197519
}
75207520
if (!reshaped) {
7521-
reshaped = rewriter.create<stablehlo::ReshapeOp>(
7522-
op.getLoc(),
7523-
RankedTensorType::get(
7524-
op.getType().getShape(),
7525-
cast<RankedTensorType>(v.getType()).getElementType()),
7526-
v);
7521+
reshaped = rewriter.create<stablehlo::ReshapeOp>(op.getLoc(), NT, v);
7522+
} else {
7523+
if (reshaped->getBlock() == op->getBlock()) {
7524+
if (op->isBeforeInBlock(reshaped)) {
7525+
rewriter.modifyOpInPlace(reshaped,
7526+
[&]() { reshaped->moveBefore(op); });
7527+
}
7528+
} else {
7529+
reshaped = rewriter.create<stablehlo::ReshapeOp>(op.getLoc(), NT, v);
7530+
}
75277531
}
75287532
ops.push_back(reshaped);
75297533
}

0 commit comments

Comments
 (0)