1
+ // ===-- MatmulConfigAnalysis.h - the analysis for matmul config -*- C++ -*-===//
2
+ //
3
+ // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4
+ // See https://llvm.org/LICENSE.txt for license information.
5
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
+ //
7
+ // ===----------------------------------------------------------------------===//
8
+
9
+ #ifndef MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
10
+ #define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
11
+
12
+ #include " gc/Dialect/Linalgx/LinalgxOps.h"
13
+ #include " mlir/Dialect/DLTI/DLTI.h"
14
+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
15
+ #include " mlir/Interfaces/DataLayoutInterfaces.h"
16
+
17
+ namespace mlir {
18
+ namespace gc {
19
+
20
+ using namespace mlir ;
21
+
22
+ // The configuration for matmul tiling
23
+ // TODO: support batch matmul
24
+ struct MatmulConfig {
25
+ // The number of threads distributed to M, N, K
26
+ uint32_t MThreads, NThreads, KThreads;
27
+ // The outer block size for M, N, K which will be used to decide the loop tile
28
+ // size in single thread
29
+ uint32_t MBlock, NBlock, KBlock;
30
+ // The innermost block size for M, N, K which will be directly converted to
31
+ // brgemm.
32
+ uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
33
+ };
34
+
35
+ enum DimType { Batch, M, N, K };
36
+
37
+ // Extract the index of the given DimType in the DimType list
38
+ inline SmallVector<unsigned > extractDimTypeIdx (ArrayRef<DimType> tyList,
39
+ DimType ty) {
40
+ SmallVector<unsigned > idxList;
41
+ for (auto [idx, type] : llvm::enumerate (tyList)) {
42
+ if (type == ty) {
43
+ idxList.push_back (idx);
44
+ }
45
+ }
46
+ return idxList;
47
+ }
48
+
49
+ // Get the operand dim type for every operand for the given linalg op
50
+ inline FailureOr<SmallVector<SmallVector<DimType>>>
51
+ getOprandDimType (linalg::LinalgOp &linalgOp) {
52
+ // TODO: replace the linalgx op with generic op
53
+ if (llvm::isa<linalg::MatmulOp>(linalgOp)) {
54
+ return SmallVector<SmallVector<DimType>>{
55
+ SmallVector<DimType>{DimType::M, DimType::K},
56
+ SmallVector<DimType>{DimType::K, DimType::N},
57
+ SmallVector<DimType>{DimType::M, DimType::N}};
58
+ } else if (llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
59
+ return SmallVector<SmallVector<DimType>>{
60
+ SmallVector<DimType>{DimType::M, DimType::K},
61
+ SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
62
+ DimType::K},
63
+ SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
64
+ } else if (llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) {
65
+ return SmallVector<SmallVector<DimType>>{
66
+ SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
67
+ SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
68
+ DimType::K},
69
+ SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
70
+ } else if (llvm::isa<linalg::BatchMatmulOp>(linalgOp)) {
71
+ return SmallVector<SmallVector<DimType>>{
72
+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
73
+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
74
+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
75
+ } else if (llvm::isa<linalg::MatmulTransposeAOp>(linalgOp)) {
76
+ return SmallVector<SmallVector<DimType>>{
77
+ SmallVector<DimType>{DimType::K, DimType::M},
78
+ SmallVector<DimType>{DimType::K, DimType::N},
79
+ SmallVector<DimType>{DimType::M, DimType::N}};
80
+ } else if (llvm::isa<linalg::MatmulTransposeBOp>(linalgOp)) {
81
+ return SmallVector<SmallVector<DimType>>{
82
+ SmallVector<DimType>{DimType::M, DimType::K},
83
+ SmallVector<DimType>{DimType::N, DimType::K},
84
+ SmallVector<DimType>{DimType::M, DimType::N}};
85
+ } else if (llvm::isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
86
+ return SmallVector<SmallVector<DimType>>{
87
+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::M},
88
+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
89
+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
90
+ } else if (llvm::isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
91
+ return SmallVector<SmallVector<DimType>>{
92
+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
93
+ SmallVector<DimType>{DimType::Batch, DimType::N, DimType::K},
94
+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
95
+ }
96
+ return failure ();
97
+ }
98
+
99
+ // The analysis to extract the matmul configuration from the given linalg op
100
+ struct MatmulConfigAnalysis {
101
+ public:
102
+ explicit MatmulConfigAnalysis (Operation *root);
103
+ MatmulConfig getConfig () { return config; }
104
+
105
+ private:
106
+ MatmulConfig config;
107
+ };
108
+
109
+ } // namespace gc
110
+ } // namespace mlir
111
+
112
+ #endif
0 commit comments