Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RoPE unit test #423

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h
Original file line number Diff line number Diff line change
@@ -229,7 +229,9 @@ struct OclModuleBuilderOpts {
bool spirvDump = false;
bool enableObjectDump = false;
ArrayRef<StringRef> sharedLibPaths = {};
void (*pipeline)(OpPassManager &) = nullptr;
std::function<void(OpPassManager &)> pipeline;
std::function<void(llvm::orc::SymbolMap &, llvm::orc::MangleAndInterner &)>
symbolMaper;
};

struct OclModuleBuilder {
@@ -258,7 +260,8 @@ struct OclModuleBuilder {
const bool spirvDump;
const bool enableObjectDump;
const ArrayRef<StringRef> sharedLibPaths;
void (*const pipeline)(OpPassManager &);
std::function<void(OpPassManager &)> pipeline;
std::function<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> symbolMap;
const StringRef funcName;
const ArrayRef<Type> argTypes;
std::shared_mutex mux;
15 changes: 13 additions & 2 deletions lib/gc/ExecutionEngine/GPURuntime/ocl/GpuOclRuntime.cpp
Original file line number Diff line number Diff line change
@@ -824,7 +824,18 @@ OclModuleBuilder::OclModuleBuilder(ModuleOp module,
? opts.pipeline
: [](OpPassManager &pm) { populateGPUPipeline(pm, {}); }),
funcName(getFuncName(opts, mlirModule)),
argTypes(getArgTypes(funcName, mlirModule)) {}
argTypes(getArgTypes(funcName, mlirModule)) {
if (opts.symbolMaper) {
symbolMap = [mapper =
opts.symbolMaper](llvm::orc::MangleAndInterner interner) {
auto map = OclRuntime::Exports::symbolMap(interner);
mapper(map, interner);
return map;
};
} else {
symbolMap = OclRuntime::Exports::symbolMap;
}
}

llvm::Expected<std::shared_ptr<const OclModule>>
OclModuleBuilder::build(const OclRuntime &runtime) {
@@ -940,7 +951,7 @@ OclModuleBuilder::build(const OclRuntime::Ext &ext) {
staticMain = createStaticMain(builder, mod, funcName, argTypes);
auto expectedEng = ExecutionEngine::create(mod, opts);
CHECKE(expectedEng, "Failed to create ExecutionEngine!");
expectedEng->get()->registerSymbols(OclRuntime::Exports::symbolMap);
expectedEng->get()->registerSymbols(symbolMap);

// Find all kernels and query the workgroup size
size_t minSize = maxSize;
3 changes: 3 additions & 0 deletions lib/gc/Transforms/GPU/Pipeline.cpp
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@ void populateGPUPipeline(OpPassManager &pm,
pm.addPass(createDecomposeTensorOperation());
pm.addNestedPass<func::FuncOp>(createGpuTilingAndFusion());
pm.addPass(createCanonicalizerPass());
// pm.addPass(createPrintIRPass());

pm.addPass(bufferization::createEmptyTensorEliminationPass());
pm.addPass(bufferization::createEmptyTensorToAllocTensorPass());
@@ -123,6 +124,8 @@ void populateGPUPipeline(OpPassManager &pm,
pm.addPass(createLowerAffinePass());
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addPass(createReconcileUnrealizedCastsPass());

// pm.addPass(createPrintIRPass());
}

void registerGPUPipeline() {
218 changes: 218 additions & 0 deletions test/mlir/test/gc/gpu-runner/rope.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils %s | FileCheck %s

!dtype=i16
!input_memref_type=memref<2x7x32x128x!dtype>
!input_tensor_type=tensor<2x7x32x128x!dtype>
!output_memref_type=memref<2x32x7x128x!dtype>
!output_tensor_type=tensor<2x32x7x128x!dtype>
!cos_sin_cache_memref_type=memref<1x1x7x128x!dtype>
!cos_sin_cache_tensor_type=tensor<1x1x7x128x!dtype>
!cos_sin_cache_tensor_shrink_type=tensor<1x1x7x128x!dtype>
!pos_ids_memref_type=memref<1x7xindex>
module @fragment_name {
memref.global "private" constant @_all_zeroes : !output_memref_type = dense<0>
func.func @rope1(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %out: !output_memref_type,
%cos_cache : !cos_sin_cache_memref_type, %sin_cache : !cos_sin_cache_memref_type) {
%input = bufferization.to_tensor %iinput restrict : !input_memref_type
%cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type
%sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type
%pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
%3 = tensor.empty(): !output_tensor_type
%transpose_in = linalg.transpose ins(%input: !input_tensor_type) outs(%3:!output_tensor_type) permutation = [0, 2, 1, 3]

%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%cos_cache_slice = tensor.extract_slice %cos_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
%cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
%cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
%pos_ids_index=tensor.expand_shape %pos_ids [[0],[1,2]] output_shape [1, 7, 1] : tensor<1x7xindex> into tensor<1x7x1xindex>

%cos_cache_slice4 = tensor.gather %cos_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>

%cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
%cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>


%cos_cache_slice7 = linalg.broadcast ins(%cos_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
%input_apply_cos_cache = linalg.mul ins(%transpose_in, %cos_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type

%head_dim = tensor.dim %transpose_in, %c3 : !output_tensor_type
%c2 = arith.constant 2 : index
%half_head_dim = arith.floordivsi %head_dim, %c2 : index
%transpose_input_first_half = tensor.extract_slice %transpose_in[0, 0, 0, 0][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
%transpose_input_second_half = tensor.extract_slice %transpose_in[0, 0, 0, %half_head_dim][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
%cnegative1 = arith.constant dense<-1> : tensor<2x32x7x64x!dtype>
%empty_tensor = tensor.empty() : tensor<2x32x7x64x!dtype>
%transpose_input_second_half_opposite = linalg.mul ins(%transpose_input_second_half, %cnegative1: tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) outs(%empty_tensor: tensor<2x32x7x64x!dtype>) -> tensor<2x32x7x64x!dtype>

%transformed_input = tensor.concat dim(3) %transpose_input_second_half_opposite, %transpose_input_first_half : (tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) -> !output_tensor_type

%sin_cache_slice = tensor.extract_slice %sin_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
%sin_cache_slice2 = tensor.collapse_shape %sin_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
%sin_cache_slice3 = tensor.collapse_shape %sin_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
%sin_cache_slice4 = tensor.gather %sin_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>

%sin_cache_slice5 = tensor.expand_shape %sin_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
%sin_cache_slice6 = tensor.collapse_shape %sin_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>
%sin_cache_slice7 = linalg.broadcast ins(%sin_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
%input_apply_sin_cache = linalg.mul ins(%transformed_input, %sin_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type

%result = linalg.add ins(%input_apply_cos_cache, %input_apply_sin_cache: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()
return
}

func.func @rope2(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %out: !output_memref_type,
%cos_cache: !cos_sin_cache_memref_type, %sin_cache: !cos_sin_cache_memref_type) {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%cm1 = arith.constant -1 : !dtype

%input = bufferization.to_tensor %iinput restrict : !input_memref_type
%cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type
%sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type
%pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
%tmp = tensor.empty(): !output_tensor_type

%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} outs(%tmp : !output_tensor_type) {
^bb0(%ignore: !dtype):
%i0 = linalg.index 0 : index
%i1 = linalg.index 1 : index
%i2 = linalg.index 2 : index
%i3 = linalg.index 3 : index
%pos = tensor.extract %pos_ids[%c0, %i2] : tensor<1x7xindex>
%cos = tensor.extract %cos_cache_tensor[%c0, %c0, %pos, %i3] : !cos_sin_cache_tensor_type
%sin = tensor.extract %sin_cache_tensor[%c0, %c0, %pos, %i3] : !cos_sin_cache_tensor_type
%in = tensor.extract %input[%i0, %i2, %i1, %i3] : !input_tensor_type
%cos_val = arith.muli %cos, %in : !dtype

%cond = arith.cmpi slt, %i3, %c64 : index
%sin_val = scf.if %cond -> (!dtype) {
%i3_plus_64 = arith.addi %i3, %c64 : index
%v = tensor.extract %input[%i0, %i2, %i1, %i3_plus_64] : !input_tensor_type
%minusv = arith.muli %cm1, %v : !dtype
%mul = arith.muli %sin, %minusv : !dtype
scf.yield %mul : !dtype
} else {
%i3_minus_64 = arith.addi %i3, %c64 : index
%v = tensor.extract %input[%i0, %i2, %i1, %i3_minus_64] : !input_tensor_type
%mul = arith.muli %sin, %v : !dtype
scf.yield %mul : !dtype
}

%sum = arith.addi %cos_val, %sin_val : !dtype
linalg.yield %sum : !dtype
} -> !output_tensor_type

bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()
return
}

func.func @main() {
%in_tmp = tensor.empty(): !input_tensor_type
%input_values = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} outs(%in_tmp : !input_tensor_type) {
^bb0(%ignore: !dtype):
%i3 = linalg.index 3 : index
%val = arith.index_cast %i3 : index to !dtype
linalg.yield %val : !dtype
} -> !input_tensor_type
%inp = memref.alloc() {alignment = 64 : i64} : !input_memref_type
bufferization.materialize_in_destination %input_values in restrict writable %inp : (!input_tensor_type, !input_memref_type) -> ()

%ipos_ids_tmp = tensor.empty() : tensor<1x7xindex>
%ipos_ids_values = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel"]
} outs(%ipos_ids_tmp : tensor<1x7xindex>) {
^bb0(%ignore: index):
%c6 = arith.constant 6 : index
%i1 = linalg.index 1 : index
%val = arith.subi %c6, %i1 : index
linalg.yield %i1 : index
} -> tensor<1x7xindex>
%ipos_ids = memref.alloc() {alignment = 64 : i64} : !pos_ids_memref_type
bufferization.materialize_in_destination %ipos_ids_values in restrict writable %ipos_ids : (tensor<1x7xindex>, !pos_ids_memref_type) -> ()

%cos_cache_tmp = tensor.empty() : !cos_sin_cache_tensor_type
%sin_cache_tmp = tensor.empty() : !cos_sin_cache_tensor_type
%cos_cache_values, %sin_cache_values = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} outs(%cos_cache_tmp, %sin_cache_tmp : !cos_sin_cache_tensor_type, !cos_sin_cache_tensor_type) {
^bb0(%ignore_cos: !dtype, %ignore_sin: !dtype):
%c3 = arith.constant 3 : !dtype
%c2 = arith.constant 2 : !dtype
%i3 = linalg.index 3 : index
%val = arith.index_cast %i3 : index to !dtype
%cos = arith.addi %c3, %val : !dtype
%sin = arith.addi %c2, %val : !dtype
linalg.yield %cos, %sin : !dtype, !dtype
} -> (!cos_sin_cache_tensor_type, !cos_sin_cache_tensor_type)
%cos_cache = memref.alloc() {alignment = 64 : i64} : !cos_sin_cache_memref_type
%sin_cache = memref.alloc() {alignment = 64 : i64} : !cos_sin_cache_memref_type
bufferization.materialize_in_destination %cos_cache_values in restrict writable %cos_cache : (!cos_sin_cache_tensor_type, !cos_sin_cache_memref_type) -> ()
bufferization.materialize_in_destination %sin_cache_values in restrict writable %sin_cache : (!cos_sin_cache_tensor_type, !cos_sin_cache_memref_type) -> ()

%out1 = memref.alloc() {alignment = 64 : i64} : !output_memref_type
%start1 = call @nanoTime() : () -> i64
func.call @rope1(%inp, %ipos_ids, %out1, %cos_cache, %cos_cache) : (!input_memref_type, !pos_ids_memref_type, !output_memref_type, !cos_sin_cache_memref_type, !cos_sin_cache_memref_type) -> ()
%end1 = call @nanoTime() : () -> i64
%time1 = arith.subi %end1, %start1 : i64

%out2 = memref.alloc() {alignment = 64 : i64} : !output_memref_type
%start2 = call @nanoTime() : () -> i64
func.call @rope2(%inp, %ipos_ids, %out2, %cos_cache, %cos_cache) : (!input_memref_type, !pos_ids_memref_type, !output_memref_type, !cos_sin_cache_memref_type, !cos_sin_cache_memref_type) -> ()
%end2 = call @nanoTime() : () -> i64
%time2 = arith.subi %end2, %start2 : i64

%out1_tensor = bufferization.to_tensor %out1 restrict : !output_memref_type
%out2_tensor = bufferization.to_tensor %out2 restrict : !output_memref_type
%out_buf = tensor.empty(): !output_tensor_type
%out_tensor = linalg.sub ins(%out1_tensor, %out2_tensor : !output_tensor_type, !output_tensor_type)
outs(%out_buf : !output_tensor_type) -> !output_tensor_type
%out = memref.alloc() {alignment = 64 : i64} : !output_memref_type
bufferization.materialize_in_destination %out_tensor in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()

// %cast = memref.cast %out : !output_memref_type to memref<*xi16>
// call @printMemrefI16(%cast) : (memref<*xi16>) -> ()

// CHECK: [[TIME1:[0-9]+]]
llvm.call @printI64(%time1) : (i64) -> ()
llvm.call @printNewline() : () -> ()
// CHECK: [[TIME2:[0-9]+]]
llvm.call @printI64(%time2) : (i64) -> ()
llvm.call @printNewline() : () -> ()

%all_zeroes = memref.get_global @_all_zeroes : !output_memref_type
%cast_all_zeroes = memref.cast %all_zeroes : !output_memref_type to memref<*xi16>
%cast_out = memref.cast %out : !output_memref_type to memref<*xi16>
%cmp = call @verifyMemRefI16(%cast_all_zeroes, %cast_out) : (memref<*xi16>, memref<*xi16>) -> (i64)
// CHECK: 0
llvm.call @printI64(%cmp) : (i64) -> ()
llvm.call @printNewline() : () -> ()

return
}

func.func private @printMemrefI16(%ptr : memref<*xi16>)
func.func private @verifyMemRefI16(%a : memref<*xi16>, %b : memref<*xi16>) -> i64 attributes { llvm.emit_c_interface }
func.func private @nanoTime() -> i64
llvm.func @printI64(i64)
llvm.func @printNewline()

}
1 change: 1 addition & 0 deletions test/mlir/unittests/ExecutionEngine/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -5,4 +5,5 @@ target_link_libraries(GCExecutionEngineGpuTests
PRIVATE
GcJitWrapper
GcGpuOclRuntime
mlir_c_runner_utils
)
260 changes: 260 additions & 0 deletions test/mlir/unittests/ExecutionEngine/GPU/GpuOclRuntimeTest.cpp
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
#include "gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h"
#include "gc/Utils/Error.h"

#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "mlir/ExecutionEngine/MemRefUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
@@ -77,6 +78,136 @@ module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"G
}
)mlir";

constexpr char rope1[] = R"mlir(
!dtype=i16
!input_memref_type=memref<2x7x32x128x!dtype>
!input_tensor_type=tensor<2x7x32x128x!dtype>
!output_memref_type=memref<2x32x7x128x!dtype>
!output_tensor_type=tensor<2x32x7x128x!dtype>
!cos_sin_cache_tensor_shrink_type=tensor<1x1x7x128x!dtype>
!cos_sin_cache_memref_type=memref<1x1x7x128x!dtype>
!cos_sin_cache_tensor_type=tensor<1x1x7x128x!dtype>
!pos_ids_memref_type=memref<1x7xindex>
module @fragment_name {
func.func @rope1(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %out: !output_memref_type,
%cos_cache : !cos_sin_cache_memref_type, %sin_cache : !cos_sin_cache_memref_type) {
%input = bufferization.to_tensor %iinput restrict : !input_memref_type
%cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type
%sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type
%pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
%3 = tensor.empty(): !output_tensor_type
%transpose_in = linalg.transpose ins(%input: !input_tensor_type) outs(%3:!output_tensor_type) permutation = [0, 2, 1, 3]
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%cos_cache_slice = tensor.extract_slice %cos_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
%cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
%cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
%pos_ids_index=tensor.expand_shape %pos_ids [[0],[1,2]] output_shape [1, 7, 1] : tensor<1x7xindex> into tensor<1x7x1xindex>
%cos_cache_slice4 = tensor.gather %cos_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>
%cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
%cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>
%cos_cache_slice7 = linalg.broadcast ins(%cos_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
%input_apply_cos_cache = linalg.mul ins(%transpose_in, %cos_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
%head_dim = tensor.dim %transpose_in, %c3 : !output_tensor_type
%c2 = arith.constant 2 : index
%half_head_dim = arith.floordivsi %head_dim, %c2 : index
%transpose_input_first_half = tensor.extract_slice %transpose_in[0, 0, 0, 0][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
%transpose_input_second_half = tensor.extract_slice %transpose_in[0, 0, 0, %half_head_dim][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
%cnegative1 = arith.constant dense<-1> : tensor<2x32x7x64x!dtype>
%empty_tensor = tensor.empty() : tensor<2x32x7x64x!dtype>
%transpose_input_second_half_opposite = linalg.mul ins(%transpose_input_second_half, %cnegative1: tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) outs(%empty_tensor: tensor<2x32x7x64x!dtype>) -> tensor<2x32x7x64x!dtype>
%transformed_input = tensor.concat dim(3) %transpose_input_second_half_opposite, %transpose_input_first_half : (tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) -> !output_tensor_type
%sin_cache_slice = tensor.extract_slice %sin_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
%sin_cache_slice2 = tensor.collapse_shape %sin_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
%sin_cache_slice3 = tensor.collapse_shape %sin_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
%sin_cache_slice4 = tensor.gather %sin_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>
%sin_cache_slice5 = tensor.expand_shape %sin_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
%sin_cache_slice6 = tensor.collapse_shape %sin_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>
%sin_cache_slice7 = linalg.broadcast ins(%sin_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
%input_apply_sin_cache = linalg.mul ins(%transformed_input, %sin_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
%result = linalg.add ins(%input_apply_cos_cache, %input_apply_sin_cache: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()
return
}
}
)mlir";
constexpr char rope2[] = R"mlir(
!dtype=i16
!input_memref_type=memref<2x7x32x128x!dtype>
!input_tensor_type=tensor<2x7x32x128x!dtype>
!output_memref_type=memref<2x32x7x128x!dtype>
!output_tensor_type=tensor<2x32x7x128x!dtype>
!cos_sin_cache_memref_type=memref<1x1x7x128x!dtype>
!cos_sin_cache_tensor_type=tensor<1x1x7x128x!dtype>
!pos_ids_memref_type=memref<1x7xindex>
module @fragment_name {
func.func @rope2(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %out: !output_memref_type,
%cos_cache: !cos_sin_cache_memref_type, %sin_cache: !cos_sin_cache_memref_type) {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%cm1 = arith.constant -1 : !dtype
%input = bufferization.to_tensor %iinput restrict : !input_memref_type
%cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type
%sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type
%pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
%tmp = tensor.empty(): !output_tensor_type
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} outs(%tmp : !output_tensor_type) {
^bb0(%ignore: !dtype):
%i0 = linalg.index 0 : index
%i1 = linalg.index 1 : index
%i2 = linalg.index 2 : index
%i3 = linalg.index 3 : index
%pos = tensor.extract %pos_ids[%c0, %i2] : tensor<1x7xindex>
%cos = tensor.extract %cos_cache_tensor[%c0, %c0, %pos, %i3] : !cos_sin_cache_tensor_type
%sin = tensor.extract %sin_cache_tensor[%c0, %c0, %pos, %i3] : !cos_sin_cache_tensor_type
%in = tensor.extract %input[%i0, %i2, %i1, %i3] : !input_tensor_type
%cos_val = arith.muli %cos, %in : !dtype
%cond = arith.cmpi slt, %i3, %c64 : index
%sin_val = scf.if %cond -> (!dtype) {
%i3_plus_64 = arith.addi %i3, %c64 : index
%v = tensor.extract %input[%i0, %i2, %i1, %i3_plus_64] : !input_tensor_type
%minusv = arith.muli %cm1, %v : !dtype
%mul = arith.muli %sin, %minusv : !dtype
scf.yield %mul : !dtype
} else {
%i3_minus_64 = arith.addi %i3, %c64 : index
%v = tensor.extract %input[%i0, %i2, %i1, %i3_minus_64] : !input_tensor_type
%mul = arith.muli %sin, %v : !dtype
scf.yield %mul : !dtype
}
%sum = arith.addi %cos_val, %sin_val : !dtype
linalg.yield %sum : !dtype
} -> !output_tensor_type
bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()
return
}
}
)mlir";

struct TestBase {
OclRuntime runtime = gcGetOrReport(OclRuntime::get());
cl_command_queue queue = gcGetOrReport(runtime.createQueue());
@@ -230,3 +361,132 @@ TEST(GpuOclRuntime, TestMatmulAddStatic) {
} test;
test.test(matmulAddStatic);
}

TEST(GpuOclRuntime, TestRope) {
struct Test : TestBase {
int16_t *inputs;
int16_t *ipos;
int16_t *outputs;
int16_t *cosCache;
int16_t *sinCache;

explicit Test(bool sharedMem) {
auto memAlloc = [&](size_t size) {
return gcGetOrReport(sharedMem ? runtime.usmNewShared<int16_t>(size)
: runtime.usmNewDev<int16_t>(size));
};

{
int16_t inputsCpu[2][7][32][128];
size_t bufLen = sizeof(inputsCpu) / sizeof(int16_t);
inputs = memAlloc(bufLen);
outputs = memAlloc(bufLen);
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 7; j++) {
for (int k = 0; k < 32; k++) {
for (int l = 0; l < 128; l++) {
inputsCpu[i][j][k][l] = static_cast<int16_t>(i + j + l + 1);
}
}
}
}
assert(runtime.usmCpy(ctx, inputsCpu, inputs, bufLen));
}

{
int16_t cosCacheCpu[1][1][7][128];
int16_t sinCacheCpu[1][1][7][128];
size_t bufLen = sizeof(cosCacheCpu) / sizeof(int16_t);
cosCache = memAlloc(bufLen);
sinCache = memAlloc(bufLen);
for (int i = 0; i < 1; i++) {
for (int j = 0; j < 1; j++) {
for (int k = 0; k < 7; k++) {
for (int l = 0; l < 128; l++) {
cosCacheCpu[i][j][k][l] = static_cast<int16_t>(i + j + l + 3);
sinCacheCpu[i][j][k][l] = static_cast<int16_t>(i + j + l + 2);
}
}
}
}
assert(runtime.usmCpy(ctx, cosCacheCpu, cosCache, bufLen));
assert(runtime.usmCpy(ctx, sinCacheCpu, sinCache, bufLen));
}

int16_t iposCpu[]{6, 5, 4, 3, 2, 1, 0};
ipos = memAlloc(7);
assert(runtime.usmCpy(ctx, iposCpu, ipos,
sizeof(iposCpu) / sizeof(int16_t)));
}

~Test() override {
assert(runtime.usmFree(inputs));
assert(runtime.usmFree(ipos));
assert(runtime.usmFree(outputs));
assert(runtime.usmFree(cosCache));
assert(runtime.usmFree(sinCache));
}

void exec(std::shared_ptr<const OclModule> &mod) override {
StaticExecutor<3> exec(mod);
exec.arg(inputs);
exec.arg(ipos);
exec.arg(outputs);
exec.arg(cosCache);
exec.arg(sinCache);

auto start =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::high_resolution_clock::now().time_since_epoch())
.count();
exec(ctx);
assert(ctx.finish());
auto end =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::high_resolution_clock::now().time_since_epoch())
.count();
gcLogD("Execution time: ", end - start, " ns");
}

void test(const char *code, int16_t (&outputsCpu)[2][32][7][128]) {
OclModuleBuilderOpts opts;
opts.symbolMaper = [](llvm::orc::SymbolMap &map,
llvm::orc::MangleAndInterner &interner) {
map.try_emplace(interner("memrefCopy"),
llvm::orc::ExecutorAddr::fromPtr(&memrefCopy),
llvm::JITSymbolFlags::Exported);
};
OclModuleBuilder builder(parse(code), opts);
auto mod = gcGetOrReport(builder.build(runtime));
exec(mod);

assert(runtime.usmCpy(ctx, outputs, outputsCpu,
sizeof(outputsCpu) / sizeof(int16_t)));
assert(ctx.finish());
}
};

int16_t outputs1[2][32][7][128];
int16_t outputs2[2][32][7][128];

{
Test test(true);
test.test(rope1, outputs1);
}
{
Test test(false);
test.test(rope2, outputs2);
}

for (int i = 0; i < 2; i++) {
for (int j = 0; j < 32; j++) {
for (int k = 0; k < 7; k++) {
for (int l = 0; l < 128; l++) {
// std::cout << outputs1[i][j][k][l] << " ";
assert(outputs1[i][j][k][l] == outputs2[i][j][k][l]);
}
// std::cout << "\n";
}
}
}
}