Skip to content

Commit c6a674d

Browse files
author
ZhangYan
committed
draft update
1 parent fa07870 commit c6a674d

File tree

4 files changed

+324
-25
lines changed

4 files changed

+324
-25
lines changed

include/gc/Analysis/MatmulConfigAnalysis.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,29 @@ struct SystemDesc {
2828
// get runtime OMP_NUM_THREADS
2929
uint32_t getNumThreads() {
3030
char *numThreads = getenv("OMP_NUM_THREADS");
31-
if (numThreads) {
31+
if (!threads_limited && numThreads) {
3232
return std::stoi(numThreads);
3333
}
34+
return curThreads;
35+
}
36+
37+
// set the expected threads
38+
void limitOnSingleNode(uint32_t numa_node) {
39+
char *cacheSize = getenv("NUMA_THREADS");
40+
if (cacheSize) {
41+
curThreads = std::stoi(cacheSize);
42+
threads_limited = true;
43+
}
44+
}
45+
46+
uint32_t getNumNodes() {
47+
char *numThreads = getenv("OMP_NUM_THREADS");
48+
if (threads_limited && numThreads) {
49+
return std::stoi(numThreads) / curThreads;
50+
}
3451
return 1;
3552
}
53+
3654
// get cache size by cacheLevel
3755
size_t getCacheSize(uint8_t cacheLevel) {
3856
if (cacheLevel == 1) {
@@ -57,6 +75,10 @@ struct SystemDesc {
5775
SmallVector<size_t> getContractionOperationMaxVectorLength() {
5876
return {512UL, 512UL};
5977
}
78+
79+
private:
80+
uint32_t curThreads = 1;
81+
bool threads_limited = false;
6082
};
6183

6284
struct MatmulConfig {

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,12 @@ previous matmul
345345
MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
346346
SystemDesc sysDesc;
347347
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(root)) {
348+
// Check if the operation has an attribute named 'splited'
349+
auto splitedAttr = linalgOp->getAttrOfType<IntegerAttr>("splited");
350+
if (splitedAttr) {
351+
sysDesc.limitOnSingleNode(splitedAttr.getInt());
352+
llvm::outs() << "splited mm, and should be allocated on numa node 0.\n";
353+
}
348354
auto oprandDimType = *getOprandDimType(linalgOp);
349355
// get the origin M,N,K size
350356
auto MDimTypeIdx = extractDimTypeIdx(oprandDimType[0], DimType::M);

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
326326
Operation *op, bool isExtract,
327327
SmallVector<int64_t> size,
328328
int shrinDimNum = 0) {
329+
llvm::outs() << "^^^^^^^^^^^^^^setStaticSizeForExtractSliceOp^^^^^^^^^^\n";
329330
OpBuilder::InsertionGuard guard(rewriter);
330331
rewriter.setInsertionPoint(op);
331332
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
@@ -335,6 +336,23 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
335336
for (auto i = 0UL; i < mixedSizes.size(); i++) {
336337
mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]);
337338
}
339+
llvm::outs() << "mixedOffsets: ";
340+
for (auto t : mixedOffsets) {
341+
llvm::outs() << t << ", ";
342+
}
343+
llvm::outs() << "\n";
344+
345+
llvm::outs() << "mixedSizes: ";
346+
for (auto t : mixedSizes) {
347+
llvm::outs() << t << ", ";
348+
}
349+
llvm::outs() << "\n";
350+
351+
llvm::outs() << "mixedStrides: ";
352+
for (auto t : mixedStrides) {
353+
llvm::outs() << t << ", ";
354+
}
355+
llvm::outs() << "\n";
338356
if (shrinDimNum > 0) {
339357
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
340358
extractSlice,
@@ -348,6 +366,7 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
348366
mixedStrides);
349367
}
350368
}
369+
llvm::outs() << "^^^^^^^^^^^^^^setStaticSizeForExtractSliceOp^^^^^^^^^^\n";
351370
}
352371

