Skip to content

Commit 2c1f5f7

Browse files
committed
merge all parallel into one forall
1 parent 33f6207 commit 2c1f5f7

File tree

3 files changed

+1091
-13
lines changed

3 files changed

+1091
-13
lines changed

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "./Tiling.hpp"
1314
#include "mlir/AsmParser/AsmParser.h"
1415
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1516
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -262,19 +263,16 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
262263
OpBuilder::InsertionGuard guard(b);
263264
b.setInsertionPoint(currentOp);
264265
// TODO: add split reduction support here
265-
// if (auto partialInterface =
266-
// dyn_cast<PartialReductionOpInterface>(currentOp.getOperation()))
267-
// {
268-
// auto tilingResult = linalg::tileReductionUsingForall(
269-
// b, cast<PartialReductionOpInterface>(currentOp.getOperation()),
270-
// numThreads, tileSizes, std::nullopt);
271-
// if (failed(tilingResult))
272-
// return failure();
273-
// currentOp =
274-
// dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOp);
275-
// } else
276-
if (auto tilingInterface =
277-
cast<TilingInterface>(currentOp.getOperation())) {
266+
if (auto partialInterface =
267+
dyn_cast<PartialReductionOpInterface>(currentOp.getOperation())) {
268+
auto tilingResult = linalgX::tileAllUsingForall(
269+
b, cast<PartialReductionOpInterface>(currentOp.getOperation()),
270+
numThreads, tileSizes, std::nullopt);
271+
if (failed(tilingResult))
272+
return failure();
273+
currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOp);
274+
} else if (auto tilingInterface =
275+
cast<TilingInterface>(currentOp.getOperation())) {
278276
auto tilingResult = linalg::tileToForallOpUsingTileSizes(
279277
b, tilingInterface, tileSizes, std::nullopt);
280278
if (failed(tilingResult))

0 commit comments

Comments
 (0)