Skip to content

Commit 510932e

Browse files
committed
support dlti
1 parent 5bc5e37 commit 510932e

File tree

5 files changed

+99
-31
lines changed

5 files changed

+99
-31
lines changed

include/gc/Analysis/MatmulConfigAnalysis.h

+43-25
Original file line numberDiff line numberDiff line change
@@ -10,64 +10,82 @@
1010
#define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
1111

1212
#include "gc/Dialect/Linalgx/LinalgxOps.h"
13+
#include "mlir/Dialect/DLTI/DLTI.h"
1314
#include "mlir/Dialect/Linalg/IR/Linalg.h"
14-
#include <cstring>
15+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
1516

1617
namespace mlir {
1718
namespace gc {
1819

1920
using namespace mlir;
2021

21-
// A mock for the taget information
22-
// TODO: replace it with upstream hardware description model
2322
struct SystemDesc {
24-
25-
static int getPositiveIntFromStr(char *str, int defaultValue = 1) {
26-
if (!str || strlen(str) == 0 || str[0] > '9' || str[0] < '0') {
27-
return defaultValue;
28-
}
29-
auto val = std::stoi(str);
30-
return val > 0 ? val : defaultValue;
31-
}
32-
3323
// get runtime OMP_NUM_THREADS
3424
uint32_t getNumThreads() {
35-
char *numThreads = getenv("OMP_NUM_THREADS");
36-
return getPositiveIntFromStr(numThreads, 1);
25+
std::optional<Attribute> numThreads = layout.getDevicePropertyValue(
26+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
27+
Builder(ctx).getStringAttr("num_threads"));
28+
if (numThreads && isa<IntegerAttr>(*numThreads)) {
29+
return dyn_cast<IntegerAttr>(*numThreads).getInt();
30+
}
31+
return 1;
3732
}
3833
// get cache size by cacheLevel
3934
size_t getCacheSize(uint8_t cacheLevel) {
4035
if (cacheLevel == 1) {
41-
char *cacheSize = getenv("L1_CACHE_SIZE");
42-
return getPositiveIntFromStr(cacheSize, 0);
36+
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
37+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
38+
Builder(ctx).getStringAttr("L1_cache_size_in_bytes"));
39+
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
40+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
41+
}
4342
} else if (cacheLevel == 2) {
44-
char *cacheSize = getenv("L2_CACHE_SIZE");
45-
return getPositiveIntFromStr(cacheSize, 0);
43+
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
44+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
45+
Builder(ctx).getStringAttr("L2_cache_size_in_bytes"));
46+
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
47+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
48+
}
4649
} else if (cacheLevel == 3) {
47-
char *cacheSize = getenv("L3_CACHE_SIZE");
48-
return getPositiveIntFromStr(cacheSize, 0);
50+
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
51+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
52+
Builder(ctx).getStringAttr("L3_cache_size_in_bytes"));
53+
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
54+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
55+
}
4956
}
5057
return 0;
5158
}
5259

5360
// get the maximum vector length in bits
5461
size_t getMaxVectorLength() {
55-
char *maxVectorLanes = getenv("MAX_VECTOR_LENGTH");
56-
return getPositiveIntFromStr(maxVectorLanes, 512);
62+
std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue(
63+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
64+
Builder(ctx).getStringAttr("max_vector_width"));
65+
if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) {
66+
return dyn_cast<IntegerAttr>(*maxVectorLength).getInt();
67+
}
68+
return 512;
5769
}
70+
71+
SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {}
72+
73+
private:
74+
DataLayout layout;
75+
MLIRContext *ctx;
5876
};
5977

6078
// The configuration for matmul tiling
6179
// TODO: support batch matmul
6280
struct MatmulConfig {
6381
// The number of threads distributed to M, N, K
6482
uint32_t MThreads, NThreads, KThreads;
65-
// The innermost block size for M, N, K which will be directly converted to
66-
// brgemm.
67-
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
6883
// The outer block size for M, N, K which will be used to decide the loop tile
6984
// size in single thread
7085
uint32_t MBlock, NBlock, KBlock;
86+
// The innermost block size for M, N, K which will be directly converted to
87+
// brgemm.
88+
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
7189
};
7290

7391
enum DimType { Batch, M, N, K };

