Skip to content

Commit b3bf8dc

Browse files
committed
fix comments
1 parent 510932e commit b3bf8dc

File tree

3 files changed

+100
-87
lines changed

3 files changed

+100
-87
lines changed

lib/gc/Analysis/MatmulConfigAnalysis.cpp

+38-19
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
2929

3030
template <typename T>
3131
static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
32-
std::vector<T> arry) {
32+
std::vector<T> array) {
3333
ss << "[";
34-
for (auto [idx, a] : llvm::enumerate(arry)) {
35-
if (idx != 0) {
36-
ss << ", ";
37-
}
38-
ss << a;
39-
}
34+
llvm::interleaveComma(array, ss);
4035
ss << "]";
4136
return ss;
4237
}
@@ -174,24 +169,23 @@ std::vector<MatmulConfig>
174169
filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
175170
linalg::LinalgOp &linalgOp, ArrayRef<uint32_t> shape,
176171
SystemDesc &sysDesc, const CostModelFn &costModel,
177-
float eliminationRatio = 0.5, float threshold = -1) {
172+
float preserveRatio = 0.5, float threshold = -1) {
178173
std::vector<MatmulConfig> result;
179174
std::vector<float> costs;
180175
std::vector<size_t> idx;
181-
for (auto [i, config] : llvm::enumerate(configs)) {
176+
for (auto &&[i, config] : llvm::enumerate(configs)) {
182177
costs.push_back(costModel(linalgOp, shape, config, sysDesc));
183178
idx.push_back(i);
184179
}
185180
std::stable_sort(idx.begin(), idx.end(), [&costs](size_t i1, size_t i2) {
186181
return costs[i1] < costs[i2];
187182
});
188-
double thresholdCost =
189-
costs[idx[(size_t)(eliminationRatio * configs.size())]];
183+
double thresholdCost = costs[idx[(size_t)(preserveRatio * configs.size())]];
190184
thresholdCost =
191185
threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost;
192-
for (size_t i = 0; i < configs.size(); i++) {
193-
if (costs[idx[i]] <= thresholdCost) {
194-
result.push_back(configs[idx[i]]);
186+
for (const auto &i : idx) {
187+
if (costs[i] <= thresholdCost) {
188+
result.push_back(configs[i]);
195189
}
196190
}
197191
LLVM_DEBUG(llvm::dbgs() << "thresholdCost is: " << thresholdCost
@@ -210,6 +204,11 @@ std::vector<MatmulConfig>
210204
prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
211205
ArrayRef<uint32_t> shape,
212206
ArrayRef<uint32_t> givenInnermostBlock) {
207+
if (shape.size() < 3) {
208+
LLVM_DEBUG(llvm::dbgs()
209+
<< "The shape is invalid, no candidate is generated\n");
210+
return {};
211+
}
213212
std::vector<MatmulConfig> configs;
214213
uint32_t threads = sysDesc.getNumThreads();
215214
std::vector<uint32_t> MThreadsCandidates =
@@ -290,10 +289,25 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
290289
return configs;
291290
}
292291

292+
bool validateConfig(const MatmulConfig &cfg) {
293+
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
294+
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
295+
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
296+
cfg.innerMostKBlock <= 0) {
297+
return false;
298+
}
299+
if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
300+
cfg.NBlock % cfg.innerMostNBlock != 0 ||
301+
cfg.KBlock % cfg.innerMostKBlock != 0) {
302+
return false;
303+
}
304+
return true;
305+
}
306+
293307
// read the config from the attributes for tuning
294308
bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
295309
size_t cfgItemCnt = 0;
296-
for (auto &attr : attrs) {
310+
for (const auto &attr : attrs) {
297311
if (attr.getName() == "KBlock") {
298312
config.KBlock = cast<IntegerAttr>(attr.getValue()).getInt();
299313
cfgItemCnt++;
@@ -323,7 +337,12 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
323337
cfgItemCnt++;
324338
}
325339
}
326-
return cfgItemCnt == 9;
340+
if (validateConfig(config)) {
341+
return cfgItemCnt == 9;
342+
} else {
343+
LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n");
344+
return false;
345+
}
327346
}
328347

329348
// Analyze the workload and system description to generate the default config
@@ -350,14 +369,14 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
350369
SmallVector<unsigned> NDimTypeIdx =
351370
extractDimTypeIdx(oprandDimType[1], DimType::N);
352371
uint32_t M = 1U, N = 1U, K = 1U;
353-
for (auto [s, dimType] :
372+
for (auto &&[s, dimType] :
354373
llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(0)),
355374
oprandDimType[0])) {
356375
if (dimType == DimType::M) {
357376
M *= s;
358377
}
359378
}
360-
for (auto [s, dimType] :
379+
for (auto &&[s, dimType] :
361380
llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(1)),
362381
oprandDimType[1])) {
363382
if (dimType == DimType::N) {
@@ -425,7 +444,7 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
425444
SmallVector<uint32_t> shape = {M, N, K};
426445
std::vector<MatmulConfig> configCandidates =
427446
prepareConfigCandidates(root, sysDesc, shape, givenInnermostBlock);
428-
for (auto [fn, name, threshold] : costModelList) {
447+
for (auto &&[fn, name, threshold] : costModelList) {
429448
configCandidates = filterConfigByCostModel(
430449
configCandidates, linalgOp, shape, sysDesc, fn, 0.5, threshold);
431450
}

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

+25-25
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) {
151151
// Check if the linalgOp need to be legalized to f32 accumulation type
152152
static bool needToLegalizeDtype(linalg::LinalgOp linalgOp) {
153153
mlir::Type dataType =
154-
dyn_cast<mlir::RankedTensorType>(linalgOp.getDpsInputs()[0].getType())
154+
dyn_cast<mlir::ShapedType>(linalgOp.getDpsInputs()[0].getType())
155155
.getElementType();
156156
mlir::Type resultType =
157-
dyn_cast<mlir::RankedTensorType>(linalgOp.getDpsInits()[0].getType())
157+
dyn_cast<mlir::ShapedType>(linalgOp.getDpsInits()[0].getType())
158158
.getElementType();
159159
return (dataType.isBF16() || dataType.isF16()) && dataType == resultType;
160160
}
@@ -372,7 +372,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
372372
linalg::LinalgOp currentOp = linalgOp;
373373

374374
bool hasFullResult = !option.isPartialResult;
375-
for (auto [i, loopType] : llvm::enumerate(loopType)) {
375+
for (auto &&[i, loopType] : llvm::enumerate(loopType)) {
376376
ArrayRef<size_t> currentDim = loopDim[i];
377377
ArrayRef<size_t> currentTileSize = nestedTileSizes[i];
378378
if (loopType == OuterLoopGenerationOption::LoopType::ForOp) {
@@ -420,7 +420,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
420420
cast<TilingInterface>(currentOp.getOperation()).getIterationDomain(b);
421421
currentOp.getReductionDims(reductionDims);
422422
bool tileOnReduction = false;
423-
for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) {
423+
for (auto &&[d, tile] : llvm::zip(currentDim, currentTileSize)) {
424424
if (llvm::find(reductionDims, d) != reductionDims.end() && tile != 0 &&
425425
(!getConstantIntValue(loopRanges[d].size) ||
426426
tile != static_cast<size_t>(
@@ -438,22 +438,23 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
438438
OpBuilder::InsertionGuard guard(b);
439439
b.setInsertionPoint(currentOp);
440440
if (tileOnReduction) {
441-
for (auto [idx, tile] : llvm::enumerate(tileSizes)) {
441+
for (auto &&[idx, tile] : llvm::enumerate(tileSizes)) {
442442
if (isConstantIntValue(tile, 0) &&
443443
llvm::find(reductionDims, idx) != reductionDims.end()) {
444444
tileSizes[idx] = loopRanges[idx].size;
445445
}
446446
}
447447
SmallVector<OpFoldResult> newParallelDims;
448-
for (size_t i = 0UL; i < reductionDims.size(); i++) {
449-
newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i));
448+
for (auto iter : llvm::enumerate(reductionDims)) {
449+
newParallelDims.push_back(
450+
getAsIndexOpFoldResult(b.getContext(), iter.index()));
450451
}
451452
FailureOr<linalg::ForallReductionTilingResult> tilingResult =
452453
linalgX::tileReductionUsingForall(
453454
b, cast<PartialReductionOpInterface>(currentOp.getOperation()),
454455
{}, tileSizes, newParallelDims, std::nullopt);
455456
if (failed(tilingResult) &&
456-
tilingResult->parallelTiledOps.size() == 1UL)
457+
llvm::hasSingleElement(tilingResult->parallelTiledOps))
457458
return failure();
458459
currentOp =
459460
dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOps.back());
@@ -585,7 +586,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
585586
: cfg.NBlock;
586587

587588
// Outer loop tile size
588-
for (auto [tile, dim] :
589+
for (auto &&[tile, dim] :
589590
llvm::zip(SmallVector<size_t>{KParallelBlockSize, MParallelBlockSize,
590591
NParallelBlockSize},
591592
SmallVector<size_t>{KDimPos[0], MDimPos[0], NDimPos[0]})) {
@@ -596,27 +597,27 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
596597
}
597598

598599
// Middle loop tile size
599-
for (auto [tile, dim] :
600+
for (auto &&[tile, dim] :
600601
llvm::zip(SmallVector<size_t>{MOuterBlockSize, NOuterBlockSize,
601602
KOuterBlockSize},
602603
SmallVector<size_t>{MDimPos[0], NDimPos[0], KDimPos[0]})) {
603604
option.nestedTileSizes.emplace_back(SmallVector<size_t>{tile});
604605
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
605606
option.loopDim.emplace_back(SmallVector<size_t>{dim});
606607
}
607-
if (KDimPos.size() == 1) {
608+
if (llvm::hasSingleElement(KDimPos)) {
608609
option.nestedTileSizes.emplace_back(SmallVector<size_t>{cfg.KBlock});
609610
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
610611
option.loopDim.emplace_back(SmallVector<size_t>{KDimPos.back()});
611612
}
612613
// Inner loop tile size
613-
if (MDimPos.size() == 1) {
614+
if (llvm::hasSingleElement(MDimPos)) {
614615
option.nestedTileSizes.emplace_back(
615616
SmallVector<size_t>{cfg.innerMostMBlock});
616617
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
617618
option.loopDim.emplace_back(SmallVector<size_t>{MDimPos.back()});
618619
}
619-
if (NDimPos.size() == 1) {
620+
if (llvm::hasSingleElement(NDimPos)) {
620621
option.nestedTileSizes.emplace_back(
621622
SmallVector<size_t>{cfg.innerMostNBlock});
622623
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
@@ -656,7 +657,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
656657
const linalg::ForallReductionTilingResult &result)
657658
-> FailureOr<linalg::LinalgOp> {
658659
ArrayRef<Value> initValue = result.initialValues;
659-
if (initValue.size() == 1 &&
660+
if (llvm::hasSingleElement(initValue) &&
660661
isa<linalg::FillOp>(initValue[0].getDefiningOp())) {
661662
rewriter.replaceOp(initValue[0].getDefiningOp(),
662663
dyn_cast<DestinationStyleOpInterface>(
@@ -706,7 +707,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
706707
SmallVector<int64_t> AInnermostDims, BInnermostDims, CInnermostDims;
707708
bool firstM = true, firstK = true, firstN = true;
708709
if (MDimNum > 1) {
709-
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[0])) {
710+
for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[0])) {
710711
if (iter == DimType::M && firstM) {
711712
AInnermostDims.push_back(1);
712713
firstM = false;
@@ -721,7 +722,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
721722
}
722723
firstM = true;
723724
firstN = true;
724-
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[2])) {
725+
for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[2])) {
725726
if (iter == DimType::M && firstM) {
726727
CInnermostDims.push_back(1);
727728
firstM = false;
@@ -745,7 +746,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
745746
if (NDimNum > 1) {
746747
firstN = true;
747748
firstK = true;
748-
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) {
749+
for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[1])) {
749750
if (iter == DimType::N && firstN) {
750751
BInnermostDims.push_back(1);
751752
firstN = false;
@@ -768,13 +769,13 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
768769
OpBuilder::InsertionGuard guard(rewriter);
769770
rewriter.setInsertionPoint(currentOp);
770771
mlir::Type dataType =
771-
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[0].getType())
772+
dyn_cast<mlir::ShapedType>(currentOp.getDpsInputs()[0].getType())
772773
.getElementType();
773774
mlir::Type weightType =
774-
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[1].getType())
775+
dyn_cast<mlir::ShapedType>(currentOp.getDpsInputs()[1].getType())
775776
.getElementType();
776777
mlir::Type resultType =
777-
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInits()[0].getType())
778+
dyn_cast<mlir::ShapedType>(currentOp.getDpsInits()[0].getType())
778779
.getElementType();
779780

780781
// update the extractSlice to static size, replace it with
@@ -821,9 +822,8 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
821822
currentOp.getDpsInits()[0]);
822823
// Create the brgemm op and replace the origin linalg op
823824
linalg::LinalgOp matmul;
824-
if (dyn_cast<mlir::RankedTensorType>(weightOprand.getType())
825-
.getShape()
826-
.size() == 3) {
825+
if (dyn_cast<mlir::ShapedType>(weightOprand.getType()).getShape().size() ==
826+
3) {
827827
matmul = rewriter.create<linalg::BatchReduceMatmulOp>(
828828
loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand},
829829
resultOprand);
@@ -843,7 +843,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
843843
// fuse the low precision cast to the innermost body
844844
rewriter.setInsertionPointAfter(currentOp);
845845
Value cond;
846-
for (LoopLikeOpInterface loop : option.KLoopHandles) {
846+
for (LoopLikeOpInterface &loop : option.KLoopHandles) {
847847
Value induceVar = turnOpFoldResultIntoValue(
848848
rewriter, loc, *loop.getSingleInductionVar());
849849
Value upBound = turnOpFoldResultIntoValue(rewriter, loc,
@@ -903,7 +903,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
903903
Value cond;
904904
arith::ConstantIndexOp zeroConst =
905905
rewriter.create<arith::ConstantIndexOp>(loc, 0);
906-
for (LoopLikeOpInterface loop : option.KLoopHandles) {
906+
for (LoopLikeOpInterface &loop : option.KLoopHandles) {
907907
Value induceVar = loop.getLoopRegions().front()->front().getArgument(0);
908908
Value currentCond = rewriter.create<arith::CmpIOp>(
909909
loc, arith::CmpIPredicate::eq, induceVar, zeroConst);

0 commit comments

Comments
 (0)