Skip to content

Commit 95b3cc8

Browse files
authored
Merge branch 'main' into xurui/add_benchmark
2 parents 524bd7d + 8948c6b commit 95b3cc8

17 files changed

+3438
-59
lines changed

docs/deep_tile_matmul_design.md

Lines changed: 594 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
//===-- MatmulConfigAnalysis.h - the analysis for matmul config -*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
10+
#define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
11+
12+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
13+
#include "mlir/Dialect/DLTI/DLTI.h"
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
16+
17+
namespace mlir {
18+
namespace gc {
19+
20+
using namespace mlir;
21+
22+
// The configuration for matmul tiling
23+
// TODO: support batch matmul
24+
struct MatmulConfig {
25+
// The number of threads distributed to M, N, K
26+
uint32_t MThreads, NThreads, KThreads;
27+
// The outer block size for M, N, K which will be used to decide the loop tile
28+
// size in single thread
29+
uint32_t MBlock, NBlock, KBlock;
30+
// The innermost block size for M, N, K which will be directly converted to
31+
// brgemm.
32+
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
33+
};
34+
35+
enum DimType { Batch, M, N, K };
36+
37+
// Extract the index of the given DimType in the DimType list
38+
inline SmallVector<unsigned> extractDimTypeIdx(ArrayRef<DimType> tyList,
39+
DimType ty) {
40+
SmallVector<unsigned> idxList;
41+
for (auto [idx, type] : llvm::enumerate(tyList)) {
42+
if (type == ty) {
43+
idxList.push_back(idx);
44+
}
45+
}
46+
return idxList;
47+
}
48+
49+
// Get the operand dim type for every operand for the given linalg op
50+
inline FailureOr<SmallVector<SmallVector<DimType>>>
51+
getOprandDimType(linalg::LinalgOp &linalgOp) {
52+
// TODO: replace the linalgx op with generic op
53+
if (llvm::isa<linalg::MatmulOp>(linalgOp)) {
54+
return SmallVector<SmallVector<DimType>>{
55+
SmallVector<DimType>{DimType::M, DimType::K},
56+
SmallVector<DimType>{DimType::K, DimType::N},
57+
SmallVector<DimType>{DimType::M, DimType::N}};
58+
} else if (llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
59+
return SmallVector<SmallVector<DimType>>{
60+
SmallVector<DimType>{DimType::M, DimType::K},
61+
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
62+
DimType::K},
63+
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
64+
} else if (llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) {
65+
return SmallVector<SmallVector<DimType>>{
66+
SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
67+
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
68+
DimType::K},
69+
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
70+
} else if (llvm::isa<linalg::BatchMatmulOp>(linalgOp)) {
71+
return SmallVector<SmallVector<DimType>>{
72+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
73+
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
74+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
75+
} else if (llvm::isa<linalg::MatmulTransposeAOp>(linalgOp)) {
76+
return SmallVector<SmallVector<DimType>>{
77+
SmallVector<DimType>{DimType::K, DimType::M},
78+
SmallVector<DimType>{DimType::K, DimType::N},
79+
SmallVector<DimType>{DimType::M, DimType::N}};
80+
} else if (llvm::isa<linalg::MatmulTransposeBOp>(linalgOp)) {
81+
return SmallVector<SmallVector<DimType>>{
82+
SmallVector<DimType>{DimType::M, DimType::K},
83+
SmallVector<DimType>{DimType::N, DimType::K},
84+
SmallVector<DimType>{DimType::M, DimType::N}};
85+
} else if (llvm::isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
86+
return SmallVector<SmallVector<DimType>>{
87+
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::M},
88+
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
89+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
90+
} else if (llvm::isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
91+
return SmallVector<SmallVector<DimType>>{
92+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
93+
SmallVector<DimType>{DimType::Batch, DimType::N, DimType::K},
94+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
95+
}
96+
return failure();
97+
}
98+
99+
// The analysis to extract the matmul configuration from the given linalg op
100+
struct MatmulConfigAnalysis {
101+
public:
102+
explicit MatmulConfigAnalysis(Operation *root);
103+
MatmulConfig getConfig() { return config; }
104+
105+
private:
106+
MatmulConfig config;
107+
};
108+
109+
} // namespace gc
110+
} // namespace mlir
111+
112+
#endif

include/gc/Transforms/Passes.td

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,6 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14-
def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {
15-
let summary = "Tile linalg named operations.";
16-
let dependentDialects =
17-
["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"];
18-
}
19-
2014
#ifdef GC_HAS_ONEDNN_DIALECT
2115
def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
2216
let summary =
@@ -71,6 +65,18 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
7165
"Decide if enable cost model to control iterative fusion.">,
7266
ListOption<"defaultTileSize", "default-tile-size", "std::string",
7367
"Set default TileSize for the certain type of op, saying `matmul:{32,32}`">,
68+
];
69+
}
70+
def DeepTileContractionNamedOp
71+
: Pass<"deep-tile-contraction-named-op", "func::FuncOp"> {
72+
let summary = "Tile linalg contraction named operation deeply";
73+
let description =
74+
[{The pass tries to tile the linalg contraction named op deeply.}];
75+
let dependentDialects = [
76+
"func::FuncDialect",
77+
"arith::ArithDialect",
78+
"tensor::TensorDialect",
79+
"linalg::LinalgDialect",
7480
];
7581
}
7682

@@ -87,4 +93,17 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> {
8793
];
8894
}
8995

96+
def SinkOpIntoInnerLoop : Pass<"sink-op-into-inner-loop"> {
97+
let summary = "Sink operations into inner loops";
98+
let description = [{The pass tries to sink operations into inner loops as deep as possible to maximize the chance for outer loop optimization.
99+
}];
100+
let dependentDialects = [];
101+
}
102+
103+
def MergeNestedForall : Pass<"merge-nested-forall"> {
104+
let summary = "Merge nested scf.forall operations";
105+
let description = [{The pass tries to merge nested forall operations.}];
106+
let dependentDialects = ["scf::SCFDialect"];
107+
}
108+
90109
#endif // GC_DIALECT_GC_PASSES

lib/gc/Analysis/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
44

55
gc_add_mlir_library(GcAnalysis
66
TargetDescriptionAnalysis.cpp
7+
MatmulConfigAnalysis.cpp
78

89
DEPENDS
910
GraphCompilerPassIncGen
@@ -12,4 +13,4 @@ gc_add_mlir_library(GcAnalysis
1213
${mlir_dialect_libs}
1314
${MLIR_LINK_COMPONENTS}
1415
GcInterface
15-
)
16+
)

0 commit comments

Comments
 (0)