Skip to content

Commit a205731

Browse files
committed
deprecated tileToForallUsingTileSize
1 parent 3c5567f commit a205731

File tree

3 files changed

+28
-24
lines changed

3 files changed

+28
-24
lines changed

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,12 @@ static FailureOr<DtypeLegalizeResult>
168168
matmulDtypeLegalize(RewriterBase &rewriter, Operation *op,
169169
bool needCopyInit = true, bool needFurtherFuse = false) {
170170
linalg::LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
171-
Location loc = linalgOp->getLoc();
172-
DtypeLegalizeResult result;
173171
if (!linalgOp)
174172
return failure();
175173

174+
Location loc = linalgOp->getLoc();
175+
DtypeLegalizeResult result;
176+
176177
if (needToLegalizeDtype(linalgOp)) {
177178
rewriter.setInsertionPoint(linalgOp);
178179
IRMapping mapping;
@@ -449,15 +450,15 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
449450
}
450451
}
451452
} else {
452-
TilingInterface tilingInterface =
453-
cast<TilingInterface>(currentOp.getOperation());
454-
FailureOr<linalg::ForallTilingResult> tilingResult =
455-
linalg::tileToForallOpUsingTileSizes(b, tilingInterface, tileSizes,
456-
std::nullopt);
453+
scf::SCFTilingOptions tileOption;
454+
tileOption.setTileSizes(tileSizes);
455+
tileOption.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
456+
FailureOr<scf::SCFTilingResult> tilingResult = scf::tileUsingSCF(
457+
b, cast<TilingInterface>(currentOp.getOperation()), tileOption);
457458
if (failed(tilingResult))
458459
return failure();
459-
b.replaceOp(currentOp, tilingResult->tileOp);
460-
currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->tiledOp);
460+
b.replaceOp(currentOp, tilingResult->replacements);
461+
currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->tiledOps.back());
461462
}
462463
}
463464
}
@@ -499,8 +500,8 @@ NOuterBlock: (PN + 1) * NOuterBlock] CSlice2 = CSlice[PK, PM * MOuterBlock: (PM
499500
for([om, on, ok]: [MNumBlock, NNumBlock, KNumBlock]) {
500501
ASlice2 = ASlice[om * MBlock: (om + 1) * MBlock, ok * KBlock: (ok + 1) *
501502
KBlock]
502-
BSlice2 = BSlice[0, om * MBlock: (om + 1) * MBlock, ok * KBlock: (ok +
503-
1) * KBlock]
503+
BSlice2 = BSlice[0, ok * KBlock: (ok + 1) * KBlock, on * NBlock: (on +
504+
1) * NBlock]
504505
CSlice3 = CSlice2[0, om * MBlock: (om + 1) * MBlock, on * NBlock:
505506
(on + 1) * NBlock] (init with 0 when ok == 0)
506507
MNumInnerBlock = MBlock / iim_block_
@@ -539,11 +540,13 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
539540
size_t NFirstDim = *getConstantIntValue(loopRange[NDimPos[0]].size);
540541

541542
size_t KParallelBlockSize =
542-
KDimPos.size() > 1
543-
? llvm::divideCeil(KFirstDim, cfg.KThreads)
544-
: llvm::divideCeil(llvm::divideCeil(KFirstDim, cfg.KBlock),
545-
cfg.KThreads) *
546-
cfg.KBlock;
543+
cfg.KThreads == 1
544+
? 0
545+
: (KDimPos.size() > 1
546+
? llvm::divideCeil(KFirstDim, cfg.KThreads)
547+
: llvm::divideCeil(llvm::divideCeil(KFirstDim, cfg.KBlock),
548+
cfg.KThreads) *
549+
cfg.KBlock);
547550
size_t MParallelBlockSize =
548551
MDimPos.size() > 1
549552
? llvm::divideCeil(MFirstDim, cfg.MThreads)

lib/gc/Transforms/Pipeline.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
5252
// todo: layout propagation pass
5353
// todo: tensor constant propagation pass
5454
// linalg.matmul lowering to (scf.loop + linalg.brgemm) pass
55-
pm.addNestedPass<func::FuncOp>(createIterativeTilingAndFusion());
56-
// Fine-grain fusion pass
5755
pm.addNestedPass<func::FuncOp>(createDeepTileContractionNamedOp());
56+
57+
// Fine-grain fusion pass
58+
pm.addNestedPass<func::FuncOp>(createIterativeTilingAndFusion());
5859
// todo: fine-grain fusion pass
5960
// todo: lower linalg to arith/math on virtual vector pass
6061

test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ func.func @matmul_2Dx2D_f32(%arg0: tensor<4096x4096xf32>, %arg1: tensor<4096x409
77
%cst_0 = arith.constant 0.000000e+00 : f32
88
%0 = tensor.empty() : tensor<4096x4096xf32>
99
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
10-
// CHECK: scf.forall {{.*}} (4) {{.*}} (tensor<4096x4096xf32>) {
10+
// CHECK: scf.forall {{.*}} (0) to (4096) step (1024) {{.*}} (tensor<4096x4096xf32>) {
1111
// CHECK: tensor.extract_slice {{.*}} [1024, 4096] [1, 1]
12-
// CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<1024x4096xf32>)
12+
// CHECK: scf.forall {{.*}} (0) to (4096) step (2048) {{.*}} (tensor<1024x4096xf32>)
1313
// CHECK: tensor.extract_slice {{.*}} [1024, 2048] [1, 1]
1414
// CHECK: scf.for
1515
// CHECK: tensor.extract_slice {{.*}} [256, 2048] [1, 1]
@@ -43,9 +43,9 @@ func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<12
4343
%0 = tensor.empty() : tensor<128x128x32x32xbf16>
4444
// CHECK-NOT: linalg.fill
4545
%1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
46-
// CHECK: scf.forall {{.*}} (16) {{.*}} (tensor<128x128x32x32xbf16>)
46+
// CHECK: scf.forall {{.*}} (0) to (128) step (8) {{.*}} (tensor<128x128x32x32xbf16>)
4747
// CHECK: tensor.extract_slice {{.*}} [8, 128, 32, 32] [1, 1, 1, 1]
48-
// CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<8x128x32x32xbf16>)
48+
// CHECK: scf.forall {{.*}} (0) to (128) step (64) {{.*}} (tensor<8x128x32x32xbf16>)
4949
// CHECK: tensor.extract_slice {{.*}} [8, 64, 32, 32] [1, 1, 1, 1]
5050
// CHECK: scf.for
5151
// CHECK: tensor.extract_slice {{.*}} [8, 8, 32, 32] [1, 1, 1, 1]
@@ -80,9 +80,9 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12
8080
%1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
8181
// CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<2x1x1x4096x4096xf32>)
8282
// CHECK: tensor.extract_slice {{.*}} [1, 1, 1, 4096, 4096] [1, 1, 1, 1, 1]
83-
// CHECK: scf.forall {{.*}} (16) {{.*}} (tensor<4096x4096xf32>)
83+
// CHECK: scf.forall {{.*}} (0) to (4096) step (256) {{.*}} (tensor<4096x4096xf32>)
8484
// CHECK: tensor.extract_slice {{.*}} [256, 4096] [1, 1]
85-
// CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<256x4096xf32>)
85+
// CHECK: scf.forall {{.*}} (0) to (128) step (64) {{.*}} (tensor<256x4096xf32>)
8686
// CHECK: tensor.extract_slice {{.*}} [256, 2048] [1, 1]
8787
// CHECK: scf.for
8888
// CHECK: tensor.extract_slice {{.*}} [256, 256] [1, 1]

0 commit comments

Comments
 (0)