|
22 | 22 | #include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
23 | 23 | #include "mlir/Interfaces/TilingInterface.h"
|
24 | 24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
| 25 | +#include "mlir/Transforms/RegionUtils.h" |
25 | 26 |
|
26 | 27 | #include "gc/Transforms/Passes.h"
|
27 | 28 |
|
@@ -256,7 +257,7 @@ class ConsumerFusionAnchor
|
256 | 257 | * Collected insertExtractOp List during walk including targetSliceOp:
|
257 | 258 | * %4 = extract %args2 and %2 = extract %arg1
|
258 | 259 | */
|
259 |
| -FailureOr<std::pair<Value, SmallVector<tensor::ExtractSliceOp>>> |
| 260 | +static FailureOr<std::pair<Value, SmallVector<tensor::ExtractSliceOp>>> |
260 | 261 | getRootSourceOfExtractSliceOp(tensor::ExtractSliceOp targetSliceOp,
|
261 | 262 | int curDepth = 0) {
|
262 | 263 | // control recursive time in avoid of stack overflow
|
@@ -294,6 +295,14 @@ getRootSourceOfExtractSliceOp(tensor::ExtractSliceOp targetSliceOp,
|
294 | 295 |
|
295 | 296 | static FailureOr<tensor::ExtractSliceOp>
|
296 | 297 | getFirstExtractSliceOpOfOperand(OpOperand &operand) {
|
| 298 | + if (auto iterArg = dyn_cast<BlockArgument>(operand.get())) { |
| 299 | + if (auto loop = |
| 300 | + dyn_cast<LoopLikeOpInterface>(iterArg.getOwner()->getParentOp())) { |
| 301 | + return getFirstExtractSliceOpOfOperand(*loop.getTiedLoopInit(iterArg)); |
| 302 | + } |
| 303 | + return failure(); |
| 304 | + } |
| 305 | + |
297 | 306 | Operation *defineOp = operand.get().getDefiningOp();
|
298 | 307 | if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(defineOp)) {
|
299 | 308 | return sliceOp;
|
@@ -347,6 +356,87 @@ getProducerFusionAnchorFromOpOperand(RewriterBase &rewriter,
|
347 | 356 | return failure();
|
348 | 357 | }
|
349 | 358 |
|
| 359 | +// Get the Result of top-level Loop which yield the target InsertSliceOp. E.g |
| 360 | +// ``` |
| 361 | +// %1 = scf.for |
| 362 | +// %2 = scf.for |
| 363 | +// %3 = scf.for |
| 364 | +// ... |
| 365 | +// %4 = insert |
| 366 | +// yield %4 |
| 367 | +// %5 = insert %3 |
| 368 | +// yield %5 |
| 369 | +// yield %2 |
| 370 | +// ``` |
| 371 | +// @param targetSliceOp: %4 = insert |
| 372 | +// @return Result Value: %1 |
| 373 | +// Collected insertSliceOp List during walk including targetSliceOp: |
| 374 | +// %4 = insert and %5 = insert %3 |
| 375 | +static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>> |
| 376 | +getResultOfTopLevelLoopYieldInsertSliceOp( |
| 377 | + OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0) { |
| 378 | + // control recursive time in avoid of stack overflow |
| 379 | + if (curDepth > MAX_DEPTH) |
| 380 | + return failure(); |
| 381 | + |
| 382 | + SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList; |
| 383 | + candidateSliceOpList.push_back(targetSliceOp); |
| 384 | + Value resultOfLoop; |
| 385 | + if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>( |
| 386 | + targetSliceOp.getOperation())) { |
| 387 | + Value destValue = sliceOp.getDest(); |
| 388 | + auto iterArg = cast<BlockArgument>(destValue); |
| 389 | + auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp()); |
| 390 | + if (!forallOp) |
| 391 | + return failure(); |
| 392 | + resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)); |
| 393 | + } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>( |
| 394 | + targetSliceOp.getOperation())) { |
| 395 | + Value resultValue = sliceOp.getResult(); |
| 396 | + for (auto &useOperand : resultValue.getUses()) { |
| 397 | + if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) { |
| 398 | + if (llvm::detail::isPresent(resultOfLoop)) |
| 399 | + return failure(); |
| 400 | + auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp()); |
| 401 | + if (!forOp) |
| 402 | + return failure(); |
| 403 | + resultOfLoop = forOp->getResult(useOperand.getOperandNumber()); |
| 404 | + } |
| 405 | + } |
| 406 | + } |
| 407 | + |
| 408 | + if (!llvm::detail::isPresent(resultOfLoop)) |
| 409 | + return failure(); |
| 410 | + |
| 411 | + while (true) { |
| 412 | + bool walkThroughOuterLoop = false; |
| 413 | + for (auto &useOperand : resultOfLoop.getUses()) { |
| 414 | + if (auto sliceOp = |
| 415 | + dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) { |
| 416 | + auto resultAndSliceOpsPair = |
| 417 | + getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp, curDepth + 1); |
| 418 | + if (failed(resultAndSliceOpsPair)) |
| 419 | + return failure(); |
| 420 | + candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(), |
| 421 | + (*resultAndSliceOpsPair).second.end()); |
| 422 | + return std::make_pair((*resultAndSliceOpsPair).first, |
| 423 | + candidateSliceOpList); |
| 424 | + } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) { |
| 425 | + // walk through outer loop |
| 426 | + auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp()); |
| 427 | + if (!forOp) |
| 428 | + return failure(); |
| 429 | + resultOfLoop = forOp->getResult(useOperand.getOperandNumber()); |
| 430 | + walkThroughOuterLoop = true; |
| 431 | + break; |
| 432 | + } |
| 433 | + } |
| 434 | + if (!walkThroughOuterLoop) |
| 435 | + break; |
| 436 | + } |
| 437 | + return std::make_pair(resultOfLoop, candidateSliceOpList); |
| 438 | +} |
| 439 | + |
350 | 440 | /**
|
351 | 441 | * Find the untiled Consumer op based on given OpResult of Tiled Op, E.g.
|
352 | 442 | *
|
@@ -380,14 +470,19 @@ getConsumerFusionAnchorFromOpResult(RewriterBase &rewriter,
|
380 | 470 | return failure();
|
381 | 471 | sliceOp =
|
382 | 472 | dyn_cast<OffsetSizeAndStrideOpInterface>(useOfResult.getOwner());
|
| 473 | + } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOfResult.getOwner())) { |
| 474 | + if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp())) { |
| 475 | + return getConsumerFusionAnchorFromOpResult( |
| 476 | + rewriter, loop->getResult(useOfResult.getOperandNumber())); |
| 477 | + } |
383 | 478 | }
|
384 | 479 | }
|
385 | 480 |
|
386 | 481 | if (!llvm::detail::isPresent(sliceOp))
|
387 | 482 | return failure();
|
388 | 483 |
|
389 | 484 | auto resultAndSliceOpsPair =
|
390 |
| - scfX::getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp); |
| 485 | + getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp); |
391 | 486 | if (failed(resultAndSliceOpsPair))
|
392 | 487 | return failure();
|
393 | 488 |
|
@@ -424,9 +519,8 @@ static Operation *preOpFuseProducerOfOpOperand(
|
424 | 519 | if (failed(candidateSliceOp)) {
|
425 | 520 | return nullptr;
|
426 | 521 | }
|
427 |
| - auto outerLoops = scfX::getOuterLoopsOfSliceOp(*candidateSliceOp); |
428 | 522 | std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
|
429 |
| - scfX::tileAndFuseProducerOfSlice(rewriter, *candidateSliceOp, outerLoops); |
| 523 | + scfX::tileAndFuseProducerOfSlice(rewriter, *candidateSliceOp); |
430 | 524 |
|
431 | 525 | if (!fusedResult)
|
432 | 526 | return nullptr;
|
@@ -456,8 +550,23 @@ static SmallVector<Operation *> postOpFuseConsumerOfOpResult(
|
456 | 550 | std::optional<scf::SCFFuseConsumerOfSliceResult> fusedResult =
|
457 | 551 | scfX::tileAndFuseConsumerOfSlice(rewriter, *candidateSliceOp);
|
458 | 552 | if (fusedResult) {
|
459 |
| - tiledConsumerList.push_back(fusedResult.value().tiledOps[0]); |
460 |
| - rewriter.eraseOp(fusedResult.value().origConsumerOperand->getOwner()); |
| 553 | + auto tiledOp = fusedResult.value().tiledOps[0]; |
| 554 | + tiledConsumerList.push_back(tiledOp); |
| 555 | + auto whileProducerOutOfBlock = |
| 556 | + [&tiledOp](LoopLikeOpInterface loop) -> LogicalResult { |
| 557 | + Block &body = loop->getRegion(0).front(); |
| 558 | + return (tiledOp->getBlock() == &body) ? failure() : success(); |
| 559 | + }; |
| 560 | + SmallVector<LoopLikeOpInterface> outerLoops = |
| 561 | + scfX::getOuterNestLoopsWhile( |
| 562 | + (*candidateSliceOp)->getParentOfType<LoopLikeOpInterface>(), |
| 563 | + whileProducerOutOfBlock); |
| 564 | + // Manually run cse on region which contains top-level loop of candidate |
| 565 | + // slice in avoid of conflict with subsequent `tileAndFuseConsumerOfSlice` |
| 566 | + // get nest loops between next candidate sliceOp and tiled producer. |
| 567 | + auto region = outerLoops.front()->getParentRegion(); |
| 568 | + (void)mlir::eraseUnreachableBlocks(rewriter, {*region}); |
| 569 | + (void)mlir::runRegionDCE(rewriter, {*region}); |
461 | 570 | }
|
462 | 571 | }
|
463 | 572 |
|
|
0 commit comments