File tree 1 file changed +27
-3
lines changed 1 file changed +27
-3
lines changed Original file line number Diff line number Diff line change @@ -3784,9 +3784,33 @@ struct BroadcastToReshape final
3784
3784
return failure();
3785
3785
}
3786
3786
3787
- // replace with reshape
3788
- rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(),
3789
- op.getOperand());
3787
+ auto NT = op.getType();
3788
+ stablehlo::ReshapeOp reshaped = nullptr;
3789
+ for (auto u : op.getOperand().getUsers()) {
3790
+ auto re = dyn_cast<stablehlo::ReshapeOp>(u);
3791
+ if (!re)
3792
+ continue;
3793
+ if (re.getType() != NT)
3794
+ continue;
3795
+ reshaped = re;
3796
+ break;
3797
+ }
3798
+ if (!reshaped) {
3799
+ rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(),
3800
+ op.getOperand());
3801
+ } else {
3802
+ if (reshaped->getBlock() == op->getBlock()) {
3803
+ if (op->isBeforeInBlock(reshaped)) {
3804
+ rewriter.modifyOpInPlace(reshaped,
3805
+ [&]() { reshaped->moveBefore(op); });
3806
+ }
3807
+ rewriter.replaceOp(op, reshaped);
3808
+ } else {
3809
+ rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(),
3810
+ op.getOperand());
3811
+ }
3812
+ }
3813
+
3790
3814
return success();
3791
3815
}
3792
3816
};
You can’t perform that action at this time.
0 commit comments