Skip to content

Commit da852fd

Browse files
committed
support partial reduction
1 parent ed0d7e6 commit da852fd

File tree

6 files changed

+137
-83
lines changed

6 files changed

+137
-83
lines changed

include/gc/Dialect/Arith/Utils/EasyBuild.h

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
1-
//===- EasyBuild.h - Easy Arith IR Builder utilities ------------*- C++ -*-===//
1+
//===-- EasyBuild.h - DESC --------------------------------------*- C++ -*-===//
22
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
//
9-
// This header file defines the easy-build utilities for arith dialects. It
10-
// provides the utility functions, classes and operators to make it easir to
11-
// program arith dialect operations in C++
12-
//
13-
//===----------------------------------------------------------------------===//
14-
158
#ifndef MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H
169
#define MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H
1710
#include "gc/IR/EasyBuild.h"
@@ -28,12 +21,8 @@ namespace impl {
2821

2922
template <std::size_t size> struct ToFloatType {};
3023

31-
template <> struct ToFloatType<4> {
32-
using type = Float32Type;
33-
};
34-
template <> struct ToFloatType<8> {
35-
using type = Float64Type;
36-
};
24+
template <> struct ToFloatType<4> { using type = Float32Type; };
25+
template <> struct ToFloatType<8> { using type = Float64Type; };
3726

3827
inline Type getElementType(Value v) {
3928
auto type = v.getType();

include/gc/IR/EasyBuild.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
1-
//===- EasyBuild.h - Easy IR Builder utilities ------------------*- C++ -*-===//
1+
//===-- EasyBuild.h - DESC --------------------------------------*- C++ -*-===//
22
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
//
9-
// This header file defines the easy-build utilities core data structures for
10-
// building IR.
11-
//
12-
//===----------------------------------------------------------------------===//
13-
148
#ifndef MLIR_IR_EASYBUILD_H
159
#define MLIR_IR_EASYBUILD_H
1610
#include "mlir/IR/Builders.h"

include/gc/IR/EasyBuildSCF.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
//===- EasyBuildSCF.h - Easy IR Builder for general control flow *- C++ -*-===//
1+
//===-- EasyBuildSCF.h - DESC -----------------------------------*- C++ -*-===//
22
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8+
89
//
910
// This header file defines the helper classes, functions and macros to help to
1011
// build general structured control flow. Developers can use the utilities in

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
//===-- DeepTileContractionNamedOp.cpp - DESC -------------------*- C++ -*-===//
2-
//
2+
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
//
6+
//
77
//===----------------------------------------------------------------------===//
88

99
#include "./Tiling.hpp"
@@ -273,9 +273,19 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
273273
b.setInsertionPoint(currentOp);
274274
if (auto partialInterface =
275275
dyn_cast<PartialReductionOpInterface>(currentOp.getOperation())) {
276+
for (auto [idx, tile] : llvm::enumerate(tileSizes)) {
277+
if (isConstantIntValue(tile, 0)) {
278+
tileSizes[idx] = loopRanges[idx].size;
279+
}
280+
}
281+
282+
SmallVector<OpFoldResult> newParallelDims;
283+
for (auto i = 0UL; i < reductionDims.size(); i++) {
284+
newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i));
285+
}
276286
auto tilingResult = linalgX::tileAllUsingForall(
277-
b, cast<PartialReductionOpInterface>(currentOp.getOperation()),
278-
numThreads, tileSizes, std::nullopt);
287+
b, cast<PartialReductionOpInterface>(currentOp.getOperation()), {},
288+
tileSizes, newParallelDims, std::nullopt);
279289
if (failed(tilingResult))
280290
return failure();
281291
currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOp);

lib/gc/Transforms/Tiling.cpp

Lines changed: 110 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -794,11 +794,10 @@ FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
794794
return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
795795
}
796796