353372
static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op,
@@ -398,6 +417,7 @@ struct OuterLoopGenerationResult {
398417
static FailureOr<OuterLoopGenerationResult>
399418
generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
400419
const OuterLoopGenerationOption &option) {
420+
llvm::outs() << "======================================\n";
401421
// TODO: handle the return value
402422
OuterLoopGenerationResult result;
403423
auto nestedTileSizes = option.nestedTileSizes;
@@ -471,40 +491,82 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
471491
else
472492
tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile);
473493
}
494+
495+
llvm::outs() << "tileSizes: ";
496+
for (auto t : tileSizes) {
497+
llvm::outs() << t << ", ";
498+
}
499+
llvm::outs() << "\n";
500+
501+
llvm::outs() << "threads: ";
502+
for (auto t : threads) {
503+
llvm::outs() << t << ", ";
504+
}
505+
llvm::outs() << "\n";
506+
474507
SmallVector<Range> loopRanges =
475508
cast<TilingInterface>(currentOp.getOperation()).getIterationDomain(b);
476509
OpBuilder::InsertionGuard guard(b);
477510
b.setInsertionPoint(currentOp);
478511
if (auto partialInterface =
479512
dyn_cast<PartialReductionOpInterface>(currentOp.getOperation())) {
513+
llvm::outs() << "PartialReductionOpInterface\n";
480514
for (auto [idx, tile] : llvm::enumerate(tileSizes)) {
481515
if (isConstantIntValue(tile, 0)) {
482516
tileSizes[idx] = loopRanges[idx].size;
483517
}
484518
}
485-
519+
llvm::outs() << "updated tileSizes: ";
520+
for (auto t : tileSizes) {
521+
llvm::outs() << t << ", ";
522+
}
523+
llvm::outs() << "\n";
486524
SmallVector<OpFoldResult> newParallelDims;
487525
for (auto i = 0UL; i < reductionDims.size(); i++) {
488526
newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i));
489527
}
490-
auto tilingResult = linalgX::tileAllUsingForall(
491-
b, cast<PartialReductionOpInterface>(currentOp.getOperation()), {},
492-
tileSizes, newParallelDims, std::nullopt);
493-
if (failed(tilingResult) &&
494-
tilingResult->parallelTiledOps.size() == 1UL)
495-
return failure();
496-
currentOp =
497-
dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOps.back());
498-
if (!tilingResult->mergeOps.empty()) {
499-
for (const auto &fn : option.finalReduceCallBacks) {
500-
auto result = fn(b, currentOp.getLoc(), *tilingResult);
501-
if (succeeded(result)) {
502-
currentOp = *result;
528+
if (currentTileSize.front() != 16 || true) {
529+
auto tilingResult = linalgX::tileAllUsingForall(
530+
b, cast<PartialReductionOpInterface>(currentOp.getOperation()),
531+
{}, tileSizes, newParallelDims, std::nullopt);
532+
if (failed(tilingResult) &&
533+
tilingResult->parallelTiledOps.size() == 1UL)
534+
return failure();
535+
currentOp =
536+
dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOps.back());
537+
if (!tilingResult->mergeOps.empty()) {
538+
llvm::outs() << "has merge ops\n";
539+
for (const auto &fn : option.finalReduceCallBacks) {
540+
auto result = fn(b, currentOp.getLoc(), *tilingResult);
541+
if (succeeded(result)) {
542+
currentOp = *result;
543+
}
503544
}
504545
}
546+
} else {
547+
llvm::outs() << "handle special cases\n";
548+
OpBuilder::InsertionGuard g(b);
549+
550+
Location loc = currentOp.getLoc();
551+
SmallVector<Value> dest;
552+
if (failed(tensor::getOrCreateDestinations(b, loc, currentOp, dest)))
553+
return b.notifyMatchFailure(currentOp,
554+
"failed to get destination tensors");
555+
arith::ConstantIndexOp lb = b.create<arith::ConstantIndexOp>(loc, 0);
556+
arith::ConstantIndexOp ub = b.create<arith::ConstantIndexOp>(loc, 2);
557+
arith::ConstantIndexOp step =
558+
b.create<arith::ConstantIndexOp>(loc, 1);
559+
560+
Operation *forallOp = b.create<scf::ForallOp>(
561+
loc, ArrayRef<OpFoldResult>(lb->getResult(0)),
562+
ArrayRef<OpFoldResult>(ub->getResult(0)),
563+
ArrayRef<OpFoldResult>(step->getResult(0)), dest, std::nullopt);
564+
currentOp = dyn_cast<linalg::LinalgOp>(forallOp);
505565
}
566+
506567
} else if (auto tilingInterface =
507568
cast<TilingInterface>(currentOp.getOperation())) {
569+
llvm::outs() << "TilingInterface\n";
508570
auto tilingResult = linalg::tileToForallOpUsingTileSizes(
509571
b, tilingInterface, tileSizes, std::nullopt);
510572
if (failed(tilingResult))
@@ -515,6 +577,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
515577
}
516578
}
517579
result.tiledOps.emplace_back(currentOp);
580+
llvm::outs() << "======================================\n";
518581
return result;
519582
}
520583

@@ -595,6 +658,11 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
595658
auto NOuterBlockSize = NDimPos.size() > 1
596659
? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1
597660
: cfg.NBlock;
661+
// Outermost Numa loop
662+
option.nestedTileSizes.emplace_back(
663+
SmallVector<size_t>{uint32_t(MFirstDim / 2)});
664+
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp);
665+
option.loopDim.emplace_back(SmallVector<size_t>{MDimPos[0]});
598666
// Outer
599667
option.nestedTileSizes.emplace_back(SmallVector<size_t>{
600668
MParallelBlockSize, NParallelBlockSize, KParallelBlockSize});

0 commit comments

Comments
 (0)