|
| 1 | +// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils %s | FileCheck %s |
| 2 | + |
| 3 | +!dtype=i16 |
| 4 | +!input_memref_type=memref<2x7x32x128x!dtype> |
| 5 | +!input_tensor_type=tensor<2x7x32x128x!dtype> |
| 6 | +!output_memref_type=memref<2x32x7x128x!dtype> |
| 7 | +!output_tensor_type=tensor<2x32x7x128x!dtype> |
| 8 | +!cos_sin_cache_memref_type=memref<1x1x7x128x!dtype> |
| 9 | +!cos_sin_cache_tensor_type=tensor<1x1x7x128x!dtype> |
| 10 | +!cos_sin_cache_tensor_shrink_type=tensor<1x1x7x128x!dtype> |
| 11 | +!pos_ids_memref_type=memref<1x7xindex> |
| 12 | +module @fragment_name { |
| 13 | +memref.global "private" constant @_all_zeroes : !output_memref_type = dense<0> |
| 14 | + func.func @rope1(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %out: !output_memref_type, |
| 15 | + %cos_cache : !cos_sin_cache_memref_type, %sin_cache : !cos_sin_cache_memref_type) { |
| 16 | + %input = bufferization.to_tensor %iinput restrict : !input_memref_type |
| 17 | + %cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type |
| 18 | + %sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type |
| 19 | + %pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type |
| 20 | + %3 = tensor.empty(): !output_tensor_type |
| 21 | + %transpose_in = linalg.transpose ins(%input: !input_tensor_type) outs(%3:!output_tensor_type) permutation = [0, 2, 1, 3] |
| 22 | + |
| 23 | + %c0 = arith.constant 0 : index |
| 24 | + %c3 = arith.constant 3 : index |
| 25 | + %cos_cache_slice = tensor.extract_slice %cos_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type |
| 26 | + %cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype> |
| 27 | + %cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype> |
| 28 | + %pos_ids_index=tensor.expand_shape %pos_ids [[0],[1,2]] output_shape [1, 7, 1] : tensor<1x7xindex> into tensor<1x7x1xindex> |
| 29 | + |
| 30 | + %cos_cache_slice4 = tensor.gather %cos_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype> |
| 31 | + |
| 32 | + %cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype> |
| 33 | + %cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype> |
| 34 | + |
| 35 | + |
| 36 | + %cos_cache_slice7 = linalg.broadcast ins(%cos_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1] |
| 37 | + %input_apply_cos_cache = linalg.mul ins(%transpose_in, %cos_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type |
| 38 | + |
| 39 | + %head_dim = tensor.dim %transpose_in, %c3 : !output_tensor_type |
| 40 | + %c2 = arith.constant 2 : index |
| 41 | + %half_head_dim = arith.floordivsi %head_dim, %c2 : index |
| 42 | + %transpose_input_first_half = tensor.extract_slice %transpose_in[0, 0, 0, 0][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype> |
| 43 | + %transpose_input_second_half = tensor.extract_slice %transpose_in[0, 0, 0, %half_head_dim][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype> |
| 44 | + %cnegative1 = arith.constant dense<-1> : tensor<2x32x7x64x!dtype> |
| 45 | + %empty_tensor = tensor.empty() : tensor<2x32x7x64x!dtype> |
| 46 | + %transpose_input_second_half_opposite = linalg.mul ins(%transpose_input_second_half, %cnegative1: tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) outs(%empty_tensor: tensor<2x32x7x64x!dtype>) -> tensor<2x32x7x64x!dtype> |
| 47 | + |
| 48 | + %transformed_input = tensor.concat dim(3) %transpose_input_second_half_opposite, %transpose_input_first_half : (tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) -> !output_tensor_type |
| 49 | + |
| 50 | + %sin_cache_slice = tensor.extract_slice %sin_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type |
| 51 | + %sin_cache_slice2 = tensor.collapse_shape %sin_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype> |
| 52 | + %sin_cache_slice3 = tensor.collapse_shape %sin_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype> |
| 53 | + %sin_cache_slice4 = tensor.gather %sin_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype> |
| 54 | + |
| 55 | + %sin_cache_slice5 = tensor.expand_shape %sin_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype> |
| 56 | + %sin_cache_slice6 = tensor.collapse_shape %sin_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype> |
| 57 | + %sin_cache_slice7 = linalg.broadcast ins(%sin_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1] |
| 58 | + %input_apply_sin_cache = linalg.mul ins(%transformed_input, %sin_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type |
| 59 | + |
| 60 | + %result = linalg.add ins(%input_apply_cos_cache, %input_apply_sin_cache: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type |
| 61 | + bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> () |
| 62 | + return |
| 63 | + } |
| 64 | + |
| 65 | + func.func @rope2(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %out: !output_memref_type, |
| 66 | + %cos_cache: !cos_sin_cache_memref_type, %sin_cache: !cos_sin_cache_memref_type) { |
| 67 | + %c0 = arith.constant 0 : index |
| 68 | + %c64 = arith.constant 64 : index |
| 69 | + %cm1 = arith.constant -1 : !dtype |
| 70 | + |
| 71 | + %input = bufferization.to_tensor %iinput restrict : !input_memref_type |
| 72 | + %cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type |
| 73 | + %sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type |
| 74 | + %pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type |
| 75 | + %tmp = tensor.empty(): !output_tensor_type |
| 76 | + |
| 77 | + %result = linalg.generic { |
| 78 | + indexing_maps = [ |
| 79 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
| 80 | + ], |
| 81 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"] |
| 82 | + } outs(%tmp : !output_tensor_type) { |
| 83 | + ^bb0(%ignore: !dtype): |
| 84 | + %i0 = linalg.index 0 : index |
| 85 | + %i1 = linalg.index 1 : index |
| 86 | + %i2 = linalg.index 2 : index |
| 87 | + %i3 = linalg.index 3 : index |
| 88 | + %pos = tensor.extract %pos_ids[%c0, %i2] : tensor<1x7xindex> |
| 89 | + %cos = tensor.extract %cos_cache_tensor[%c0, %c0, %pos, %i3] : !cos_sin_cache_tensor_type |
| 90 | + %sin = tensor.extract %sin_cache_tensor[%c0, %c0, %pos, %i3] : !cos_sin_cache_tensor_type |
| 91 | + %in = tensor.extract %input[%i0, %i2, %i1, %i3] : !input_tensor_type |
| 92 | + %cos_val = arith.muli %cos, %in : !dtype |
| 93 | + |
| 94 | + %cond = arith.cmpi slt, %i3, %c64 : index |
| 95 | + %sin_val = scf.if %cond -> (!dtype) { |
| 96 | + %i3_plus_64 = arith.addi %i3, %c64 : index |
| 97 | + %v = tensor.extract %input[%i0, %i2, %i1, %i3_plus_64] : !input_tensor_type |
| 98 | + %minusv = arith.muli %cm1, %v : !dtype |
| 99 | + %mul = arith.muli %sin, %minusv : !dtype |
| 100 | + scf.yield %mul : !dtype |
| 101 | + } else { |
| 102 | + %i3_minus_64 = arith.addi %i3, %c64 : index |
| 103 | + %v = tensor.extract %input[%i0, %i2, %i1, %i3_minus_64] : !input_tensor_type |
| 104 | + %mul = arith.muli %sin, %v : !dtype |
| 105 | + scf.yield %mul : !dtype |
| 106 | + } |
| 107 | + |
| 108 | + %sum = arith.addi %cos_val, %sin_val : !dtype |
| 109 | + linalg.yield %sum : !dtype |
| 110 | + } -> !output_tensor_type |
| 111 | + |
| 112 | + bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> () |
| 113 | + return |
| 114 | + } |
| 115 | + |
| 116 | + func.func @main() { |
| 117 | + %in_tmp = tensor.empty(): !input_tensor_type |
| 118 | + %input_values = linalg.generic { |
| 119 | + indexing_maps = [ |
| 120 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
| 121 | + ], |
| 122 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"] |
| 123 | + } outs(%in_tmp : !input_tensor_type) { |
| 124 | + ^bb0(%ignore: !dtype): |
| 125 | + %i3 = linalg.index 3 : index |
| 126 | + %val = arith.index_cast %i3 : index to !dtype |
| 127 | + linalg.yield %val : !dtype |
| 128 | + } -> !input_tensor_type |
| 129 | + %inp = memref.alloc() {alignment = 64 : i64} : !input_memref_type |
| 130 | + bufferization.materialize_in_destination %input_values in restrict writable %inp : (!input_tensor_type, !input_memref_type) -> () |
| 131 | + |
| 132 | + %ipos_ids_tmp = tensor.empty() : tensor<1x7xindex> |
| 133 | + %ipos_ids_values = linalg.generic { |
| 134 | + indexing_maps = [ |
| 135 | + affine_map<(d0, d1) -> (d0, d1)> |
| 136 | + ], |
| 137 | + iterator_types = ["parallel", "parallel"] |
| 138 | + } outs(%ipos_ids_tmp : tensor<1x7xindex>) { |
| 139 | + ^bb0(%ignore: index): |
| 140 | + %c6 = arith.constant 6 : index |
| 141 | + %i1 = linalg.index 1 : index |
| 142 | + %val = arith.subi %c6, %i1 : index |
| 143 | + linalg.yield %i1 : index |
| 144 | + } -> tensor<1x7xindex> |
| 145 | + %ipos_ids = memref.alloc() {alignment = 64 : i64} : !pos_ids_memref_type |
| 146 | + bufferization.materialize_in_destination %ipos_ids_values in restrict writable %ipos_ids : (tensor<1x7xindex>, !pos_ids_memref_type) -> () |
| 147 | + |
| 148 | + %cos_cache_tmp = tensor.empty() : !cos_sin_cache_tensor_type |
| 149 | + %sin_cache_tmp = tensor.empty() : !cos_sin_cache_tensor_type |
| 150 | + %cos_cache_values, %sin_cache_values = linalg.generic { |
| 151 | + indexing_maps = [ |
| 152 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 153 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
| 154 | + ], |
| 155 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"] |
| 156 | + } outs(%cos_cache_tmp, %sin_cache_tmp : !cos_sin_cache_tensor_type, !cos_sin_cache_tensor_type) { |
| 157 | + ^bb0(%ignore_cos: !dtype, %ignore_sin: !dtype): |
| 158 | + %c3 = arith.constant 3 : !dtype |
| 159 | + %c2 = arith.constant 2 : !dtype |
| 160 | + %i3 = linalg.index 3 : index |
| 161 | + %val = arith.index_cast %i3 : index to !dtype |
| 162 | + %cos = arith.addi %c3, %val : !dtype |
| 163 | + %sin = arith.addi %c2, %val : !dtype |
| 164 | + linalg.yield %cos, %sin : !dtype, !dtype |
| 165 | + } -> (!cos_sin_cache_tensor_type, !cos_sin_cache_tensor_type) |
| 166 | + %cos_cache = memref.alloc() {alignment = 64 : i64} : !cos_sin_cache_memref_type |
| 167 | + %sin_cache = memref.alloc() {alignment = 64 : i64} : !cos_sin_cache_memref_type |
| 168 | + bufferization.materialize_in_destination %cos_cache_values in restrict writable %cos_cache : (!cos_sin_cache_tensor_type, !cos_sin_cache_memref_type) -> () |
| 169 | + bufferization.materialize_in_destination %sin_cache_values in restrict writable %sin_cache : (!cos_sin_cache_tensor_type, !cos_sin_cache_memref_type) -> () |
| 170 | + |
| 171 | + %out1 = memref.alloc() {alignment = 64 : i64} : !output_memref_type |
| 172 | + %start1 = call @nanoTime() : () -> i64 |
| 173 | + func.call @rope1(%inp, %ipos_ids, %out1, %cos_cache, %cos_cache) : (!input_memref_type, !pos_ids_memref_type, !output_memref_type, !cos_sin_cache_memref_type, !cos_sin_cache_memref_type) -> () |
| 174 | + %end1 = call @nanoTime() : () -> i64 |
| 175 | + %time1 = arith.subi %end1, %start1 : i64 |
| 176 | + |
| 177 | + %out2 = memref.alloc() {alignment = 64 : i64} : !output_memref_type |
| 178 | + %start2 = call @nanoTime() : () -> i64 |
| 179 | + func.call @rope2(%inp, %ipos_ids, %out2, %cos_cache, %cos_cache) : (!input_memref_type, !pos_ids_memref_type, !output_memref_type, !cos_sin_cache_memref_type, !cos_sin_cache_memref_type) -> () |
| 180 | + %end2 = call @nanoTime() : () -> i64 |
| 181 | + %time2 = arith.subi %end2, %start2 : i64 |
| 182 | + |
| 183 | + %out1_tensor = bufferization.to_tensor %out1 restrict : !output_memref_type |
| 184 | + %out2_tensor = bufferization.to_tensor %out2 restrict : !output_memref_type |
| 185 | + %out_buf = tensor.empty(): !output_tensor_type |
| 186 | + %out_tensor = linalg.sub ins(%out1_tensor, %out2_tensor : !output_tensor_type, !output_tensor_type) |
| 187 | + outs(%out_buf : !output_tensor_type) -> !output_tensor_type |
| 188 | + %out = memref.alloc() {alignment = 64 : i64} : !output_memref_type |
| 189 | + bufferization.materialize_in_destination %out_tensor in restrict writable %out : (!output_tensor_type, !output_memref_type) -> () |
| 190 | + |
| 191 | + // %cast = memref.cast %out : !output_memref_type to memref<*xi16> |
| 192 | + // call @printMemrefI16(%cast) : (memref<*xi16>) -> () |
| 193 | + |
| 194 | +// CHECK: [[TIME1:[0-9]+]] |
| 195 | + llvm.call @printI64(%time1) : (i64) -> () |
| 196 | + llvm.call @printNewline() : () -> () |
| 197 | +// CHECK: [[TIME2:[0-9]+]] |
| 198 | + llvm.call @printI64(%time2) : (i64) -> () |
| 199 | + llvm.call @printNewline() : () -> () |
| 200 | + |
| 201 | + %all_zeroes = memref.get_global @_all_zeroes : !output_memref_type |
| 202 | + %cast_all_zeroes = memref.cast %all_zeroes : !output_memref_type to memref<*xi16> |
| 203 | + %cast_out = memref.cast %out : !output_memref_type to memref<*xi16> |
| 204 | + %cmp = call @verifyMemRefI16(%cast_all_zeroes, %cast_out) : (memref<*xi16>, memref<*xi16>) -> (i64) |
| 205 | +// CHECK: 0 |
| 206 | + llvm.call @printI64(%cmp) : (i64) -> () |
| 207 | + llvm.call @printNewline() : () -> () |
| 208 | + |
| 209 | + return |
| 210 | + } |
| 211 | + |
| 212 | +func.func private @printMemrefI16(%ptr : memref<*xi16>) |
| 213 | +func.func private @verifyMemRefI16(%a : memref<*xi16>, %b : memref<*xi16>) -> i64 attributes { llvm.emit_c_interface } |
| 214 | +func.func private @nanoTime() -> i64 |
| 215 | +llvm.func @printI64(i64) |
| 216 | +llvm.func @printNewline() |
| 217 | + |
| 218 | +} |
0 commit comments