1
- // ===-- MatmulConfigAnalysis.h - DESC -------------------------- -*- C++ -*-===//
1
+ // ===-- MatmulConfigAnalysis.h - the analysis for matmul config -*- C++ -*-===//
2
2
//
3
3
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
11
11
12
12
#include " gc/Dialect/Linalgx/LinalgxOps.h"
13
13
#include " mlir/Dialect/Linalg/IR/Linalg.h"
14
- #include " mlir/Dialect/Tensor/IR/Tensor.h"
15
- #include " mlir/Pass/Pass.h"
16
- #include " mlir/Support/LLVM.h"
17
- #include " llvm/ADT/DenseMap.h"
18
- #include < llvm/Support/Debug.h>
19
- #include < memory>
20
- #include < numeric>
14
+ #include < cstring>
21
15
22
16
namespace mlir {
23
17
namespace gc {
24
18
25
19
using namespace mlir ;
26
20
21
+ // A mock for the taget information
22
+ // TODO: replace it with upstream hardware description model
27
23
struct SystemDesc {
24
+
25
+ static int getPositiveIntFromStr (char *str, int defaultValue = 1 ) {
26
+ if (!str || strlen (str) == 0 || str[0 ] > ' 9' || str[0 ] < ' 0' ) {
27
+ return defaultValue;
28
+ }
29
+ auto val = std::stoi (str);
30
+ return val > 0 ? val : defaultValue;
31
+ }
32
+
28
33
// get runtime OMP_NUM_THREADS
29
34
uint32_t getNumThreads () {
30
35
char *numThreads = getenv (" OMP_NUM_THREADS" );
31
- if (numThreads) {
32
- return std::stoi (numThreads);
33
- }
34
- return 1 ;
36
+ return getPositiveIntFromStr (numThreads, 1 );
35
37
}
36
38
// get cache size by cacheLevel
37
39
size_t getCacheSize (uint8_t cacheLevel) {
38
40
if (cacheLevel == 1 ) {
39
41
char *cacheSize = getenv (" L1_CACHE_SIZE" );
40
- if (cacheSize) {
41
- return std::stoi (cacheSize);
42
- }
42
+ return getPositiveIntFromStr (cacheSize, 0 );
43
43
} else if (cacheLevel == 2 ) {
44
44
char *cacheSize = getenv (" L2_CACHE_SIZE" );
45
- if (cacheSize) {
46
- return std::stoi (cacheSize);
47
- }
45
+ return getPositiveIntFromStr (cacheSize, 0 );
48
46
} else if (cacheLevel == 3 ) {
49
47
char *cacheSize = getenv (" L3_CACHE_SIZE" );
50
- if (cacheSize) {
51
- return std::stoi (cacheSize);
52
- }
48
+ return getPositiveIntFromStr (cacheSize, 0 );
53
49
}
54
50
return 0 ;
55
51
}
56
52
57
- SmallVector<size_t > getContractionOperationMaxVectorLength () {
58
- return {512UL , 512UL };
53
+ // get the maximum vector length in bits
54
+ size_t getMaxVectorLength () {
55
+ char *maxVectorLanes = getenv (" MAX_VECTOR_LENGTH" );
56
+ return getPositiveIntFromStr (maxVectorLanes, 512 );
59
57
}
60
58
};
61
59
60
+ // The configuration for matmul tiling
61
+ // TODO: support batch matmul
62
62
struct MatmulConfig {
63
- uint32_t MBlock, NBlock, KBlock;
63
+ // The number of threads distributed to M, N, K
64
64
uint32_t MThreads, NThreads, KThreads;
65
+ // The innermost block size for M, N, K which will be directly converted to
66
+ // brgemm.
65
67
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
66
- friend llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
67
- const MatmulConfig &config);
68
+ // The outer block size for M, N, K which will be used to decide the loop tile
69
+ // size in single thread
70
+ uint32_t MBlock, NBlock, KBlock;
68
71
};
69
72
70
73
enum DimType { Batch, M, N, K };
71
74
72
- [[maybe_unused]] static SmallVector<unsigned >
73
- extractDimTypeIdx (ArrayRef<DimType> tyList, DimType ty) {
75
+ // Extract the index of the given DimType in the DimType list
76
+ inline SmallVector<unsigned > extractDimTypeIdx (ArrayRef<DimType> tyList,
77
+ DimType ty) {
74
78
SmallVector<unsigned > idxList;
75
79
for (auto [idx, type] : llvm::enumerate (tyList)) {
76
80
if (type == ty) {
@@ -80,9 +84,11 @@ extractDimTypeIdx(ArrayRef<DimType> tyList, DimType ty) {
80
84
return idxList;
81
85
}
82
86
83
- static FailureOr<SmallVector<SmallVector<DimType>>>
87
+ // Get the operand dim type for every operand for the given linalg op
88
+ inline FailureOr<SmallVector<SmallVector<DimType>>>
84
89
getOprandDimType (linalg::LinalgOp &linalgOp) {
85
- if (isa<linalg::MatmulOp>(linalgOp)) {
90
+ // TODO: replace the linalgx op with generic op
91
+ if (llvm::isa<linalg::MatmulOp>(linalgOp)) {
86
92
return SmallVector<SmallVector<DimType>>{
87
93
SmallVector<DimType>{DimType::M, DimType::K},
88
94
SmallVector<DimType>{DimType::K, DimType::N},
@@ -104,10 +110,31 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
104
110
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
105
111
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
106
112
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
113
+ } else if (llvm::isa<linalg::MatmulTransposeAOp>(linalgOp)) {
114
+ return SmallVector<SmallVector<DimType>>{
115
+ SmallVector<DimType>{DimType::K, DimType::M},
116
+ SmallVector<DimType>{DimType::K, DimType::N},
117
+ SmallVector<DimType>{DimType::M, DimType::N}};
118
+ } else if (llvm::isa<linalg::MatmulTransposeBOp>(linalgOp)) {
119
+ return SmallVector<SmallVector<DimType>>{
120
+ SmallVector<DimType>{DimType::M, DimType::K},
121
+ SmallVector<DimType>{DimType::N, DimType::K},
122
+ SmallVector<DimType>{DimType::M, DimType::N}};
123
+ } else if (llvm::isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
124
+ return SmallVector<SmallVector<DimType>>{
125
+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::M},
126
+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
127
+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
128
+ } else if (llvm::isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
129
+ return SmallVector<SmallVector<DimType>>{
130
+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
131
+ SmallVector<DimType>{DimType::Batch, DimType::N, DimType::K},
132
+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
107
133
}
108
134
return failure ();
109
135
}
110
136
137
+ // The analysis to extract the matmul configuration from the given linalg op
111
138
struct MatmulConfigAnalysis {
112
139
public:
113
140
explicit MatmulConfigAnalysis (Operation *root);
0 commit comments