Skip to content

Commit ba356ab

Browse files
committed
Merge branch 'main' into xiaohui/vectorization
2 parents d49715c + 664024a commit ba356ab

File tree

18 files changed

+1314
-132
lines changed

18 files changed

+1314
-132
lines changed

cmake/imex-version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
6c2e414a953b9a118bce6adac21cf9d42630e674
1+
8209807be6148d81fda6f439a01b77696986dd3e

cmake/llvm-version-imex.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
f06563a5c0d239a6b98f74db522681613254ad08
1+
3191587666aa3d1e53966bc8876614c7197fac4f

cmake/llvm-version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
f06563a5c0d239a6b98f74db522681613254ad08
1+
3191587666aa3d1e53966bc8876614c7197fac4f
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===-- GpuOclRuntime.h - GPU OpenCL runtime --------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_GPUOCLRUNTIME_H
10+
#define GC_GPUOCLRUNTIME_H
11+
12+
namespace mlir::gc::gpu {
13+
constexpr char GPU_OCL_MALLOC[] = "gcGpuOclMalloc";
14+
constexpr char GPU_OCL_DEALLOC[] = "gcGpuOclDealloc";
15+
constexpr char GPU_OCL_MEMCPY[] = "gcGpuOclMemcpy";
16+
constexpr char GPU_OCL_KERNEL_CREATE[] = "gcGpuOclKernelCreate";
17+
constexpr char GPU_OCL_KERNEL_DESTROY[] = "gcGpuOclKernelDestroy";
18+
constexpr char GPU_OCL_KERNEL_LAUNCH[] = "gcGpuOclKernelLaunch";
19+
constexpr char GPU_OCL_MOD_DESTRUCTOR[] = "gcGpuOclModuleDestructor";
20+
} // namespace mlir::gc::gpu
21+
22+
#ifndef GC_GPU_OCL_CONST_ONLY
23+
24+
// TBD
25+
26+
#else
27+
#undef GC_GPU_OCL_CONST_ONLY
28+
#endif
29+
#endif

include/gc/Transforms/Passes.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ namespace func {
2020
class FuncOp;
2121
} // namespace func
2222

23-
2423
namespace LLVM {
2524
class LLVMDialect;
2625
}
@@ -116,7 +115,8 @@ void populateFrontendPasses(mlir::OpPassManager &);
116115
void populateCPUPipeline(mlir::OpPassManager &);
117116

118117
#ifdef GC_USE_IMEX
119-
void populateGPUPipeline(mlir::OpPassManager &);
118+
struct GPUPipelineOption;
119+
void populateGPUPipeline(mlir::OpPassManager &, const GPUPipelineOption &);
120120
#endif
121121

122122
#define GEN_PASS_DECL

include/gc/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,20 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
9393
"DPAS register block sizes MxNxK">,
9494
];
9595
}
96+
97+
def AddContextArg : Pass<"add-ctx-arg", "func::FuncOp"> {
98+
let summary = "Add a context argument.";
99+
let description = [{
100+
Add a new memref argument to the function, that could be used to pass some context.
101+
}];
102+
}
103+
104+
def GpuToGpuOcl : Pass<"gpu-to-gpuocl", "ModuleOp"> {
105+
let summary = "Convert the GPU operations to GpuOclRuntime calls.";
106+
let description = [{
107+
Convert the gpu alloc, dealloc, memcpy and launch operations to GpuOclRuntime calls.
108+
}];
109+
}
96110
#endif // GC_USE_IMEX
97111

98112
def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",

lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,24 @@ using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
5353
using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
5454
static std::shared_mutex g_brgemm_lock;
5555

56-
static std::vector<brgemm_desc_t> g_brgemm_desc_list;
57-
static std::vector<brgemm_kernel_t *> g_brgemm_kernel_list;
58-
static std::vector<std::unique_ptr<char[]>> g_brgemm_palette;
56+
struct brgemm_cache_info_t {
57+
brgemm_desc_t desc;
58+
brgemm_kernel_t *kernel;
59+
std::shared_ptr<char[]> palette;
60+
};
61+
62+
static std::vector<brgemm_cache_info_t> g_cache;
5963

6064
// TODO(haixin): use syscall to determine page size?
6165
static constexpr size_t SCRATCH_SIZE = 2 * 4096;
6266
// TODO(haixin): need to use custom thread management for scratch in the future?
6367
static thread_local char scratch[SCRATCH_SIZE] = {0};
6468

69+
static std::unordered_map<int64_t, brgemm_cache_info_t> &get_tl_cache() {
70+
thread_local std::unordered_map<int64_t, brgemm_cache_info_t> tl_cache;
71+
return tl_cache;
72+
}
73+
6574
extern "C" {
6675

6776
int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
@@ -93,33 +102,33 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
93102
brgemm_desc_set_attr(&desc, dnnl_attrs);
94103

95104
// TODO(haixin): Reuse identical palettes across kernels
96-
char *palette_buffer = nullptr;
105+
std::shared_ptr<char[]> palette_buffer;
97106
if (desc.is_tmm) {
98-
palette_buffer = new char[PALETTE_SIZE];
99-
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer);
107+
palette_buffer.reset(new char[PALETTE_SIZE]);
108+
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer.get());
100109
assert(status == dnnl::impl::status::success &&
101110
"Failed to initialize palette for BRGEMM");
102111
}
103112

104113
write_lock_guard_t g(g_brgemm_lock);
105-
g_brgemm_desc_list.push_back(desc);
106-
g_brgemm_kernel_list.push_back(kernel);
107-
g_brgemm_palette.emplace_back(palette_buffer);
108-
109-
return g_brgemm_desc_list.size() - 1;
114+
g_cache.push_back(brgemm_cache_info_t{desc, kernel, palette_buffer});
115+
return g_cache.size() - 1;
110116
}
111117

