@@ -464,30 +464,36 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
464464 currentOp.getNumLoops (), getAsIndexOpFoldResult (b.getContext (), 0 ));
465465 SmallVector<unsigned > reductionDims;
466466 currentOp.getReductionDims (reductionDims);
467+ bool tileOnReduction = false ;
467468 for (auto [d, tile] : llvm::zip (currentDim, currentTileSize)) {
469+ if (llvm::find (reductionDims, d) != reductionDims.end ()) {
470+ tileOnReduction = true ;
471+ }
468472 if (llvm::find (reductionDims, d) != reductionDims.end () &&
469- !dyn_cast<PartialReductionOpInterface>(currentOp.getOperation ()))
473+ !dyn_cast<PartialReductionOpInterface>(currentOp.getOperation ())) {
470474 tileSizes[d] = getAsIndexOpFoldResult (b.getContext (), 0 );
471- else
475+ tileOnReduction = false ;
476+ } else
472477 tileSizes[d] = getAsIndexOpFoldResult (b.getContext (), tile);
473478 }
474479 SmallVector<Range> loopRanges =
475480 cast<TilingInterface>(currentOp.getOperation ()).getIterationDomain (b);
476481 OpBuilder::InsertionGuard guard (b);
477482 b.setInsertionPoint (currentOp);
478- if (auto partialInterface =
479- dyn_cast<PartialReductionOpInterface>(currentOp.getOperation ())) {
483+ if (tileOnReduction) {
484+ auto partialInterface =
485+ dyn_cast<PartialReductionOpInterface>(currentOp.getOperation ());
480486 for (auto [idx, tile] : llvm::enumerate (tileSizes)) {
481- if (isConstantIntValue (tile, 0 )) {
487+ if (isConstantIntValue (tile, 0 ) &&
488+ llvm::find (reductionDims, d) != reductionDims.end ()) {
482489 tileSizes[idx] = loopRanges[idx].size ;
483490 }
484491 }
485-
486492 SmallVector<OpFoldResult> newParallelDims;
487493 for (auto i = 0UL ; i < reductionDims.size (); i++) {
488494 newParallelDims.push_back (getAsIndexOpFoldResult (b.getContext (), i));
489495 }
490- auto tilingResult = linalgX::tileAllUsingForall (
496+ auto tilingResult = linalgX::tileReductionUsingForall (
491497 b, cast<PartialReductionOpInterface>(currentOp.getOperation ()), {},
492498 tileSizes, newParallelDims, std::nullopt );
493499 if (failed (tilingResult) &&
@@ -503,8 +509,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
503509 }
504510 }
505511 }
506- } else if ( auto tilingInterface =
507- cast<TilingInterface>(currentOp.getOperation ())) {
512+ } else {
513+ auto tilingInterface = cast<TilingInterface>(currentOp.getOperation ());
508514 auto tilingResult = linalg::tileToForallOpUsingTileSizes (
509515 b, tilingInterface, tileSizes, std::nullopt );
510516 if (failed (tilingResult))
@@ -597,11 +603,15 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
597603 ? (cfg.NBlock - 1 ) / cfg.innerMostNBlock + 1
598604 : cfg.NBlock ;
599605 // Outer
600- option.nestedTileSizes .emplace_back (SmallVector<size_t >{
601- MParallelBlockSize, NParallelBlockSize, KParallelBlockSize});
602- option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForallOp);
603- option.loopDim .emplace_back (
604- SmallVector<size_t >{MDimPos[0 ], NDimPos[0 ], KDimPos[0 ]});
606+ for (auto [tile, dim] :
607+ llvm::zip (SmallVector<size_t >{KParallelBlockSize, MParallelBlockSize,
608+ NParallelBlockSize},
609+ SmallVector<size_t >{KDimPos[0 ], MDimPos[0 ], NDimPos[0 ]})) {
610+ option.nestedTileSizes .emplace_back (SmallVector<size_t >{tile});
611+ option.loopType .emplace_back (
612+ OuterLoopGenerationOption::LoopType::ForallOp);
613+ option.loopDim .emplace_back (SmallVector<size_t >{dim});
614+ }
605615 // Middle
606616 for (auto [tile, dim] :
607617 llvm::zip (SmallVector<size_t >{MOuterBlockSize, NOuterBlockSize,
0 commit comments