|
10 | 10 | //
|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
| 13 | +#include "./Tiling.hpp" |
13 | 14 | #include "mlir/AsmParser/AsmParser.h"
|
14 | 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h"
|
15 | 16 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
@@ -262,19 +263,16 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
|
262 | 263 | OpBuilder::InsertionGuard guard(b);
|
263 | 264 | b.setInsertionPoint(currentOp);
|
264 | 265 | // TODO: add split reduction support here
|
265 |
| - // if (auto partialInterface = |
266 |
| - // dyn_cast<PartialReductionOpInterface>(currentOp.getOperation())) |
267 |
| - // { |
268 |
| - // auto tilingResult = linalg::tileReductionUsingForall( |
269 |
| - // b, cast<PartialReductionOpInterface>(currentOp.getOperation()), |
270 |
| - // numThreads, tileSizes, std::nullopt); |
271 |
| - // if (failed(tilingResult)) |
272 |
| - // return failure(); |
273 |
| - // currentOp = |
274 |
| - // dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOp); |
275 |
| - // } else |
276 |
| - if (auto tilingInterface = |
277 |
| - cast<TilingInterface>(currentOp.getOperation())) { |
| 266 | + if (auto partialInterface = |
| 267 | + dyn_cast<PartialReductionOpInterface>(currentOp.getOperation())) { |
| 268 | + auto tilingResult = linalgX::tileAllUsingForall( |
| 269 | + b, cast<PartialReductionOpInterface>(currentOp.getOperation()), |
| 270 | + numThreads, tileSizes, std::nullopt); |
| 271 | + if (failed(tilingResult)) |
| 272 | + return failure(); |
| 273 | + currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOp); |
| 274 | + } else if (auto tilingInterface = |
| 275 | + cast<TilingInterface>(currentOp.getOperation())) { |
278 | 276 | auto tilingResult = linalg::tileToForallOpUsingTileSizes(
|
279 | 277 | b, tilingInterface, tileSizes, std::nullopt);
|
280 | 278 | if (failed(tilingResult))
|
|
0 commit comments