Skip to content

Commit 5c7f765

Browse files
committed
Also broadcast2reshape
1 parent db618e3 commit 5c7f765

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3784,9 +3784,33 @@ struct BroadcastToReshape final
37843784
return failure();
37853785
}
37863786

3787-
// replace with reshape
3788-
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(),
3789-
op.getOperand());
3787+
auto NT = op.getType();
3788+
stablehlo::ReshapeOp reshaped = nullptr;
3789+
for (auto u : op.getOperand().getUsers()) {
3790+
auto re = dyn_cast<stablehlo::ReshapeOp>(u);
3791+
if (!re)
3792+
continue;
3793+
if (re.getType() != NT)
3794+
continue;
3795+
reshaped = re;
3796+
break;
3797+
}
3798+
if (!reshaped) {
3799+
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(),
3800+
op.getOperand());
3801+
} else {
3802+
if (reshaped->getBlock() == op->getBlock()) {
3803+
if (op->isBeforeInBlock(reshaped)) {
3804+
rewriter.modifyOpInPlace(reshaped,
3805+
[&]() { reshaped->moveBefore(op); });
3806+
}
3807+
rewriter.replaceOp(op, reshaped);
3808+
} else {
3809+
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(),
3810+
op.getOperand());
3811+
}
3812+
}
3813+
37903814
return success();
37913815
}
37923816
};

0 commit comments

Comments
 (0)