@@ -168,11 +168,12 @@ static FailureOr<DtypeLegalizeResult>
168
168
matmulDtypeLegalize (RewriterBase &rewriter, Operation *op,
169
169
bool needCopyInit = true , bool needFurtherFuse = false ) {
170
170
linalg::LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
171
- Location loc = linalgOp->getLoc ();
172
- DtypeLegalizeResult result;
173
171
if (!linalgOp)
174
172
return failure ();
175
173
174
+ Location loc = linalgOp->getLoc ();
175
+ DtypeLegalizeResult result;
176
+
176
177
if (needToLegalizeDtype (linalgOp)) {
177
178
rewriter.setInsertionPoint (linalgOp);
178
179
IRMapping mapping;
@@ -449,15 +450,15 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
449
450
}
450
451
}
451
452
} else {
452
- TilingInterface tilingInterface =
453
- cast<TilingInterface>(currentOp. getOperation () );
454
- FailureOr<linalg::ForallTilingResult> tilingResult =
455
- linalg::tileToForallOpUsingTileSizes (b, tilingInterface, tileSizes,
456
- std::nullopt );
453
+ scf::SCFTilingOptions tileOption;
454
+ tileOption. setTileSizes (tileSizes );
455
+ tileOption. setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
456
+ FailureOr<scf::SCFTilingResult> tilingResult = scf::tileUsingSCF (
457
+ b, cast<TilingInterface>(currentOp. getOperation ()), tileOption );
457
458
if (failed (tilingResult))
458
459
return failure ();
459
- b.replaceOp (currentOp, tilingResult->tileOp );
460
- currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->tiledOp );
460
+ b.replaceOp (currentOp, tilingResult->replacements );
461
+ currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->tiledOps . back () );
461
462
}
462
463
}
463
464
}
@@ -499,8 +500,8 @@ NOuterBlock: (PN + 1) * NOuterBlock] CSlice2 = CSlice[PK, PM * MOuterBlock: (PM
499
500
for([om, on, ok]: [MNumBlock, NNumBlock, KNumBlock]) {
500
501
ASlice2 = ASlice[om * MBlock: (om + 1) * MBlock, ok * KBlock: (ok + 1) *
501
502
KBlock]
502
- BSlice2 = BSlice[0, om * MBlock : (om + 1) * MBlock, ok * KBlock : (ok +
503
- 1) * KBlock ]
503
+ BSlice2 = BSlice[0, ok * KBlock : (ok + 1) * KBlock, on * NBlock : (on +
504
+ 1) * NBlock ]
504
505
CSlice3 = CSlice2[0, om * MBlock: (om + 1) * MBlock, on * NBlock:
505
506
(on + 1) * NBlock] (init with 0 when ok == 0)
506
507
MNumInnerBlock = MBlock / iim_block_
@@ -539,11 +540,13 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
539
540
size_t NFirstDim = *getConstantIntValue (loopRange[NDimPos[0 ]].size );
540
541
541
542
size_t KParallelBlockSize =
542
- KDimPos.size () > 1
543
- ? llvm::divideCeil (KFirstDim, cfg.KThreads )
544
- : llvm::divideCeil (llvm::divideCeil (KFirstDim, cfg.KBlock ),
545
- cfg.KThreads ) *
546
- cfg.KBlock ;
543
+ cfg.KThreads == 1
544
+ ? 0
545
+ : (KDimPos.size () > 1
546
+ ? llvm::divideCeil (KFirstDim, cfg.KThreads )
547
+ : llvm::divideCeil (llvm::divideCeil (KFirstDim, cfg.KBlock ),
548
+ cfg.KThreads ) *
549
+ cfg.KBlock );
547
550
size_t MParallelBlockSize =
548
551
MDimPos.size () > 1
549
552
? llvm::divideCeil (MFirstDim, cfg.MThreads )
0 commit comments