lib/gc/Analysis/MatmulConfigAnalysis.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ double vectorRegEfficiencyCost(linalg::LinalgOp &linalgOp,
8888
size_t dtypeSize = DataLayout().getTypeSizeInBits(
8989
ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType());
9090
size_t maxVectorLength = sysDesc.getMaxVectorLength() / dtypeSize;
91+
// TODO: take matrix register like amx into account
9192
double cost = (maxVectorLength - config.innerMostMBlock % maxVectorLength) %
9293
maxVectorLength * 1.0 / config.innerMostMBlock +
9394
(maxVectorLength - config.innerMostKBlock % maxVectorLength) %
@@ -270,8 +271,8 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
270271
continue;
271272
}
272273
MatmulConfig config{
273-
MBlock, NBlock, KBlock,
274274
MThreads, NThreads, KThreads,
275+
MBlock, NBlock, KBlock,
275276
innerMostMBlock, innerMostNBlock, innerMostKBlock};
276277
configs.push_back(config);
277278
}
@@ -311,13 +312,13 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
311312
} else if (attr.getName() == "MThreads") {
312313
config.MThreads = cast<IntegerAttr>(attr.getValue()).getInt();
313314
cfgItemCnt++;
314-
} else if (attr.getName() == "innerMostMBlock") {
315+
} else if (attr.getName() == "innermostMBlock") {
315316
config.innerMostMBlock = cast<IntegerAttr>(attr.getValue()).getInt();
316317
cfgItemCnt++;
317-
} else if (attr.getName() == "innerMostNBlock") {
318+
} else if (attr.getName() == "innermostNBlock") {
318319
config.innerMostNBlock = cast<IntegerAttr>(attr.getValue()).getInt();
319320
cfgItemCnt++;
320-
} else if (attr.getName() == "innerMostKBlock") {
321+
} else if (attr.getName() == "innermostKBlock") {
321322
config.innerMostKBlock = cast<IntegerAttr>(attr.getValue()).getInt();
322323
cfgItemCnt++;
323324
}
@@ -338,7 +339,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
338339
// previous matmul
339340
MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
340341
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(root)) {
341-
SystemDesc sysDesc;
342+
SystemDesc sysDesc(root->getParentOfType<ModuleOp>());
342343
SmallVector<SmallVector<DimType>> oprandDimType =
343344
*getOprandDimType(linalgOp);
344345
// get the origin M,N,K size

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ static Operation *findParentFillOp(Value val) {
243243
llvm::find(skipOpList, currentOp->getName().getStringRef()) !=
244244
skipOpList.end() &&
245245
!isa<linalg::FillOp>(currentOp)) {
246-
currentOp = currentOp->getResult(0).getDefiningOp();
246+
currentOp = currentOp->getOperand(0).getDefiningOp();
247247
}
248248
if (currentOp && isa<linalg::FillOp>(currentOp)) {
249249
return currentOp;

lib/gc/Transforms/TilingUtil.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
namespace mlir {
1717
namespace linalgX {
1818

19+
// An enahncement for the upstream pass to support tiling reduction for MKmk
20+
// like cases(with multiple reduction iterators).
1921
FailureOr<linalg::ForallReductionTilingResult> tileReductionUsingForall(
2022
RewriterBase &b, PartialReductionOpInterface op,
2123
ArrayRef<OpFoldResult> threadNums, ArrayRef<OpFoldResult> tileSizes,

test/gc/Transform/deepTileContractionNamedOp.mlir

+47
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,50 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12
108108
return %2 : tensor<4096x4096xbf16>
109109
}
110110

111+
// -----
112+
113+
module attributes {
114+
dlti.target_system_spec = #dlti.target_system_spec<
115+
"CPU": #dlti.target_device_spec<
116+
#dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>,
117+
#dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>,
118+
#dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>,
119+
#dlti.dl_entry<"num_threads", 56 : i32>,
120+
#dlti.dl_entry<"max_vector_width", 512 : i32>>
121+
>} {
122+
/// CHECK-LABEL: @matmul_2Dx4D_bf16_with_dlti
123+
func.func @matmul_2Dx4D_bf16_with_dlti(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<4096x4096xbf16> {
124+
%cst_0 = arith.constant 0.000000e+00 : bf16
125+
%0 = tensor.empty() : tensor<4096x4096xbf16>
126+
%1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
127+
// CHECK: scf.forall
128+
// CHECK: tensor.extract_slice
129+
// CHECK: scf.forall
130+
// CHECK: tensor.extract_slice
131+
// CHECK: scf.forall
132+
// CHECK: tensor.extract_slice
133+
// CHECK: scf.for
134+
// CHECK: tensor.extract_slice
135+
// CHECK: scf.for
136+
// CHECK: scf.for
137+
// CHECK: tensor.extract_slice
138+
// CHECK: tensor.extract_slice
139+
// CHECK: scf.for
140+
// CHECK: tensor.extract_slice
141+
// CHECK: tensor.extract_slice
142+
// CHECK: linalg.transpose
143+
// CHECK: scf.if
144+
// CHECK: linalg.fill
145+
// CHECK: linalgx.batch_reduce_matmul_vnni
146+
// CHECK: else
147+
// CHECK: linalgx.batch_reduce_matmul_vnni
148+
// CHECK: scf.forall.in_parallel
149+
// CHECK: scf.forall.in_parallel
150+
// CHECK: scf.forall.in_parallel
151+
// CHECK: linalg.reduce
152+
// CHECK: linalg.copy
153+
%2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
154+
return %2 : tensor<4096x4096xbf16>
155+
}
156+
157+
}

0 commit comments

Comments
 (0)