|
10 | 10 | #define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
|
11 | 11 |
|
12 | 12 | #include "gc/Dialect/Linalgx/LinalgxOps.h"
|
| 13 | +#include "mlir/Dialect/DLTI/DLTI.h" |
13 | 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h"
|
14 |
| -#include <cstring> |
| 15 | +#include "mlir/Interfaces/DataLayoutInterfaces.h" |
15 | 16 |
|
16 | 17 | namespace mlir {
|
17 | 18 | namespace gc {
|
18 | 19 |
|
19 | 20 | using namespace mlir;
|
20 | 21 |
|
21 |
| -// A mock for the taget information |
22 |
| -// TODO: replace it with upstream hardware description model |
23 | 22 | struct SystemDesc {
|
24 |
| - |
25 |
| - static int getPositiveIntFromStr(char *str, int defaultValue = 1) { |
26 |
| - if (!str || strlen(str) == 0 || str[0] > '9' || str[0] < '0') { |
27 |
| - return defaultValue; |
28 |
| - } |
29 |
| - auto val = std::stoi(str); |
30 |
| - return val > 0 ? val : defaultValue; |
31 |
| - } |
32 |
| - |
33 | 23 | // get runtime OMP_NUM_THREADS
|
34 | 24 | uint32_t getNumThreads() {
|
35 |
| - char *numThreads = getenv("OMP_NUM_THREADS"); |
36 |
| - return getPositiveIntFromStr(numThreads, 1); |
| 25 | + std::optional<Attribute> numThreads = layout.getDevicePropertyValue( |
| 26 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 27 | + Builder(ctx).getStringAttr("num_threads")); |
| 28 | + if (numThreads && isa<IntegerAttr>(*numThreads)) { |
| 29 | + return dyn_cast<IntegerAttr>(*numThreads).getInt(); |
| 30 | + } |
| 31 | + return 1; |
37 | 32 | }
|
38 | 33 | // get cache size by cacheLevel
|
39 | 34 | size_t getCacheSize(uint8_t cacheLevel) {
|
40 | 35 | if (cacheLevel == 1) {
|
41 |
| - char *cacheSize = getenv("L1_CACHE_SIZE"); |
42 |
| - return getPositiveIntFromStr(cacheSize, 0); |
| 36 | + std::optional<Attribute> cacheSize = layout.getDevicePropertyValue( |
| 37 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 38 | + Builder(ctx).getStringAttr("L1_cache_size_in_bytes")); |
| 39 | + if (cacheSize && isa<IntegerAttr>(*cacheSize)) { |
| 40 | + return dyn_cast<IntegerAttr>(*cacheSize).getInt(); |
| 41 | + } |
43 | 42 | } else if (cacheLevel == 2) {
|
44 |
| - char *cacheSize = getenv("L2_CACHE_SIZE"); |
45 |
| - return getPositiveIntFromStr(cacheSize, 0); |
| 43 | + std::optional<Attribute> cacheSize = layout.getDevicePropertyValue( |
| 44 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 45 | + Builder(ctx).getStringAttr("L2_cache_size_in_bytes")); |
| 46 | + if (cacheSize && isa<IntegerAttr>(*cacheSize)) { |
| 47 | + return dyn_cast<IntegerAttr>(*cacheSize).getInt(); |
| 48 | + } |
46 | 49 | } else if (cacheLevel == 3) {
|
47 |
| - char *cacheSize = getenv("L3_CACHE_SIZE"); |
48 |
| - return getPositiveIntFromStr(cacheSize, 0); |
| 50 | + std::optional<Attribute> cacheSize = layout.getDevicePropertyValue( |
| 51 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 52 | + Builder(ctx).getStringAttr("L3_cache_size_in_bytes")); |
| 53 | + if (cacheSize && isa<IntegerAttr>(*cacheSize)) { |
| 54 | + return dyn_cast<IntegerAttr>(*cacheSize).getInt(); |
| 55 | + } |
49 | 56 | }
|
50 | 57 | return 0;
|
51 | 58 | }
|
52 | 59 |
|
53 | 60 | // get the maximum vector length in bits
|
54 | 61 | size_t getMaxVectorLength() {
|
55 |
| - char *maxVectorLanes = getenv("MAX_VECTOR_LENGTH"); |
56 |
| - return getPositiveIntFromStr(maxVectorLanes, 512); |
| 62 | + std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue( |
| 63 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 64 | + Builder(ctx).getStringAttr("max_vector_width")); |
| 65 | + if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) { |
| 66 | + return dyn_cast<IntegerAttr>(*maxVectorLength).getInt(); |
| 67 | + } |
| 68 | + return 512; |
57 | 69 | }
|
| 70 | + |
| 71 | + SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {} |
| 72 | + |
| 73 | +private: |
| 74 | + DataLayout layout; |
| 75 | + MLIRContext *ctx; |
58 | 76 | };
|
59 | 77 |
|
60 | 78 | // The configuration for matmul tiling
|
61 | 79 | // TODO: support batch matmul
|
62 | 80 | struct MatmulConfig {
|
63 | 81 | // The number of threads distributed to M, N, K
|
64 | 82 | uint32_t MThreads, NThreads, KThreads;
|
65 |
| - // The innermost block size for M, N, K which will be directly converted to |
66 |
| - // brgemm. |
67 |
| - uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; |
68 | 83 | // The outer block size for M, N, K which will be used to decide the loop tile
|
69 | 84 | // size in single thread
|
70 | 85 | uint32_t MBlock, NBlock, KBlock;
|
| 86 | + // The innermost block size for M, N, K which will be directly converted to |
| 87 | + // brgemm. |
| 88 | + uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; |
71 | 89 | };
|
72 | 90 |
|
73 | 91 | enum DimType { Batch, M, N, K };
|
|
0 commit comments