Skip to content

Commit d1bd46b

Browse files
Added RoPE unit test
1 parent 6c407ec commit d1bd46b

File tree

5 files changed

+496
-2
lines changed

5 files changed

+496
-2
lines changed

Diff for: include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h

+4
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ struct OclModuleBuilderOpts {
230230
bool enableObjectDump = false;
231231
ArrayRef<StringRef> sharedLibPaths = {};
232232
void (*pipeline)(OpPassManager &) = nullptr;
233+
void (*symbolMaper)(llvm::orc::SymbolMap &,
234+
llvm::orc::MangleAndInterner &) = nullptr;
233235
};
234236

235237
struct OclModuleBuilder {
@@ -259,6 +261,8 @@ struct OclModuleBuilder {
259261
const bool enableObjectDump;
260262
const ArrayRef<StringRef> sharedLibPaths;
261263
void (*const pipeline)(OpPassManager &);
264+
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
265+
symbolMap;
262266
const StringRef funcName;
263267
const ArrayRef<Type> argTypes;
264268
std::shared_mutex mux;

Diff for: lib/gc/ExecutionEngine/GPURuntime/ocl/GpuOclRuntime.cpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,18 @@ OclModuleBuilder::OclModuleBuilder(ModuleOp module,
824824
? opts.pipeline
825825
: [](OpPassManager &pm) { populateGPUPipeline(pm, {}); }),
826826
funcName(getFuncName(opts, mlirModule)),
827-
argTypes(getArgTypes(funcName, mlirModule)) {}
827+
argTypes(getArgTypes(funcName, mlirModule)) {
828+
if (opts.symbolMaper) {
829+
symbolMap = [mapper =
830+
opts.symbolMaper](llvm::orc::MangleAndInterner interner) {
831+
auto map = OclRuntime::Exports::symbolMap(interner);
832+
mapper(map, interner);
833+
return map;
834+
};
835+
} else {
836+
symbolMap = OclRuntime::Exports::symbolMap;
837+
}
838+
}
828839

829840
llvm::Expected<std::shared_ptr<const OclModule>>
830841
OclModuleBuilder::build(const OclRuntime &runtime) {
@@ -940,7 +951,7 @@ OclModuleBuilder::build(const OclRuntime::Ext &ext) {
940951
staticMain = createStaticMain(builder, mod, funcName, argTypes);
941952
auto expectedEng = ExecutionEngine::create(mod, opts);
942953
CHECKE(expectedEng, "Failed to create ExecutionEngine!");
943-
expectedEng->get()->registerSymbols(OclRuntime::Exports::symbolMap);
954+
expectedEng->get()->registerSymbols(symbolMap);
944955

945956
// Find all kernels and query the workgroup size
946957
size_t minSize = maxSize;

Diff for: test/mlir/test/gc/gpu-runner/rope.mlir

+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils %s | FileCheck %s
2+
3+
!dtype=i16
4+
!input_memref_type=memref<2x7x32x128x!dtype>
5+
!input_tensor_type=tensor<2x7x32x128x!dtype>
6+
!output_memref_type=memref<2x32x7x128x!dtype>
7+
!output_tensor_type=tensor<2x32x7x128x!dtype>
8+
!cos_sin_cache_memref_type=memref<1x1x7x128x!dtype>
9+
!cos_sin_cache_tensor_type=tensor<1x1x7x128x!dtype>
10+
!cos_sin_cache_tensor_shrink_type=tensor<1x1x7x128x!dtype>
11+
!pos_ids_memref_type=memref<1x7xindex>
12+
module @fragment_name {
13+
memref.global "private" constant @_all_zeroes : !output_memref_type = dense<0>
14+
func.func @rope1(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %out: !output_memref_type,
15+
%cos_cache : !cos_sin_cache_memref_type, %sin_cache : !cos_sin_cache_memref_type) {
16+
%input = bufferization.to_tensor %iinput restrict : !input_memref_type
17+
%cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type
18+
%sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type
19+
%pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
20+
%3 = tensor.empty(): !output_tensor_type
21+
%transpose_in = linalg.transpose ins(%input: !input_tensor_type) outs(%3:!output_tensor_type) permutation = [0, 2, 1, 3]
22+
23+
%c0 = arith.constant 0 : index
24+
%c3 = arith.constant 3 : index
25+
%cos_cache_slice = tensor.extract_slice %cos_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
26+
%cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
27+
%cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
28+
%pos_ids_index=tensor.expand_shape %pos_ids [[0],[1,2]] output_shape [1, 7, 1] : tensor<1x7xindex> into tensor<1x7x1xindex>
29+
30+
%cos_cache_slice4 = tensor.gather %cos_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>
31+
32+
%cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
33+
%cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>
34+
35+
36+
%cos_cache_slice7 = linalg.broadcast ins(%cos_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
37+
%input_apply_cos_cache = linalg.mul ins(%transpose_in, %cos_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
38+
39+
%head_dim = tensor.dim %transpose_in, %c3 : !output_tensor_type
40+
%c2 = arith.constant 2 : index
41+
%half_head_dim = arith.floordivsi %head_dim, %c2 : index
42+
%transpose_input_first_half = tensor.extract_slice %transpose_in[0, 0, 0, 0][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
43+
%transpose_input_second_half = tensor.extract_slice %transpose_in[0, 0, 0, %half_head_dim][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
44+
%cnegative1 = arith.constant dense<-1> : tensor<2x32x7x64x!dtype>
45+
%empty_tensor = tensor.empty() : tensor<2x32x7x64x!dtype>
46+
%transpose_input_second_half_opposite = linalg.mul ins(%transpose_input_second_half, %cnegative1: tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) outs(%empty_tensor: tensor<2x32x7x64x!dtype>) -> tensor<2x32x7x64x!dtype>
47+
48+
%transformed_input = tensor.concat dim(3) %transpose_input_second_half_opposite, %transpose_input_first_half : (tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) -> !output_tensor_type
49+
50+
%sin_cache_slice = tensor.extract_slice %sin_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
51+
%sin_cache_slice2 = tensor.collapse_shape %sin_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
52+
%sin_cache_slice3 = tensor.collapse_shape %sin_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
53+
%sin_cache_slice4 = tensor.gather %sin_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>
54+
55+
%sin_cache_slice5 = tensor.expand_shape %sin_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
56+
%sin_cache_slice6 = tensor.collapse_shape %sin_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>
57+
%sin_cache_slice7 = linalg.broadcast ins(%sin_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
58+
%input_apply_sin_cache = linalg.mul ins(%transformed_input, %sin_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
59+
60+
%result = linalg.add ins(%input_apply_cos_cache, %input_apply_sin_cache: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
61+
bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()
62+
return
63+
}
64+
65+
func.func @rope2(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %out: !output_memref_type,
66+
%cos_cache: !cos_sin_cache_memref_type, %sin_cache: !cos_sin_cache_memref_type) {
67+
%c0 = arith.constant 0 : index
68+
%c64 = arith.constant 64 : index
69+
%cm1 = arith.constant -1 : !dtype
70+
71+
%input = bufferization.to_tensor %iinput restrict : !input_memref_type
72+
%cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type
73+
%sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type
74+
%pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
75+
%tmp = tensor.empty(): !output_tensor_type
76+
77+
%result = linalg.generic {
78+
indexing_maps = [
79+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
80+
],
81+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
82+
} outs(%tmp : !output_tensor_type) {
83+
^bb0(%ignore: !dtype):
84+
%i0 = linalg.index 0 : index
85+
%i1 = linalg.index 1 : index
86+
%i2 = linalg.index 2 : index
87+
%i3 = linalg.index 3 : index
88+
%pos = tensor.extract %pos_ids[%c0, %i2] : tensor<1x7xindex>
89+
%cos = tensor.extract %cos_cache_tensor[%c0, %c0, %pos, %i3] : !cos_sin_cache_tensor_type
90+
%sin = tensor.extract %sin_cache_tensor[%c0, %c0, %pos, %i3] : !cos_sin_cache_tensor_type
91+
%in = tensor.extract %input[%i0, %i2, %i1, %i3] : !input_tensor_type
92+
%cos_val = arith.muli %cos, %in : !dtype
93+
94+
%cond = arith.cmpi slt, %i3, %c64 : index
95+
%sin_val = scf.if %cond -> (!dtype) {
96+
%i3_plus_64 = arith.addi %i3, %c64 : index
97+
%v = tensor.extract %input[%i0, %i2, %i1, %i3_plus_64] : !input_tensor_type
98+
%minusv = arith.muli %cm1, %v : !dtype
99+
%mul = arith.muli %sin, %minusv : !dtype
100+
scf.yield %mul : !dtype
101+
} else {
102+
%i3_minus_64 = arith.addi %i3, %c64 : index
103+
%v = tensor.extract %input[%i0, %i2, %i1, %i3_minus_64] : !input_tensor_type
104+
%mul = arith.muli %sin, %v : !dtype
105+
scf.yield %mul : !dtype
106+
}
107+
108+
%sum = arith.addi %cos_val, %sin_val : !dtype
109+
linalg.yield %sum : !dtype
110+
} -> !output_tensor_type
111+
112+
bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()
113+
return
114+
}
115+
116+
func.func @main() {
117+
%in_tmp = tensor.empty(): !input_tensor_type
118+
%input_values = linalg.generic {
119+
indexing_maps = [
120+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
121+
],
122+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
123+
} outs(%in_tmp : !input_tensor_type) {
124+
^bb0(%ignore: !dtype):
125+
%i3 = linalg.index 3 : index
126+
%val = arith.index_cast %i3 : index to !dtype
127+
linalg.yield %val : !dtype
128+
} -> !input_tensor_type
129+
%inp = memref.alloc() {alignment = 64 : i64} : !input_memref_type
130+
bufferization.materialize_in_destination %input_values in restrict writable %inp : (!input_tensor_type, !input_memref_type) -> ()
131+
132+
%ipos_ids_tmp = tensor.empty() : tensor<1x7xindex>
133+
%ipos_ids_values = linalg.generic {
134+
indexing_maps = [
135+
affine_map<(d0, d1) -> (d0, d1)>
136+
],
137+
iterator_types = ["parallel", "parallel"]
138+
} outs(%ipos_ids_tmp : tensor<1x7xindex>) {
139+
^bb0(%ignore: index):
140+
%c6 = arith.constant 6 : index
141+
%i1 = linalg.index 1 : index
142+
%val = arith.subi %c6, %i1 : index
143+
linalg.yield %i1 : index
144+
} -> tensor<1x7xindex>
145+
%ipos_ids = memref.alloc() {alignment = 64 : i64} : !pos_ids_memref_type
146+
bufferization.materialize_in_destination %ipos_ids_values in restrict writable %ipos_ids : (tensor<1x7xindex>, !pos_ids_memref_type) -> ()
147+
148+
%cos_cache_tmp = tensor.empty() : !cos_sin_cache_tensor_type
149+
%sin_cache_tmp = tensor.empty() : !cos_sin_cache_tensor_type
150+
%cos_cache_values, %sin_cache_values = linalg.generic {
151+
indexing_maps = [
152+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
153+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
154+
],
155+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
156+
} outs(%cos_cache_tmp, %sin_cache_tmp : !cos_sin_cache_tensor_type, !cos_sin_cache_tensor_type) {
157+
^bb0(%ignore_cos: !dtype, %ignore_sin: !dtype):
158+
%c3 = arith.constant 3 : !dtype
159+
%c2 = arith.constant 2 : !dtype
160+
%i3 = linalg.index 3 : index
161+
%val = arith.index_cast %i3 : index to !dtype
162+
%cos = arith.addi %c3, %val : !dtype
163+
%sin = arith.addi %c2, %val : !dtype
164+
linalg.yield %cos, %sin : !dtype, !dtype
165+
} -> (!cos_sin_cache_tensor_type, !cos_sin_cache_tensor_type)
166+
%cos_cache = memref.alloc() {alignment = 64 : i64} : !cos_sin_cache_memref_type
167+
%sin_cache = memref.alloc() {alignment = 64 : i64} : !cos_sin_cache_memref_type
168+
bufferization.materialize_in_destination %cos_cache_values in restrict writable %cos_cache : (!cos_sin_cache_tensor_type, !cos_sin_cache_memref_type) -> ()
169+
bufferization.materialize_in_destination %sin_cache_values in restrict writable %sin_cache : (!cos_sin_cache_tensor_type, !cos_sin_cache_memref_type) -> ()
170+
171+
%out1 = memref.alloc() {alignment = 64 : i64} : !output_memref_type
172+
%start1 = call @nanoTime() : () -> i64
173+
func.call @rope1(%inp, %ipos_ids, %out1, %cos_cache, %cos_cache) : (!input_memref_type, !pos_ids_memref_type, !output_memref_type, !cos_sin_cache_memref_type, !cos_sin_cache_memref_type) -> ()
174+
%end1 = call @nanoTime() : () -> i64
175+
%time1 = arith.subi %end1, %start1 : i64
176+
177+
%out2 = memref.alloc() {alignment = 64 : i64} : !output_memref_type
178+
%start2 = call @nanoTime() : () -> i64
179+
func.call @rope2(%inp, %ipos_ids, %out2, %cos_cache, %cos_cache) : (!input_memref_type, !pos_ids_memref_type, !output_memref_type, !cos_sin_cache_memref_type, !cos_sin_cache_memref_type) -> ()
180+
%end2 = call @nanoTime() : () -> i64
181+
%time2 = arith.subi %end2, %start2 : i64
182+
183+
%out1_tensor = bufferization.to_tensor %out1 restrict : !output_memref_type
184+
%out2_tensor = bufferization.to_tensor %out2 restrict : !output_memref_type
185+
%out_buf = tensor.empty(): !output_tensor_type
186+
%out_tensor = linalg.sub ins(%out1_tensor, %out2_tensor : !output_tensor_type, !output_tensor_type)
187+
outs(%out_buf : !output_tensor_type) -> !output_tensor_type
188+
%out = memref.alloc() {alignment = 64 : i64} : !output_memref_type
189+
bufferization.materialize_in_destination %out_tensor in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()
190+
191+
// %cast = memref.cast %out : !output_memref_type to memref<*xi16>
192+
// call @printMemrefI16(%cast) : (memref<*xi16>) -> ()
193+
194+
// CHECK: [[TIME1:[0-9]+]]
195+
llvm.call @printI64(%time1) : (i64) -> ()
196+
llvm.call @printNewline() : () -> ()
197+
// CHECK: [[TIME2:[0-9]+]]
198+
llvm.call @printI64(%time2) : (i64) -> ()
199+
llvm.call @printNewline() : () -> ()
200+
201+
%all_zeroes = memref.get_global @_all_zeroes : !output_memref_type
202+
%cast_all_zeroes = memref.cast %all_zeroes : !output_memref_type to memref<*xi16>
203+
%cast_out = memref.cast %out : !output_memref_type to memref<*xi16>
204+
%cmp = call @verifyMemRefI16(%cast_all_zeroes, %cast_out) : (memref<*xi16>, memref<*xi16>) -> (i64)
205+
// CHECK: 0
206+
llvm.call @printI64(%cmp) : (i64) -> ()
207+
llvm.call @printNewline() : () -> ()
208+
209+
return
210+
}
211+
212+
func.func private @printMemrefI16(%ptr : memref<*xi16>)
213+
func.func private @verifyMemRefI16(%a : memref<*xi16>, %b : memref<*xi16>) -> i64 attributes { llvm.emit_c_interface }
214+
func.func private @nanoTime() -> i64
215+
llvm.func @printI64(i64)
216+
llvm.func @printNewline()
217+
218+
}

Diff for: test/mlir/unittests/ExecutionEngine/GPU/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ target_link_libraries(GCExecutionEngineGpuTests
55
PRIVATE
66
GcJitWrapper
77
GcGpuOclRuntime
8+
mlir_c_runner_utils
89
)

0 commit comments

Comments
 (0)