Skip to content

Commit 5bc5e37

Browse files
committed
polish code
1 parent 8816159 commit 5bc5e37

21 files changed

+1350
-1303
lines changed

CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ set(GC_LIB_LINKED_LIBS
108108
GCCpuRuntime
109109
GCPasses
110110
GCAnalysis
111-
MLIROneDNNGraph
112111
)
113112
add_mlir_library(graph_compiler SHARED ${GC_LIB_SOURCES})
114113
target_include_directories(graph_compiler PUBLIC ${GC_LIB_INCLUDES})

docs/deep_tile_matmul_design.md

+594
Large diffs are not rendered by default.

include/gc/Analysis/MatmulConfigAnalysis.h

+57-30
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-- MatmulConfigAnalysis.h - DESC ---------------------------*- C++ -*-===//
1+
//===-- MatmulConfigAnalysis.h - the analysis for matmul config -*- C++ -*-===//
22
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -11,66 +11,70 @@
1111

1212
#include "gc/Dialect/Linalgx/LinalgxOps.h"
1313
#include "mlir/Dialect/Linalg/IR/Linalg.h"
14-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
15-
#include "mlir/Pass/Pass.h"
16-
#include "mlir/Support/LLVM.h"
17-
#include "llvm/ADT/DenseMap.h"
18-
#include <llvm/Support/Debug.h>
19-
#include <memory>
20-
#include <numeric>
14+
#include <cstring>
2115

2216
namespace mlir {
2317
namespace gc {
2418

2519
using namespace mlir;
2620

21+
// A mock for the taget information
22+
// TODO: replace it with upstream hardware description model
2723
struct SystemDesc {
24+
25+
static int getPositiveIntFromStr(char *str, int defaultValue = 1) {
26+
if (!str || strlen(str) == 0 || str[0] > '9' || str[0] < '0') {
27+
return defaultValue;
28+
}
29+
auto val = std::stoi(str);
30+
return val > 0 ? val : defaultValue;
31+
}
32+
2833
// get runtime OMP_NUM_THREADS
2934
uint32_t getNumThreads() {
3035
char *numThreads = getenv("OMP_NUM_THREADS");
31-
if (numThreads) {
32-
return std::stoi(numThreads);
33-
}
34-
return 1;
36+
return getPositiveIntFromStr(numThreads, 1);
3537
}
3638
// get cache size by cacheLevel
3739
size_t getCacheSize(uint8_t cacheLevel) {
3840
if (cacheLevel == 1) {
3941
char *cacheSize = getenv("L1_CACHE_SIZE");
40-
if (cacheSize) {
41-
return std::stoi(cacheSize);
42-
}
42+
return getPositiveIntFromStr(cacheSize, 0);
4343
} else if (cacheLevel == 2) {
4444
char *cacheSize = getenv("L2_CACHE_SIZE");
45-
if (cacheSize) {
46-
return std::stoi(cacheSize);
47-
}
45+
return getPositiveIntFromStr(cacheSize, 0);
4846
} else if (cacheLevel == 3) {
4947
char *cacheSize = getenv("L3_CACHE_SIZE");
50-
if (cacheSize) {
51-
return std::stoi(cacheSize);
52-
}
48+
return getPositiveIntFromStr(cacheSize, 0);
5349
}
5450
return 0;
5551
}
5652

57-
SmallVector<size_t> getContractionOperationMaxVectorLength() {
58-
return {512UL, 512UL};
53+
// get the maximum vector length in bits
54+
size_t getMaxVectorLength() {
55+
char *maxVectorLanes = getenv("MAX_VECTOR_LENGTH");
56+
return getPositiveIntFromStr(maxVectorLanes, 512);
5957
}
6058
};
6159

60+
// The configuration for matmul tiling
61+
// TODO: support batch matmul
6262
struct MatmulConfig {
63-
uint32_t MBlock, NBlock, KBlock;
63+
// The number of threads distributed to M, N, K
6464
uint32_t MThreads, NThreads, KThreads;
65+
// The innermost block size for M, N, K which will be directly converted to
66+
// brgemm.
6567
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
66-
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
67-
const MatmulConfig &config);
68+
// The outer block size for M, N, K which will be used to decide the loop tile
69+
// size in single thread
70+
uint32_t MBlock, NBlock, KBlock;
6871
};
6972

7073
enum DimType { Batch, M, N, K };
7174

72-
[[maybe_unused]] static SmallVector<unsigned>
73-
extractDimTypeIdx(ArrayRef<DimType> tyList, DimType ty) {
75+
// Extract the index of the given DimType in the DimType list
76+
inline SmallVector<unsigned> extractDimTypeIdx(ArrayRef<DimType> tyList,
77+
DimType ty) {
7478
SmallVector<unsigned> idxList;
7579
for (auto [idx, type] : llvm::enumerate(tyList)) {
7680
if (type == ty) {
@@ -80,9 +84,11 @@ extractDimTypeIdx(ArrayRef<DimType> tyList, DimType ty) {
8084
return idxList;
8185
}
8286

83-
static FailureOr<SmallVector<SmallVector<DimType>>>
87+
// Get the operand dim type for every operand for the given linalg op
88+
inline FailureOr<SmallVector<SmallVector<DimType>>>
8489
getOprandDimType(linalg::LinalgOp &linalgOp) {
85-
if (isa<linalg::MatmulOp>(linalgOp)) {
90+
// TODO: replace the linalgx op with generic op
91+
if (llvm::isa<linalg::MatmulOp>(linalgOp)) {
8692
return SmallVector<SmallVector<DimType>>{
8793
SmallVector<DimType>{DimType::M, DimType::K},
8894
SmallVector<DimType>{DimType::K, DimType::N},
@@ -104,10 +110,31 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
104110
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
105111
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
106112
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
113+
} else if (llvm::isa<linalg::MatmulTransposeAOp>(linalgOp)) {
114+
return SmallVector<SmallVector<DimType>>{
115+
SmallVector<DimType>{DimType::K, DimType::M},
116+
SmallVector<DimType>{DimType::K, DimType::N},
117+
SmallVector<DimType>{DimType::M, DimType::N}};
118+
} else if (llvm::isa<linalg::MatmulTransposeBOp>(linalgOp)) {
119+
return SmallVector<SmallVector<DimType>>{
120+
SmallVector<DimType>{DimType::M, DimType::K},
121+
SmallVector<DimType>{DimType::N, DimType::K},
122+
SmallVector<DimType>{DimType::M, DimType::N}};
123+
} else if (llvm::isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
124+
return SmallVector<SmallVector<DimType>>{
125+
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::M},
126+
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
127+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
128+
} else if (llvm::isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
129+
return SmallVector<SmallVector<DimType>>{
130+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
131+
SmallVector<DimType>{DimType::Batch, DimType::N, DimType::K},
132+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
107133
}
108134
return failure();
109135
}
110136

137+
// The analysis to extract the matmul configuration from the given linalg op
111138
struct MatmulConfigAnalysis {
112139
public:
113140
explicit MatmulConfigAnalysis(Operation *root);

0 commit comments

Comments
 (0)