|
| 1 | +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-avx512" -convert-std-to-llvm | \ |
| 2 | +// RUN: mlir-translate --mlir-to-llvmir | \ |
| 3 | +// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ |
| 4 | +// RUN: FileCheck %s |
| 5 | + |
| 6 | +// This test shows how to implement a sparse vector-vector dot product with |
| 7 | +// AVX512. It uses vp2intersect, mask.compress and vector.contract to compute |
| 8 | +// the dot product of two sparse HW vectors of 8 float64 elements ("segment"). |
| 9 | +// Each sparse vector is represented by an index memref (A or C) and by a data |
| 10 | +// memref (B or D), containing M or N elements. |
| 11 | +// |
| 12 | +// There are two implementations: |
| 13 | +// * `memref_dot_simple`: Simple O(N*M) implementation with two for loops. |
| 14 | +// * `memref_dot_optimized`: An optimized O(N*M) version of the previous |
| 15 | +// implementation, where the second for loop skips over some elements. |
| 16 | + |
| 17 | +#contraction_accesses = [ |
| 18 | + affine_map<(i) -> (i)>, |
| 19 | + affine_map<(i) -> (i)>, |
| 20 | + affine_map<(i) -> ()> |
| 21 | +] |
| 22 | +#contraction_trait = { |
| 23 | + indexing_maps = #contraction_accesses, |
| 24 | + iterator_types = ["reduction"] |
| 25 | +} |
| 26 | + |
| 27 | +// Sparse vector dot product of two vectors. |
| 28 | +func @vector_dot(%v_A : vector<8xi64>, %v_B : vector<8xf64>, |
| 29 | + %v_C : vector<8xi64>, %v_D : vector<8xf64>) -> f64 { |
| 30 | + // Compute intersection of indices. |
| 31 | + %k0, %k1 = avx512.vp2intersect %v_A, %v_C : vector<8xi64> |
| 32 | + |
| 33 | + // Filter out values without match and compress vector. |
| 34 | + %p0 = avx512.mask.compress %k0, %v_B : vector<8xf64> |
| 35 | + %p1 = avx512.mask.compress %k1, %v_D : vector<8xf64> |
| 36 | + |
| 37 | + // Dense vector dot product. |
| 38 | + %acc = std.constant 0.0 : f64 |
| 39 | + %r = vector.contract #contraction_trait %p0, %p1, %acc |
| 40 | + : vector<8xf64>, vector<8xf64> into f64 |
| 41 | + |
| 42 | + return %r : f64 |
| 43 | +} |
| 44 | + |
| 45 | +// Fill input memrefs will all zeros, so that they can be used with arbitrary |
| 46 | +// input sizes up to 128 elements per sparse vector. |
| 47 | +func @init_input(%m_A : memref<?xi64>, %m_B : memref<?xf64>, |
| 48 | + %m_C : memref<?xi64>, %m_D : memref<?xf64>) { |
| 49 | + %c0 = constant 0 : index |
| 50 | + %v_data = constant dense<0.0> : vector<128xf64> |
| 51 | + %v_index = constant dense<9223372036854775807> : vector<128xi64> |
| 52 | + |
| 53 | + vector.transfer_write %v_index, %m_A[%c0] : vector<128xi64>, memref<?xi64> |
| 54 | + vector.transfer_write %v_data, %m_B[%c0] : vector<128xf64>, memref<?xf64> |
| 55 | + vector.transfer_write %v_index, %m_C[%c0] : vector<128xi64>, memref<?xi64> |
| 56 | + vector.transfer_write %v_data, %m_D[%c0] : vector<128xf64>, memref<?xf64> |
| 57 | + |
| 58 | + return |
| 59 | +} |
| 60 | + |
| 61 | +func @fill_input_1(%m_A : memref<?xi64>, %m_B : memref<?xf64>, |
| 62 | + %m_C : memref<?xi64>, %m_D : memref<?xf64>) |
| 63 | + -> (index, index){ |
| 64 | + call @init_input(%m_A, %m_B, %m_C, %m_D) |
| 65 | + : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) -> () |
| 66 | + |
| 67 | + %c0 = constant 0 : index |
| 68 | + |
| 69 | + %v_A = std.constant dense<[0, 1, 10, 12, 13, 17, 18, 21, |
| 70 | + 51, 52, 57, 61, 62, 82, 98, 99]> : vector<16xi64> |
| 71 | + %v_B = std.constant dense<[1., 5., 8., 3., 2., 1., 0., 9., |
| 72 | + 6., 7., 7., 3., 5., 2., 9., 1.]> : vector<16xf64> |
| 73 | + %v_C = std.constant dense<[1, 2, 5, 10, 11, 12, 47, 48, |
| 74 | + 67, 68, 69, 70, 71, 72, 77, 78, |
| 75 | + 79, 82, 83, 84, 85, 90, 91, 98]> : vector<24xi64> |
| 76 | + %v_D = std.constant dense<[1., 5., 8., 3., 2., 1., 2., 9., |
| 77 | + 6., 7., 7., 3., 5., 2., 9., 1., |
| 78 | + 2., 9., 8., 7., 2., 0., 0., 4.]> : vector<24xf64> |
| 79 | + |
| 80 | + vector.transfer_write %v_A, %m_A[%c0] : vector<16xi64>, memref<?xi64> |
| 81 | + vector.transfer_write %v_B, %m_B[%c0] : vector<16xf64>, memref<?xf64> |
| 82 | + vector.transfer_write %v_C, %m_C[%c0] : vector<24xi64>, memref<?xi64> |
| 83 | + vector.transfer_write %v_D, %m_D[%c0] : vector<24xf64>, memref<?xf64> |
| 84 | + |
| 85 | + %M = std.constant 16 : index |
| 86 | + %N = std.constant 24 : index |
| 87 | + |
| 88 | + return %M, %N : index, index |
| 89 | +} |
| 90 | + |
| 91 | +func @fill_input_2(%m_A : memref<?xi64>, %m_B : memref<?xf64>, |
| 92 | + %m_C : memref<?xi64>, %m_D : memref<?xf64>) |
| 93 | + -> (index, index){ |
| 94 | + call @init_input(%m_A, %m_B, %m_C, %m_D) |
| 95 | + : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) -> () |
| 96 | + |
| 97 | + %c0 = constant 0 : index |
| 98 | + |
| 99 | + %v_A = std.constant dense<[0, 1, 3, 5, 6, 7, 8, 9, |
| 100 | + 51, 52, 57, 61, 62, 63, 65, 66]> : vector<16xi64> |
| 101 | + %v_B = std.constant dense<[1., 5., 8., 3., 2., 1., 2., 9., |
| 102 | + 6., 7., 7., 3., 5., 2., 9., 1.]> : vector<16xf64> |
| 103 | + %v_C = std.constant dense<[6, 7, 11, 12, 15, 17, 19, 21, |
| 104 | + 30, 31, 33, 34, 37, 39, 40, 41, |
| 105 | + 42, 44, 45, 46, 47, 48, 49, 50, |
| 106 | + 62, 63, 64, 65, 66, 67, 68, 69, |
| 107 | + 70, 77, 78, 79, 81, 82, 89, 99]> : vector<40xi64> |
| 108 | + %v_D = std.constant dense<[1., 5., 8., 3., 2., 1., 2., 9., |
| 109 | + 6., 7., 7., 3., 5., 2., 9., 1., |
| 110 | + 2., 9., 8., 7., 2., 1., 2., 4., |
| 111 | + 4., 5., 8., 8., 2., 3., 5., 1., |
| 112 | + 8., 6., 6., 4., 3., 8., 9., 2.]> : vector<40xf64> |
| 113 | + |
| 114 | + vector.transfer_write %v_A, %m_A[%c0] : vector<16xi64>, memref<?xi64> |
| 115 | + vector.transfer_write %v_B, %m_B[%c0] : vector<16xf64>, memref<?xf64> |
| 116 | + vector.transfer_write %v_C, %m_C[%c0] : vector<40xi64>, memref<?xi64> |
| 117 | + vector.transfer_write %v_D, %m_D[%c0] : vector<40xf64>, memref<?xf64> |
| 118 | + |
| 119 | + %M = std.constant 16 : index |
| 120 | + %N = std.constant 40 : index |
| 121 | + |
| 122 | + return %M, %N : index, index |
| 123 | +} |
| 124 | + |
| 125 | +// Simple vector dot product implementation: Intersect every segment of size 8 |
| 126 | +// in (%m_A, %m_B) with every segment of size 8 in (%m_C, %m_D). |
| 127 | +func @memref_dot_simple(%m_A : memref<?xi64>, %m_B : memref<?xf64>, |
| 128 | + %m_C : memref<?xi64>, %m_D : memref<?xf64>, |
| 129 | + %M : index, %N : index) |
| 130 | + -> f64 { |
| 131 | + // Helper constants for loops. |
| 132 | + %c0 = constant 0 : index |
| 133 | + %c8 = constant 8 : index |
| 134 | + |
| 135 | + %data_zero = constant 0.0 : f64 |
| 136 | + %index_padding = constant 9223372036854775807 : i64 |
| 137 | + |
| 138 | + // Notation: %sum is the current (partial) aggregated dot product sum. |
| 139 | + |
| 140 | + %r0 = scf.for %a = %c0 to %M step %c8 |
| 141 | + iter_args(%sum0 = %data_zero) -> (f64) { |
| 142 | + %v_A = vector.transfer_read %m_A[%a], %index_padding |
| 143 | + : memref<?xi64>, vector<8xi64> |
| 144 | + %v_B = vector.transfer_read %m_B[%a], %data_zero |
| 145 | + : memref<?xf64>, vector<8xf64> |
| 146 | + |
| 147 | + %r1 = scf.for %b = %c0 to %N step %c8 |
| 148 | + iter_args(%sum1 = %sum0) -> (f64) { |
| 149 | + %v_C = vector.transfer_read %m_C[%b], %index_padding |
| 150 | + : memref<?xi64>, vector<8xi64> |
| 151 | + %v_D = vector.transfer_read %m_D[%b], %data_zero |
| 152 | + : memref<?xf64>, vector<8xf64> |
| 153 | + |
| 154 | + %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D) |
| 155 | + : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) -> f64 |
| 156 | + %r2 = addf %sum1, %subresult : f64 |
| 157 | + scf.yield %r2 : f64 |
| 158 | + } |
| 159 | + |
| 160 | + scf.yield %r1 : f64 |
| 161 | + } |
| 162 | + |
| 163 | + return %r0 : f64 |
| 164 | +} |
| 165 | + |
| 166 | +// Optimized vector dot product implementation: Taking advantage of the fact |
| 167 | +// that indices in %m_A and %m_C are sorted ascendingly, skip over segments |
| 168 | +// in (%m_C, %m_D) that are know to have no intersection with the current |
| 169 | +// segment from (%m_A, %m_B). |
| 170 | +func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>, |
| 171 | + %m_C : memref<?xi64>, %m_D : memref<?xf64>, |
| 172 | + %M : index, %N : index) |
| 173 | + -> f64 { |
| 174 | + // Helper constants for loops. |
| 175 | + %c0 = constant 0 : index |
| 176 | + %i0 = constant 0 : i32 |
| 177 | + %i7 = constant 7 : i32 |
| 178 | + %c8 = constant 8 : index |
| 179 | + |
| 180 | + %data_zero = constant 0.0 : f64 |
| 181 | + %index_padding = constant 9223372036854775807 : i64 |
| 182 | + |
| 183 | + // Notation: %sum is the current (partial) aggregated dot product sum. |
| 184 | + // %j_start is the value from which the inner for loop starts iterating. This |
| 185 | + // value keeps increasing if earlier segments of (%m_C, %m_D) are known to |
| 186 | + // be no longer needed. |
| 187 | + |
| 188 | + %r0, %t0 = scf.for %a = %c0 to %M step %c8 |
| 189 | + iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) { |
| 190 | + %v_A = vector.transfer_read %m_A[%a], %index_padding |
| 191 | + : memref<?xi64>, vector<8xi64> |
| 192 | + %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> |
| 193 | + |
| 194 | + %r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8 |
| 195 | + iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) { |
| 196 | + %v_C = vector.transfer_read %m_C[%b], %index_padding |
| 197 | + : memref<?xi64>, vector<8xi64> |
| 198 | + %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> |
| 199 | + %seg1_done = cmpi "slt", %segB_max, %segA_min : i64 |
| 200 | + |
| 201 | + %r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) { |
| 202 | + // %v_C segment is done, no need to examine this one again (ever). |
| 203 | + %next_b_start2 = addi %b_start1, %c8 : index |
| 204 | + scf.yield %sum1, %next_b_start2 : f64, index |
| 205 | + } else { |
| 206 | + %v_B = vector.transfer_read %m_B[%a], %data_zero |
| 207 | + : memref<?xf64>, vector<8xf64> |
| 208 | + %v_D = vector.transfer_read %m_D[%b], %data_zero |
| 209 | + : memref<?xf64>, vector<8xf64> |
| 210 | + |
| 211 | + %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D) |
| 212 | + : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) |
| 213 | + -> f64 |
| 214 | + %r3 = addf %sum1, %subresult : f64 |
| 215 | + scf.yield %r3, %b_start1 : f64, index |
| 216 | + } |
| 217 | + |
| 218 | + scf.yield %r2, %next_b_start1 : f64, index |
| 219 | + } |
| 220 | + |
| 221 | + scf.yield %r1, %next_b_start0 : f64, index |
| 222 | + } |
| 223 | + |
| 224 | + return %r0 : f64 |
| 225 | +} |
| 226 | + |
| 227 | +func @entry() -> i32 { |
| 228 | + // Initialize large buffers that can be used for multiple test cases of |
| 229 | + // different sizes. |
| 230 | + %b_A = alloc() : memref<128xi64> |
| 231 | + %b_B = alloc() : memref<128xf64> |
| 232 | + %b_C = alloc() : memref<128xi64> |
| 233 | + %b_D = alloc() : memref<128xf64> |
| 234 | + |
| 235 | + %m_A = memref_cast %b_A : memref<128xi64> to memref<?xi64> |
| 236 | + %m_B = memref_cast %b_B : memref<128xf64> to memref<?xf64> |
| 237 | + %m_C = memref_cast %b_C : memref<128xi64> to memref<?xi64> |
| 238 | + %m_D = memref_cast %b_D : memref<128xf64> to memref<?xf64> |
| 239 | + |
| 240 | + // --- Test case 1 ---. |
| 241 | + // M and N must be a multiple of 8 if smaller than 128. |
| 242 | + // (Because padding kicks in only for out-of-bounds accesses.) |
| 243 | + %M1, %N1 = call @fill_input_1(%m_A, %m_B, %m_C, %m_D) |
| 244 | + : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) |
| 245 | + -> (index, index) |
| 246 | + |
| 247 | + %r0 = call @memref_dot_simple(%m_A, %m_B, %m_C, %m_D, %M1, %N1) |
| 248 | + : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, |
| 249 | + index, index) -> f64 |
| 250 | + vector.print %r0 : f64 |
| 251 | + // CHECK: 86 |
| 252 | + |
| 253 | + %r1 = call @memref_dot_optimized(%m_A, %m_B, %m_C, %m_D, %M1, %N1) |
| 254 | + : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, |
| 255 | + index, index) -> f64 |
| 256 | + vector.print %r1 : f64 |
| 257 | + // CHECK: 86 |
| 258 | + |
| 259 | + // --- Test case 2 ---. |
| 260 | + // M and N must be a multiple of 8 if smaller than 128. |
| 261 | + // (Because padding kicks in only for out-of-bounds accesses.) |
| 262 | + %M2, %N2 = call @fill_input_2(%m_A, %m_B, %m_C, %m_D) |
| 263 | + : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) |
| 264 | + -> (index, index) |
| 265 | + |
| 266 | + %r3 = call @memref_dot_simple(%m_A, %m_B, %m_C, %m_D, %M2, %N2) |
| 267 | + : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, |
| 268 | + index, index) -> f64 |
| 269 | + vector.print %r3 : f64 |
| 270 | + // CHECK: 111 |
| 271 | + |
| 272 | + %r4 = call @memref_dot_optimized(%m_A, %m_B, %m_C, %m_D, %M2, %N2) |
| 273 | + : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, |
| 274 | + index, index) -> f64 |
| 275 | + vector.print %r4 : f64 |
| 276 | + // CHECK: 111 |
| 277 | + |
| 278 | + // Release all resources. |
| 279 | + dealloc %b_A : memref<128xi64> |
| 280 | + dealloc %b_B : memref<128xf64> |
| 281 | + dealloc %b_C : memref<128xi64> |
| 282 | + dealloc %b_D : memref<128xf64> |
| 283 | + |
| 284 | + %r = constant 0 : i32 |
| 285 | + return %r : i32 |
| 286 | +} |
0 commit comments