797-
FailureOr<linalg::ForallReductionTilingResult>
798-
tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
799-
ArrayRef<OpFoldResult> numThreads,
800-
ArrayRef<OpFoldResult> tileSizes,
801-
std::optional<ArrayAttr> mapping) {
797+
FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall(
798+
RewriterBase &b, PartialReductionOpInterface op,
799+
ArrayRef<OpFoldResult> threadNums, ArrayRef<OpFoldResult> tileSizes,
800+
ArrayRef<OpFoldResult> newParallelDims, std::optional<ArrayAttr> mapping) {
802801
Location loc = op.getLoc();
803802
OpBuilder::InsertionGuard g(b);
804803

@@ -834,6 +833,24 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
834833
if (iteratorType == utils::IteratorType::reduction)
835834
redDims.push_back(idx);
836835
}
836+
837+
SmallVector<OpFoldResult> numThreads(threadNums.begin(), threadNums.end());
838+
if (numThreads.empty()) {
839+
SmallVector<Range> loopRanges = tilingInterfaceOp.getIterationDomain(b);
840+
unsigned nLoops = loopRanges.size();
841+
numThreads.reserve(nLoops);
842+
AffineExpr s0, s1;
843+
bindSymbols(b.getContext(), s0, s1);
844+
AffineExpr divExpr = s0.ceilDiv(s1);
845+
for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
846+
OpFoldResult numTiles = std::get<0>(it);
847+
if (!isConstantIntValue(numTiles, 0))
848+
numTiles = makeComposedFoldedAffineApply(
849+
b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
850+
numThreads.push_back(numTiles);
851+
}
852+
}
853+
837854
bool hasReductionThreads = false;
838855
for (auto dim : redDims) {
839856
if (!isConstantIntValue(numThreads[dim], 0) &&
@@ -850,13 +867,24 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
850867
if ((unsigned)redDims.front() >= numThreads.size())
851868
return b.notifyMatchFailure(
852869
op, "reduction dimension must be mapped to threads");
853-
870+
SmallVector<int> constantNewParallelDims;
871+
for (auto dim : newParallelDims) {
872+
if (getConstantIntValue(dim) == std::nullopt)
873+
return b.notifyMatchFailure(
874+
op, "Expected new parallel dims to be constant integers.");
875+
constantNewParallelDims.push_back(*getConstantIntValue(dim));
876+
}
877+
if (newParallelDims.empty())
878+
constantNewParallelDims = redDims;
879+
if (constantNewParallelDims.size() != redDims.size())
880+
return b.notifyMatchFailure(
881+
op, "reduction dimension must be mapped to new parallel dims");
854882
// 1. Create the inital tensor value.
855883
FailureOr<Operation *> identityTensor = nullptr;
856884
if (hasReductionThreads) {
857885
identityTensor = LinalgOpPartialReductionInterface::
858-
generateInitialTensorForPartialReduction(op, b, loc, numThreads,
859-
redDims, {});
886+
generateInitialTensorForPartialReduction(
887+
op, b, loc, numThreads, redDims, constantNewParallelDims);
860888
}
861889
if (failed(identityTensor))
862890
return b.notifyMatchFailure(op,
@@ -866,7 +894,6 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
866894
SmallVector<Value> dest;
867895
if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
868896
return b.notifyMatchFailure(op, "failed to get destination tensors");
869-
870897
Operation *tiledOp = nullptr;
871898

872899
SmallVector<OpFoldResult> nonZeroNumThreads =
@@ -875,20 +902,21 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
875902
}));
876903
SmallVector<Value> materializedNonZeroNumThreads =
877904
getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
878-
879905
// 2. Create the ForallOp with an empty region.
880906
scf::ForallOp forallOp = b.create<scf::ForallOp>(
881907
loc, getAsOpFoldResult(materializedNonZeroNumThreads),
882908
hasReductionThreads ? (*identityTensor)->getResults() : dest, mapping);
883-
884909
// 3. Calculate the tile offsets and sizes for the subsequent loop that will
885910
// be nested under `forallOp`.
886911
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
912+
std::optional<ArrayRef<OpFoldResult>> nominalTileSizes = std::nullopt;
913+
if (!tileSizes.empty() && threadNums.empty()) {
914+
nominalTileSizes = tileSizes;
915+
}
887916
calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain,
888917
/*omitTileOffsetBoundsCheck =*/false,
889-
/*nominalTileSizes=*/tileSizes, tiledOffsets,
890-
tiledSizes);
891-
918+
/*nominalTileSizes=*/nominalTileSizes,
919+
tiledOffsets, tiledSizes);
892920
// 4. Clone the tileable op and update its destination operands to use the
893921
// output bbArgs of the ForallOp.
894922
SmallVector<Value> tilingResults;
@@ -907,20 +935,26 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
907935
SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
908936
SmallVector<OpFoldResult> outOffsets(numThreads.size(),
909937
b.getIndexAttr(0));
910-
SmallVector<OpFoldResult> sizes;
911-
for (auto s :
912-
cast<RankedTensorType>(destBbArgs[destNum].getType()).getShape()) {
913-
sizes.emplace_back(getAsIndexOpFoldResult(b.getContext(), (int)s));
914-
}
915-
for (auto dim : redDims) {
916-
sizes[dim] = b.getIndexAttr(1);
938+
SmallVector<OpFoldResult> sizes = tiledSizes;
939+
for (const auto &iteratorType : llvm::enumerate(
940+
cast<RankedTensorType>(destBbArgs[destNum].getType())
941+
.getShape())) {
942+
sizes[iteratorType.index()] =
943+
getAsIndexOpFoldResult(b.getContext(), iteratorType.value());
944+
if (llvm::find(constantNewParallelDims, iteratorType.index()) !=
945+
constantNewParallelDims.end()) {
946+
sizes[iteratorType.index()] = b.getIndexAttr(1);
947+
}
917948
}
918949

919950
auto nonZeroDimIdx = 0;
920-
for (auto dim = 0UL; dim < numThreads.size(); dim++) {
921-
if (!isConstantIntValue(numThreads[dim], 0)) {
922-
if (llvm::find(redDims, dim) != redDims.end())
923-
outOffsets[dim] = forallOp.getInductionVars()[nonZeroDimIdx];
951+
auto currentReductionIdx = 0;
952+
for (const auto &iteratorType : llvm::enumerate(numThreads)) {
953+
if (!isConstantIntValue(iteratorType.value(), 0)) {
954+
if (llvm::find(redDims, iteratorType.index()) != redDims.end()) {
955+
outOffsets[constantNewParallelDims[currentReductionIdx++]] =
956+
forallOp.getInductionVars()[nonZeroDimIdx];
957+
}
924958
nonZeroDimIdx++;
925959
}
926960
}
@@ -929,7 +963,10 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
929963
loc, cast<RankedTensorType>(initOperand.getType()),
930964
destBbArgs[destNum], outOffsets, sizes, strides));
931965
} else {
932-
tiledDpsInitOperands.push_back(initOperand);
966+
auto *it = llvm::find(dest, initOperand);
967+
assert(it != dest.end() && "dest operand not found in dest");
968+
unsigned destNum = std::distance(dest.begin(), it);
969+
tiledDpsInitOperands.push_back(destBbArgs[destNum]);
933970
}
934971
}
935972

