Skip to content

Commit c40e0d7

Browse files
[mlir][AVX512] Implement sparse vector dot product integration test.
This test operates on two hardware-vector-sized vectors and utilizes vp2intersect and mask.compress. PHAB_REVIEW=D98099
1 parent 9773cad commit c40e0d7

File tree

1 file changed

+286
-0
lines changed

1 file changed

+286
-0
lines changed
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-avx512" -convert-std-to-llvm | \
2+
// RUN: mlir-translate --mlir-to-llvmir | \
3+
// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
4+
// RUN: FileCheck %s
5+
6+
// This test shows how to implement a sparse vector-vector dot product with
7+
// AVX512. It uses vp2intersect, mask.compress and vector.contract to compute
8+
// the dot product of two sparse HW vectors of 8 float64 elements ("segment").
9+
// Each sparse vector is represented by an index memref (A or C) and by a data
10+
// memref (B or D), containing M or N elements.
11+
//
12+
// There are two implementations:
13+
// * `memref_dot_simple`: Simple O(N*M) implementation with two for loops.
14+
// * `memref_dot_optimized`: An optimized O(N*M) version of the previous
15+
// implementation, where the second for loop skips over some elements.
16+
17+
#contraction_accesses = [
18+
affine_map<(i) -> (i)>,
19+
affine_map<(i) -> (i)>,
20+
affine_map<(i) -> ()>
21+
]
22+
#contraction_trait = {
23+
indexing_maps = #contraction_accesses,
24+
iterator_types = ["reduction"]
25+
}
26+
27+
// Sparse vector dot product of two vectors.
28+
func @vector_dot(%v_A : vector<8xi64>, %v_B : vector<8xf64>,
29+
%v_C : vector<8xi64>, %v_D : vector<8xf64>) -> f64 {
30+
// Compute intersection of indices.
31+
%k0, %k1 = avx512.vp2intersect %v_A, %v_C : vector<8xi64>
32+
33+
// Filter out values without match and compress vector.
34+
%p0 = avx512.mask.compress %k0, %v_B : vector<8xf64>
35+
%p1 = avx512.mask.compress %k1, %v_D : vector<8xf64>
36+
37+
// Dense vector dot product.
38+
%acc = std.constant 0.0 : f64
39+
%r = vector.contract #contraction_trait %p0, %p1, %acc
40+
: vector<8xf64>, vector<8xf64> into f64
41+
42+
return %r : f64
43+
}
44+
45+
// Fill input memrefs will all zeros, so that they can be used with arbitrary
46+
// input sizes up to 128 elements per sparse vector.
47+
func @init_input(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
48+
%m_C : memref<?xi64>, %m_D : memref<?xf64>) {
49+
%c0 = constant 0 : index
50+
%v_data = constant dense<0.0> : vector<128xf64>
51+
%v_index = constant dense<9223372036854775807> : vector<128xi64>
52+
53+
vector.transfer_write %v_index, %m_A[%c0] : vector<128xi64>, memref<?xi64>
54+
vector.transfer_write %v_data, %m_B[%c0] : vector<128xf64>, memref<?xf64>
55+
vector.transfer_write %v_index, %m_C[%c0] : vector<128xi64>, memref<?xi64>
56+
vector.transfer_write %v_data, %m_D[%c0] : vector<128xf64>, memref<?xf64>
57+
58+
return
59+
}
60+
61+
func @fill_input_1(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
62+
%m_C : memref<?xi64>, %m_D : memref<?xf64>)
63+
-> (index, index){
64+
call @init_input(%m_A, %m_B, %m_C, %m_D)
65+
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) -> ()
66+
67+
%c0 = constant 0 : index
68+
69+
%v_A = std.constant dense<[0, 1, 10, 12, 13, 17, 18, 21,
70+
51, 52, 57, 61, 62, 82, 98, 99]> : vector<16xi64>
71+
%v_B = std.constant dense<[1., 5., 8., 3., 2., 1., 0., 9.,
72+
6., 7., 7., 3., 5., 2., 9., 1.]> : vector<16xf64>
73+
%v_C = std.constant dense<[1, 2, 5, 10, 11, 12, 47, 48,
74+
67, 68, 69, 70, 71, 72, 77, 78,
75+
79, 82, 83, 84, 85, 90, 91, 98]> : vector<24xi64>
76+
%v_D = std.constant dense<[1., 5., 8., 3., 2., 1., 2., 9.,
77+
6., 7., 7., 3., 5., 2., 9., 1.,
78+
2., 9., 8., 7., 2., 0., 0., 4.]> : vector<24xf64>
79+
80+
vector.transfer_write %v_A, %m_A[%c0] : vector<16xi64>, memref<?xi64>
81+
vector.transfer_write %v_B, %m_B[%c0] : vector<16xf64>, memref<?xf64>
82+
vector.transfer_write %v_C, %m_C[%c0] : vector<24xi64>, memref<?xi64>
83+
vector.transfer_write %v_D, %m_D[%c0] : vector<24xf64>, memref<?xf64>
84+
85+
%M = std.constant 16 : index
86+
%N = std.constant 24 : index
87+
88+
return %M, %N : index, index
89+
}
90+
91+
func @fill_input_2(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
92+
%m_C : memref<?xi64>, %m_D : memref<?xf64>)
93+
-> (index, index){
94+
call @init_input(%m_A, %m_B, %m_C, %m_D)
95+
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) -> ()
96+
97+
%c0 = constant 0 : index
98+
99+
%v_A = std.constant dense<[0, 1, 3, 5, 6, 7, 8, 9,
100+
51, 52, 57, 61, 62, 63, 65, 66]> : vector<16xi64>
101+
%v_B = std.constant dense<[1., 5., 8., 3., 2., 1., 2., 9.,
102+
6., 7., 7., 3., 5., 2., 9., 1.]> : vector<16xf64>
103+
%v_C = std.constant dense<[6, 7, 11, 12, 15, 17, 19, 21,
104+
30, 31, 33, 34, 37, 39, 40, 41,
105+
42, 44, 45, 46, 47, 48, 49, 50,
106+
62, 63, 64, 65, 66, 67, 68, 69,
107+
70, 77, 78, 79, 81, 82, 89, 99]> : vector<40xi64>
108+
%v_D = std.constant dense<[1., 5., 8., 3., 2., 1., 2., 9.,
109+
6., 7., 7., 3., 5., 2., 9., 1.,
110+
2., 9., 8., 7., 2., 1., 2., 4.,
111+
4., 5., 8., 8., 2., 3., 5., 1.,
112+
8., 6., 6., 4., 3., 8., 9., 2.]> : vector<40xf64>
113+
114+
vector.transfer_write %v_A, %m_A[%c0] : vector<16xi64>, memref<?xi64>
115+
vector.transfer_write %v_B, %m_B[%c0] : vector<16xf64>, memref<?xf64>
116+
vector.transfer_write %v_C, %m_C[%c0] : vector<40xi64>, memref<?xi64>
117+
vector.transfer_write %v_D, %m_D[%c0] : vector<40xf64>, memref<?xf64>
118+
119+
%M = std.constant 16 : index
120+
%N = std.constant 40 : index
121+
122+
return %M, %N : index, index
123+
}
124+
125+
// Simple vector dot product implementation: Intersect every segment of size 8
126+
// in (%m_A, %m_B) with every segment of size 8 in (%m_C, %m_D).
127+
func @memref_dot_simple(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
128+
%m_C : memref<?xi64>, %m_D : memref<?xf64>,
129+
%M : index, %N : index)
130+
-> f64 {
131+
// Helper constants for loops.
132+
%c0 = constant 0 : index
133+
%c8 = constant 8 : index
134+
135+
%data_zero = constant 0.0 : f64
136+
%index_padding = constant 9223372036854775807 : i64
137+
138+
// Notation: %sum is the current (partial) aggregated dot product sum.
139+
140+
%r0 = scf.for %a = %c0 to %M step %c8
141+
iter_args(%sum0 = %data_zero) -> (f64) {
142+
%v_A = vector.transfer_read %m_A[%a], %index_padding
143+
: memref<?xi64>, vector<8xi64>
144+
%v_B = vector.transfer_read %m_B[%a], %data_zero
145+
: memref<?xf64>, vector<8xf64>
146+
147+
%r1 = scf.for %b = %c0 to %N step %c8
148+
iter_args(%sum1 = %sum0) -> (f64) {
149+
%v_C = vector.transfer_read %m_C[%b], %index_padding
150+
: memref<?xi64>, vector<8xi64>
151+
%v_D = vector.transfer_read %m_D[%b], %data_zero
152+
: memref<?xf64>, vector<8xf64>
153+
154+
%subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D)
155+
: (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) -> f64
156+
%r2 = addf %sum1, %subresult : f64
157+
scf.yield %r2 : f64
158+
}
159+
160+
scf.yield %r1 : f64
161+
}
162+
163+
return %r0 : f64
164+
}
165+
166+
// Optimized vector dot product implementation: Taking advantage of the fact
167+
// that indices in %m_A and %m_C are sorted ascendingly, skip over segments
168+
// in (%m_C, %m_D) that are know to have no intersection with the current
169+
// segment from (%m_A, %m_B).
170+
func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
171+
%m_C : memref<?xi64>, %m_D : memref<?xf64>,
172+
%M : index, %N : index)
173+
-> f64 {
174+
// Helper constants for loops.
175+
%c0 = constant 0 : index
176+
%i0 = constant 0 : i32
177+
%i7 = constant 7 : i32
178+
%c8 = constant 8 : index
179+
180+
%data_zero = constant 0.0 : f64
181+
%index_padding = constant 9223372036854775807 : i64
182+
183+
// Notation: %sum is the current (partial) aggregated dot product sum.
184+
// %j_start is the value from which the inner for loop starts iterating. This
185+
// value keeps increasing if earlier segments of (%m_C, %m_D) are known to
186+
// be no longer needed.
187+
188+
%r0, %t0 = scf.for %a = %c0 to %M step %c8
189+
iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) {
190+
%v_A = vector.transfer_read %m_A[%a], %index_padding
191+
: memref<?xi64>, vector<8xi64>
192+
%segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64>
193+
194+
%r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8
195+
iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) {
196+
%v_C = vector.transfer_read %m_C[%b], %index_padding
197+
: memref<?xi64>, vector<8xi64>
198+
%segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
199+
%seg1_done = cmpi "slt", %segB_max, %segA_min : i64
200+
201+
%r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) {
202+
// %v_C segment is done, no need to examine this one again (ever).
203+
%next_b_start2 = addi %b_start1, %c8 : index
204+
scf.yield %sum1, %next_b_start2 : f64, index
205+
} else {
206+
%v_B = vector.transfer_read %m_B[%a], %data_zero
207+
: memref<?xf64>, vector<8xf64>
208+
%v_D = vector.transfer_read %m_D[%b], %data_zero
209+
: memref<?xf64>, vector<8xf64>
210+
211+
%subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D)
212+
: (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
213+
-> f64
214+
%r3 = addf %sum1, %subresult : f64
215+
scf.yield %r3, %b_start1 : f64, index
216+
}
217+
218+
scf.yield %r2, %next_b_start1 : f64, index
219+
}
220+
221+
scf.yield %r1, %next_b_start0 : f64, index
222+
}
223+
224+
return %r0 : f64
225+
}
226+
227+
func @entry() -> i32 {
228+
// Initialize large buffers that can be used for multiple test cases of
229+
// different sizes.
230+
%b_A = alloc() : memref<128xi64>
231+
%b_B = alloc() : memref<128xf64>
232+
%b_C = alloc() : memref<128xi64>
233+
%b_D = alloc() : memref<128xf64>
234+
235+
%m_A = memref_cast %b_A : memref<128xi64> to memref<?xi64>
236+
%m_B = memref_cast %b_B : memref<128xf64> to memref<?xf64>
237+
%m_C = memref_cast %b_C : memref<128xi64> to memref<?xi64>
238+
%m_D = memref_cast %b_D : memref<128xf64> to memref<?xf64>
239+
240+
// --- Test case 1 ---.
241+
// M and N must be a multiple of 8 if smaller than 128.
242+
// (Because padding kicks in only for out-of-bounds accesses.)
243+
%M1, %N1 = call @fill_input_1(%m_A, %m_B, %m_C, %m_D)
244+
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>)
245+
-> (index, index)
246+
247+
%r0 = call @memref_dot_simple(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
248+
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
249+
index, index) -> f64
250+
vector.print %r0 : f64
251+
// CHECK: 86
252+
253+
%r1 = call @memref_dot_optimized(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
254+
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
255+
index, index) -> f64
256+
vector.print %r1 : f64
257+
// CHECK: 86
258+
259+
// --- Test case 2 ---.
260+
// M and N must be a multiple of 8 if smaller than 128.
261+
// (Because padding kicks in only for out-of-bounds accesses.)
262+
%M2, %N2 = call @fill_input_2(%m_A, %m_B, %m_C, %m_D)
263+
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>)
264+
-> (index, index)
265+
266+
%r3 = call @memref_dot_simple(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
267+
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
268+
index, index) -> f64
269+
vector.print %r3 : f64
270+
// CHECK: 111
271+
272+
%r4 = call @memref_dot_optimized(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
273+
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
274+
index, index) -> f64
275+
vector.print %r4 : f64
276+
// CHECK: 111
277+
278+
// Release all resources.
279+
dealloc %b_A : memref<128xi64>
280+
dealloc %b_B : memref<128xf64>
281+
dealloc %b_C : memref<128xi64>
282+
dealloc %b_D : memref<128xf64>
283+
284+
%r = constant 0 : i32
285+
return %r : i32
286+
}

0 commit comments

Comments
 (0)