112118
void dnnl_brgemm_tileconfig(int64_t kernel_idx) {
113-
char *palette_buffer = nullptr;
114-
{
119+
assert(kernel_idx >= 0 && "Invalid kernel handler");
120+
auto &tl_cache = get_tl_cache();
121+
auto it = tl_cache.find(kernel_idx);
122+
if (it == tl_cache.end()) {
115123
read_lock_guard_t g(g_brgemm_lock);
116-
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() &&
117-
"Invalid kernel handler");
118-
brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx];
119-
if (!desc.is_tmm) {
120-
return;
121-
}
122-
palette_buffer = g_brgemm_palette[kernel_idx].get();
124+
assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler");
125+
it = tl_cache.insert({kernel_idx, g_cache[kernel_idx]}).first;
126+
}
127+
brgemm_desc_t &desc = it->second.desc;
128+
char *palette_buffer = it->second.palette.get();
129+
130+
if (!desc.is_tmm) {
131+
return;
123132
}
124133

125134
assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel");
@@ -137,24 +146,29 @@ void dnnl_brgemm_tilerelease() {
137146
void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset,
138147
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
139148
int num) {
140-
brgemm_kernel_t *kernel = nullptr;
141-
size_t A_offset_in_bytes;
142-
size_t B_offset_in_bytes;
143-
size_t C_offset_in_bytes;
144-
{
149+
auto &tl_cache = get_tl_cache();
150+
if (tl_cache.find(kernel_idx) == tl_cache.end()) {
145151
read_lock_guard_t g(g_brgemm_lock);
146-
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() &&
152+
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
147153
"Invalid kernel handler");
148-
149-
brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx];
150-
kernel = g_brgemm_kernel_list[kernel_idx];
151-
152-
A_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_a) * A_offset;
153-
B_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_b) * B_offset;
154-
C_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_c) * C_offset;
154+
auto updated_cache =
155+
tl_cache.insert(std::make_pair(kernel_idx, g_cache[kernel_idx]));
156+
assert(updated_cache.second && "insert into thread local cache");
155157
}
158+
auto it = tl_cache.find(kernel_idx);
159+
brgemm_kernel_t *kernel = it->second.kernel;
160+
brgemm_desc_t *desc_ptr = &it->second.desc;
156161

157162
assert(kernel && "Invalid brgemm kernel pointer");
163+
assert(desc_ptr && "Invalid brgemm descriptor pointer");
164+
165+
size_t A_offset_in_bytes =
166+
dnnl::impl::types::data_type_size(desc_ptr->dt_a) * A_offset;
167+
size_t B_offset_in_bytes =
168+
dnnl::impl::types::data_type_size(desc_ptr->dt_b) * B_offset;
169+
size_t C_offset_in_bytes =
170+
dnnl::impl::types::data_type_size(desc_ptr->dt_c) * C_offset;
171+
158172
char *A_arith = (char *)A;
159173
char *B_arith = (char *)B;
160174
char *C_arith = (char *)C;
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===-- AddContextArg.cpp - Add context argument ----------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "mlir/Conversion/Passes.h"
9+
#include "mlir/Dialect/Func/IR/FuncOps.h"
10+
11+
namespace mlir::gc {
12+
#define GEN_PASS_DECL_ADDCONTEXTARG
13+
#define GEN_PASS_DEF_ADDCONTEXTARG
14+
#include "gc/Transforms/Passes.h.inc"
15+
} // namespace mlir::gc
16+
17+
using namespace mlir;
18+
19+
namespace {
20+
struct AddContextArg final : gc::impl::AddContextArgBase<AddContextArg> {
21+
void runOnOperation() override {
22+
auto func = getOperation();
23+
if (func.isExternal()) {
24+
return;
25+
}
26+
27+
auto funcType = func.getFunctionType();
28+
auto argTypes = llvm::to_vector<8>(funcType.getInputs());
29+
auto resultTypes = llvm::to_vector<1>(funcType.getResults());
30+
auto ctx = func->getContext();
31+
auto newArgType = MemRefType::get({}, IntegerType::get(ctx, 8));
32+
argTypes.emplace_back(newArgType);
33+
auto newFuncType = FunctionType::get(ctx, argTypes, resultTypes);
34+
func.setType(newFuncType);
35+
func.getBody().front().addArgument(newArgType, func.getLoc());
36+
37+
// Find all function calls and append the last argument of the current
38+
// function to the call.
39+
auto module = func->getParentOfType<ModuleOp>();
40+
func.walk([&](func::CallOp call) {
41+
// If the function to be called is defined in the current module, then the
42+
// context arg will be added to this function signature either and, thus,
43+
// wee need add the context arg to the function call.
44+
if (auto callee = module.lookupSymbol<func::FuncOp>(call.getCallee());
45+
!callee || callee.isExternal()) {
46+
return;
47+
}
48+
auto args = llvm::to_vector<8>(call.getOperands());
49+
args.emplace_back(func.getArgument(func.getNumArguments() - 1));
50+
call->setOperands(args);
51+
});
52+
}
53+
};
54+
} // namespace

lib/gc/Transforms/GPU/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
gc_add_mlir_library(GcGpuPasses
2+
AddContextArg.cpp
3+
GpuToGpuOcl.cpp
24
LinalgToXeGPU.cpp
35
Pipeline.cpp
46

0 commit comments

Comments
 (0)