@@ -944,19 +981,35 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
944981
initOperandPtr.set(tiledInitValue);
945982
}
946983
});
947-
948984
// 5. Tile the cloned op and delete the clone.
949-
FailureOr<TilingResult> tilingResult =
950-
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
951-
tiledSizes);
952-
if (failed(tilingResult))
953-
return clonedOp->emitError("Failed to tile op: ");
954-
if (tilingResult->tiledOps.size() != 1) {
955-
return clonedOp->emitError("expected a single produced tiled op, got ")
956-
<< tilingResult->tiledOps.size();
985+
if (tileSizes.empty() || threadNums.empty()) {
986+
FailureOr<TilingResult> tilingResult =
987+
cast<TilingInterface>(clonedOp).getTiledImplementation(
988+
b, tiledOffsets, tiledSizes);
989+
if (failed(tilingResult))
990+
return clonedOp->emitError("Failed to tile op: ");
991+
if (tilingResult->tiledOps.size() != 1) {
992+
return clonedOp->emitError("expected a single produced tiled op, got ")
993+
<< tilingResult->tiledOps.size();
994+
}
995+
tiledOp = tilingResult->tiledOps.front();
996+
tilingResults = tilingResult->tiledValues;
997+
} else {
998+
LinalgTilingOptions options;
999+
FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
1000+
b, cast<LinalgOp>(clonedOp), tileSizes, options);
1001+
if (failed(maybeTiled))
1002+
return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
1003+
1004+
SmallVector<Value> ids = forallOp.getInductionVars();
1005+
mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
1006+
materializedNonZeroNumThreads);
1007+
if (maybeTiled->loops.size() != 1) {
1008+
return clonedOp->emitError("expected a single produced loop");
1009+
}
1010+
tiledOp = maybeTiled->op;
1011+
tilingResults = maybeTiled->loops.front()->getResults();
9571012
}
958-
tiledOp = tilingResult->tiledOps.front();
959-
tilingResults = tilingResult->tiledValues;
9601013

