@@ -168,11 +168,12 @@ static FailureOr<DtypeLegalizeResult>
168168matmulDtypeLegalize (RewriterBase &rewriter, Operation *op,
169169 bool needCopyInit = true , bool needFurtherFuse = false ) {
170170 linalg::LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
171- Location loc = linalgOp->getLoc ();
172- DtypeLegalizeResult result;
173171 if (!linalgOp)
174172 return failure ();
175173
174+ Location loc = linalgOp->getLoc ();
175+ DtypeLegalizeResult result;
176+
176177 if (needToLegalizeDtype (linalgOp)) {
177178 rewriter.setInsertionPoint (linalgOp);
178179 IRMapping mapping;
@@ -449,15 +450,15 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
449450 }
450451 }
451452 } else {
452- TilingInterface tilingInterface =
453- cast<TilingInterface>(currentOp. getOperation () );
454- FailureOr<linalg::ForallTilingResult> tilingResult =
455- linalg::tileToForallOpUsingTileSizes (b, tilingInterface, tileSizes,
456- std:: nullopt );
453+ scf::SCFTilingOptions tileOption;
454+ tileOption. setTileSizes (tileSizes );
455+ tileOption. setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
456+ FailureOr<scf::SCFTilingResult> tilingResult = scf::tileUsingSCF (
457+ b, cast<TilingInterface>(currentOp. getOperation ()), tileOption );
457458 if (failed (tilingResult))
458459 return failure ();
459- b.replaceOp (currentOp, tilingResult->tileOp );
460- currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->tiledOp );
460+ b.replaceOp (currentOp, tilingResult->replacements );
461+ currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->tiledOps . back () );
461462 }
462463 }
463464 }
@@ -499,8 +500,8 @@ NOuterBlock: (PN + 1) * NOuterBlock] CSlice2 = CSlice[PK, PM * MOuterBlock: (PM
499500 for([om, on, ok]: [MNumBlock, NNumBlock, KNumBlock]) {
500501 ASlice2 = ASlice[om * MBlock: (om + 1) * MBlock, ok * KBlock: (ok + 1) *
501502KBlock]
502- BSlice2 = BSlice[0, om * MBlock : (om + 1) * MBlock, ok * KBlock : (ok +
503- 1) * KBlock ]
503+ BSlice2 = BSlice[0, ok * KBlock : (ok + 1) * KBlock, on * NBlock : (on +
504+ 1) * NBlock ]
504505 CSlice3 = CSlice2[0, om * MBlock: (om + 1) * MBlock, on * NBlock:
505506(on + 1) * NBlock] (init with 0 when ok == 0)
506507 MNumInnerBlock = MBlock / iim_block_
@@ -539,11 +540,13 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
539540 size_t NFirstDim = *getConstantIntValue (loopRange[NDimPos[0 ]].size );
540541
541542 size_t KParallelBlockSize =
542- KDimPos.size () > 1
543- ? llvm::divideCeil (KFirstDim, cfg.KThreads )
544- : llvm::divideCeil (llvm::divideCeil (KFirstDim, cfg.KBlock ),
545- cfg.KThreads ) *
546- cfg.KBlock ;
543+ cfg.KThreads == 1
544+ ? 0
545+ : (KDimPos.size () > 1
546+ ? llvm::divideCeil (KFirstDim, cfg.KThreads )
547+ : llvm::divideCeil (llvm::divideCeil (KFirstDim, cfg.KBlock ),
548+ cfg.KThreads ) *
549+ cfg.KBlock );
547550 size_t MParallelBlockSize =
548551 MDimPos.size () > 1
549552 ? llvm::divideCeil (MFirstDim, cfg.MThreads )
0 commit comments