File tree Expand file tree Collapse file tree 1 file changed +11
-7
lines changed Expand file tree Collapse file tree 1 file changed +11
-7
lines changed Original file line number Diff line number Diff line change @@ -7507,7 +7507,7 @@ struct ReshapeElementwise final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
7507
7507
auto NT = RankedTensorType::get(
7508
7508
op.getType().getShape(),
7509
7509
cast<RankedTensorType>(v.getType()).getElementType());
7510
- Value reshaped = nullptr;
7510
+ stablehlo::ReshapeOp reshaped = nullptr;
7511
7511
for (auto u : v.getUsers()) {
7512
7512
auto re = dyn_cast<stablehlo::ReshapeOp>(u);
7513
7513
if (!re)
@@ -7518,12 +7518,16 @@ struct ReshapeElementwise final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
7518
7518
break;
7519
7519
}
7520
7520
if (!reshaped) {
7521
- reshaped = rewriter.create<stablehlo::ReshapeOp>(
7522
- op.getLoc(),
7523
- RankedTensorType::get(
7524
- op.getType().getShape(),
7525
- cast<RankedTensorType>(v.getType()).getElementType()),
7526
- v);
7521
+ reshaped = rewriter.create<stablehlo::ReshapeOp>(op.getLoc(), NT, v);
7522
+ } else {
7523
+ if (reshaped->getBlock() == op->getBlock()) {
7524
+ if (op->isBeforeInBlock(reshaped)) {
7525
+ rewriter.modifyOpInPlace(reshaped,
7526
+ [&]() { reshaped->moveBefore(op); });
7527
+ }
7528
+ } else {
7529
+ reshaped = rewriter.create<stablehlo::ReshapeOp>(op.getLoc(), NT, v);
7530
+ }
7527
7531
}
7528
7532
ops.push_back(reshaped);
7529
7533
}
You can’t perform that action at this time.
0 commit comments