9611014
b.eraseOp(clonedOp);
9621015
}
@@ -974,23 +1027,33 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
9741027
return op->emitOpError("output offsets couldn't be calculated");
9751028
SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
9761029
int64_t offIdx = 0;
977-
int64_t sizeIdx = 0;
9781030
int64_t nonZeroDimIdx = 0;
1031+
SmallVector<Value> reductionInductionVars;
9791032
for (auto i = 0UL; i < numThreads.size(); ++i) {
980-
if (llvm::find(redDims, i) != redDims.end()) {
1033+
if (llvm::find(constantNewParallelDims, i) !=
1034+
constantNewParallelDims.end()) {
9811035
if (hasReductionThreads) {
982-
resultOffsetsRank.push_back(
983-
forallOp.getInductionVars()[nonZeroDimIdx]);
1036+
resultOffsetsRank.push_back(b.getIndexAttr(1));
9841037
resultSizesRank.push_back(b.getIndexAttr(1));
9851038
}
986-
nonZeroDimIdx++;
987-
continue;
1039+
} else {
1040+
resultOffsetsRank.push_back(resultOffsets[offIdx]);
1041+
resultSizesRank.push_back(resultSizes[offIdx++]);
1042+
}
1043+
if (llvm::find(redDims, i) != redDims.end()) {
1044+
reductionInductionVars.push_back(
1045+
forallOp.getInductionVars()[nonZeroDimIdx]);
9881046
}
9891047
if (!isConstantIntValue(numThreads[i], 0)) {
9901048
nonZeroDimIdx++;
9911049
}
992-
resultOffsetsRank.push_back(resultOffsets[offIdx++]);
993-
resultSizesRank.push_back(resultSizes[sizeIdx++]);
1050+
}
1051+
if (hasReductionThreads) {
1052+
for (auto [parallelDims, redVar] :
1053+
llvm::zip(constantNewParallelDims, reductionInductionVars)) {
1054+
resultOffsetsRank[parallelDims] = redVar;
1055+
resultSizesRank[parallelDims] = b.getIndexAttr(1);
1056+
}
9941057
}
9951058
SmallVector<OpFoldResult> strides(resultSizesRank.size(),
9961059
b.getIndexAttr(1));
@@ -1001,18 +1064,16 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
10011064
b.create<tensor::ParallelInsertSliceOp>(
10021065
loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
10031066
}
1004-
10051067
// 7. Merge the partial reductions.
10061068
Operation *mergeOp = nullptr;
10071069
b.setInsertionPointAfter(forallOp);
10081070
if (hasReductionThreads) {
1009-
Operation *mergeOp =
1010-
op.mergeReductions(b, loc, forallOp->getResults(), redDims);
1071+
Operation *mergeOp = op.mergeReductions(b, loc, forallOp->getResults(),
1072+
constantNewParallelDims);
10111073
b.replaceOp(op, mergeOp->getResults());
10121074
} else {
10131075
b.replaceOp(op, forallOp->getResults());
10141076
}
1015-
10161077
// 8. Return.
10171078
ForallReductionTilingResult results;
10181079
results.initialOp = *identityTensor;

lib/gc/Transforms/Tiling.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@ FailureOr<linalg::ForallReductionTilingResult> tileReductionUsingForall(
4444
ArrayRef<OpFoldResult> threadNums, ArrayRef<OpFoldResult> tileSizes,
4545
ArrayRef<OpFoldResult> newParallelDims, std::optional<ArrayAttr> mapping);
4646

47-
FailureOr<linalg::ForallReductionTilingResult>
48-
tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
49-
ArrayRef<OpFoldResult> numThreads,
50-
ArrayRef<OpFoldResult> tileSizes,
51-
std::optional<ArrayAttr> mapping);
47+
FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall(
48+
RewriterBase &b, PartialReductionOpInterface op,
49+
ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
50+
ArrayRef<OpFoldResult> newParallelDims, std::optional<ArrayAttr> mapping);
5251

5352
} // namespace linalgX
5453
} // namespace mlir

0 commit comments

Comments
 (0)