Skip to content

Commit 53e5c4b

Browse files
committed
sync to latest upstream PR
1 parent 6150833 commit 53e5c4b

File tree

3 files changed

+712
-517
lines changed

3 files changed

+712
-517
lines changed

lib/gc/Transforms/AnyTilableFusion.cpp

+115-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2323
#include "mlir/Interfaces/TilingInterface.h"
2424
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25+
#include "mlir/Transforms/RegionUtils.h"
2526

2627
#include "gc/Transforms/Passes.h"
2728

@@ -256,7 +257,7 @@ class ConsumerFusionAnchor
256257
* Collected insertExtractOp List during walk including targetSliceOp:
257258
* %4 = extract %args2 and %2 = extract %arg1
258259
*/
259-
FailureOr<std::pair<Value, SmallVector<tensor::ExtractSliceOp>>>
260+
static FailureOr<std::pair<Value, SmallVector<tensor::ExtractSliceOp>>>
260261
getRootSourceOfExtractSliceOp(tensor::ExtractSliceOp targetSliceOp,
261262
int curDepth = 0) {
262263
// control recursive time in avoid of stack overflow
@@ -294,6 +295,14 @@ getRootSourceOfExtractSliceOp(tensor::ExtractSliceOp targetSliceOp,
294295

295296
static FailureOr<tensor::ExtractSliceOp>
296297
getFirstExtractSliceOpOfOperand(OpOperand &operand) {
298+
if (auto iterArg = dyn_cast<BlockArgument>(operand.get())) {
299+
if (auto loop =
300+
dyn_cast<LoopLikeOpInterface>(iterArg.getOwner()->getParentOp())) {
301+
return getFirstExtractSliceOpOfOperand(*loop.getTiedLoopInit(iterArg));
302+
}
303+
return failure();
304+
}
305+
297306
Operation *defineOp = operand.get().getDefiningOp();
298307
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(defineOp)) {
299308
return sliceOp;
@@ -347,6 +356,87 @@ getProducerFusionAnchorFromOpOperand(RewriterBase &rewriter,
347356
return failure();
348357
}
349358

359+
// Get the Result of top-level Loop which yield the target InsertSliceOp. E.g
360+
// ```
361+
// %1 = scf.for
362+
// %2 = scf.for
363+
// %3 = scf.for
364+
// ...
365+
// %4 = insert
366+
// yield %4
367+
// %5 = insert %3
368+
// yield %5
369+
// yield %2
370+
// ```
371+
// @param targetSliceOp: %4 = insert
372+
// @return Result Value: %1
373+
// Collected insertSliceOp List during walk including targetSliceOp:
374+
// %4 = insert and %5 = insert %3
375+
static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
376+
getResultOfTopLevelLoopYieldInsertSliceOp(
377+
OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0) {
378+
// control recursive time in avoid of stack overflow
379+
if (curDepth > MAX_DEPTH)
380+
return failure();
381+
382+
SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
383+
candidateSliceOpList.push_back(targetSliceOp);
384+
Value resultOfLoop;
385+
if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(
386+
targetSliceOp.getOperation())) {
387+
Value destValue = sliceOp.getDest();
388+
auto iterArg = cast<BlockArgument>(destValue);
389+
auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
390+
if (!forallOp)
391+
return failure();
392+
resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
393+
} else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(
394+
targetSliceOp.getOperation())) {
395+
Value resultValue = sliceOp.getResult();
396+
for (auto &useOperand : resultValue.getUses()) {
397+
if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
398+
if (llvm::detail::isPresent(resultOfLoop))
399+
return failure();
400+
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
401+
if (!forOp)
402+
return failure();
403+
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
404+
}
405+
}
406+
}
407+
408+
if (!llvm::detail::isPresent(resultOfLoop))
409+
return failure();
410+
411+
while (true) {
412+
bool walkThroughOuterLoop = false;
413+
for (auto &useOperand : resultOfLoop.getUses()) {
414+
if (auto sliceOp =
415+
dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
416+
auto resultAndSliceOpsPair =
417+
getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp, curDepth + 1);
418+
if (failed(resultAndSliceOpsPair))
419+
return failure();
420+
candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
421+
(*resultAndSliceOpsPair).second.end());
422+
return std::make_pair((*resultAndSliceOpsPair).first,
423+
candidateSliceOpList);
424+
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
425+
// walk through outer loop
426+
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
427+
if (!forOp)
428+
return failure();
429+
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
430+
walkThroughOuterLoop = true;
431+
break;
432+
}
433+
}
434+
if (!walkThroughOuterLoop)
435+
break;
436+
}
437+
return std::make_pair(resultOfLoop, candidateSliceOpList);
438+
}
439+
350440
/**
351441
* Find the untiled Consumer op based on given OpResult of Tiled Op, E.g.
352442
*
@@ -380,14 +470,19 @@ getConsumerFusionAnchorFromOpResult(RewriterBase &rewriter,
380470
return failure();
381471
sliceOp =
382472
dyn_cast<OffsetSizeAndStrideOpInterface>(useOfResult.getOwner());
473+
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOfResult.getOwner())) {
474+
if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp())) {
475+
return getConsumerFusionAnchorFromOpResult(
476+
rewriter, loop->getResult(useOfResult.getOperandNumber()));
477+
}
383478
}
384479
}
385480

386481
if (!llvm::detail::isPresent(sliceOp))
387482
return failure();
388483

389484
auto resultAndSliceOpsPair =
390-
scfX::getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp);
485+
getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp);
391486
if (failed(resultAndSliceOpsPair))
392487
return failure();
393488

@@ -424,9 +519,8 @@ static Operation *preOpFuseProducerOfOpOperand(
424519
if (failed(candidateSliceOp)) {
425520
return nullptr;
426521
}
427-
auto outerLoops = scfX::getOuterLoopsOfSliceOp(*candidateSliceOp);
428522
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
429-
scfX::tileAndFuseProducerOfSlice(rewriter, *candidateSliceOp, outerLoops);
523+
scfX::tileAndFuseProducerOfSlice(rewriter, *candidateSliceOp);
430524

431525
if (!fusedResult)
432526
return nullptr;
@@ -456,8 +550,23 @@ static SmallVector<Operation *> postOpFuseConsumerOfOpResult(
456550
std::optional<scf::SCFFuseConsumerOfSliceResult> fusedResult =
457551
scfX::tileAndFuseConsumerOfSlice(rewriter, *candidateSliceOp);
458552
if (fusedResult) {
459-
tiledConsumerList.push_back(fusedResult.value().tiledOps[0]);
460-
rewriter.eraseOp(fusedResult.value().origConsumerOperand->getOwner());
553+
auto tiledOp = fusedResult.value().tiledOps[0];
554+
tiledConsumerList.push_back(tiledOp);
555+
auto whileProducerOutOfBlock =
556+
[&tiledOp](LoopLikeOpInterface loop) -> LogicalResult {
557+
Block &body = loop->getRegion(0).front();
558+
return (tiledOp->getBlock() == &body) ? failure() : success();
559+
};
560+
SmallVector<LoopLikeOpInterface> outerLoops =
561+
scfX::getOuterNestLoopsWhile(
562+
(*candidateSliceOp)->getParentOfType<LoopLikeOpInterface>(),
563+
whileProducerOutOfBlock);
564+
// Manually run cse on region which contains top-level loop of candidate
565+
// slice in avoid of conflict with subsequent `tileAndFuseConsumerOfSlice`
566+
// get nest loops between next candidate sliceOp and tiled producer.
567+
auto region = outerLoops.front()->getParentRegion();
568+
(void)mlir::eraseUnreachableBlocks(rewriter, {*region});
569+
(void)mlir::runRegionDCE(rewriter, {*region});
461570
}
462571
}
463572

0 commit comments

Comments
 (0)