16
16
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
17
17
// RUN: | FileCheck %s
18
18
19
+
20
+ // External functions for utilities
21
+ // printMemrefF32: Prints contents of a float32 memref buffer
22
+ // rtclock: Returns current time for performance measurement
19
23
func.func private @printMemrefF32 (memref <*xf32 >)
20
24
func.func private @rtclock () -> f64
21
25
26
+ // Main SpMM computation kernel
27
+ // Parameters:
28
+ // - values: Non-zero values of sparse matrix A in CSR format
29
+ // - col_indices: Column indices for each non-zero in A
30
+ // - row_pointers: Start/end indices for each row of A in values array
31
+ // - dense: Dense input matrix B
32
+ // - result: Output dense matrix C = A * B
22
33
func.func @spmm (%values: memref <?xf32 >, %col_indices: memref <?xi32 >,
23
34
%row_pointers: memref <?xi32 >, %dense: memref <?x?xf32 >,
24
35
%result: memref <?x?xf32 >) {
25
36
%c0 = arith.constant 0 : index
26
37
%c1 = arith.constant 1 : index
27
38
39
+ // Get dimensions for iteration bounds
40
+ // num_rows: Number of rows in sparse matrix A and result C
41
+ // num_cols: Number of columns in result C
42
+ // dense_cols: Number of columns in dense matrix B
28
43
%num_rows = memref.dim %result , %c0 : memref <?x?xf32 >
29
44
%num_cols = memref.dim %result , %c1 : memref <?x?xf32 >
30
45
%dense_cols = memref.dim %dense , %c1 : memref <?x?xf32 >
31
46
47
+ // Start timing the computation
32
48
%t_start = call @rtclock () : () -> f64
33
49
50
+ // Main computation loops:
51
+ // 1. Outer loop: Iterate over each row of sparse matrix A
52
+ // 2. Middle loop: Iterate over each column of dense matrix B
53
+ // 3. Inner loop: Process non-zeros in current row of A
34
54
scf.for %i = %c0 to %num_rows step %c1 {
55
+ // Get start and end indices for current row in CSR format
56
+ // These indices mark the range of non-zeros in the current row
35
57
%row_start_ptr = memref.load %row_pointers [%i ] : memref <?xi32 >
36
58
%row_start = arith.index_cast %row_start_ptr : i32 to index
37
59
38
60
%i_plus_1 = arith.addi %i , %c1 : index
39
61
%row_end_ptr = memref.load %row_pointers [%i_plus_1 ] : memref <?xi32 >
40
62
%row_end = arith.index_cast %row_end_ptr : i32 to index
41
63
64
+ // Process each column in result matrix C
42
65
scf.for %j = %c0 to %dense_cols step %c1 {
66
+ // Initialize accumulator for dot product computation
43
67
%sum = arith.constant 0.0 : f32
44
68
69
+ // Compute dot product of current row of A with column j of B
45
70
%result_sum = scf.for %k = %row_start to %row_end step %c1 iter_args (%current_sum = %sum ) -> (f32 ) {
71
+ // Load non-zero value from A and its column index
46
72
%val = memref.load %values [%k ] : memref <?xf32 >
47
73
%col_ptr = memref.load %col_indices [%k ] : memref <?xi32 >
48
74
%col = arith.index_cast %col_ptr : i32 to index
49
75
76
+ // Load corresponding value from dense matrix B
50
77
%dense_val = memref.load %dense [%col , %j ] : memref <?x?xf32 >
51
78
79
+ // Multiply and accumulate into partial sum
52
80
%prod = arith.mulf %val , %dense_val : f32
53
81
%new_sum = arith.addf %current_sum , %prod : f32
54
82
55
83
scf.yield %new_sum : f32
56
84
}
57
85
86
+ // Store computed result in output matrix C
58
87
memref.store %result_sum , %result [%i , %j ] : memref <?x?xf32 >
59
88
}
60
89
}
@@ -71,6 +100,8 @@ func.func @spmm(%values: memref<?xf32>, %col_indices: memref<?xi32>,
71
100
return
72
101
}
73
102
103
+ // Helper function to allocate and initialize values array
104
+ // Allocates memory for given size and fills with specified value
74
105
func.func @alloc_values (%size: index , %val: f32 ) -> memref <?xf32 > {
75
106
%c0 = arith.constant 0 : index
76
107
%c1 = arith.constant 1 : index
@@ -83,6 +114,8 @@ func.func @alloc_values(%size: index, %val: f32) -> memref<?xf32> {
83
114
return %values : memref <?xf32 >
84
115
}
85
116
117
+ // Helper function to allocate and initialize column indices array
118
+ // Creates cyclic pattern of indices modulo pattern size
86
119
func.func @alloc_col_indices (%size: index , %pattern: index ) -> memref <?xi32 > {
87
120
%c0 = arith.constant 0 : index
88
121
%c1 = arith.constant 1 : index
@@ -98,6 +131,8 @@ func.func @alloc_col_indices(%size: index, %pattern: index) -> memref<?xi32> {
98
131
return %indices : memref <?xi32 >
99
132
}
100
133
134
+ // Helper function to allocate and initialize row pointers array
135
+ // Creates regular pattern with fixed number of non-zeros per row
101
136
func.func @alloc_row_pointers (%rows: index , %nnz_per_row: index ) -> memref <?xi32 > {
102
137
%c0 = arith.constant 0 : index
103
138
%c1 = arith.constant 1 : index
@@ -113,6 +148,8 @@ func.func @alloc_row_pointers(%rows: index, %nnz_per_row: index) -> memref<?xi32
113
148
return %pointers : memref <?xi32 >
114
149
}
115
150
151
+ // Helper function to allocate and initialize 2D float array
152
+ // Allocates memory for given dimensions and fills with specified value
116
153
func.func @alloc_f32 (%arg0: index , %arg1: index , %arg2: f32 ) -> memref <?x?xf32 > {
117
154
%c0 = arith.constant 0 : index
118
155
%c1 = arith.constant 1 : index
@@ -127,6 +164,11 @@ func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: f32) -> memref<?x?xf32>
127
164
return %0 : memref <?x?xf32 >
128
165
}
129
166
167
+ // Main function that sets up test case and runs SpMM
168
+ // Creates a test case with:
169
+ // - 4x4 sparse matrix with 2 non-zeros per row (values = 2.0)
170
+ // - 4x3 dense matrix (values = 3.0)
171
+ // - 4x3 result matrix (initialized to 0.0)
130
172
func.func @main () {
131
173
%c3 = arith.constant 3 : index
132
174
%c4 = arith.constant 4 : index
@@ -135,19 +177,27 @@ func.func @main() {
135
177
%f2 = arith.constant 2.0 : f32
136
178
%f3 = arith.constant 3.0 : f32
137
179
138
- //test:
180
+ // Test parameters:
181
+ // nnz: Total number of non-zeros in sparse matrix (8)
182
+ // nnz_per_row: Number of non-zeros per row (2)
139
183
%nnz = arith.constant 8 : index
140
184
%nnz_per_row = arith.constant 2 : index
141
185
186
+ // Allocate and initialize sparse matrix components
142
187
%values = call @alloc_values (%nnz , %f2 ) : (index , f32 ) -> memref <?xf32 >
143
188
%col_indices = call @alloc_col_indices (%nnz , %c4 ) : (index , index ) -> memref <?xi32 >
144
189
%row_pointers = call @alloc_row_pointers (%c4 , %nnz_per_row ) : (index , index ) -> memref <?xi32 >
145
190
146
- //4x3
191
+ // Allocate and initialize dense input and result matrices
147
192
%dense = call @alloc_f32 (%c4 , %c3 , %f3 ) : (index , index , f32 ) -> memref <?x?xf32 >
148
193
//4x3
149
194
%result = call @alloc_f32 (%c4 , %c3 , %f0 ) : (index , index , f32 ) -> memref <?x?xf32 >
150
195
196
+ // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 3] strides = [3, 1] data =
197
+ // CHECK-NEXT: [[12, 12, 12],
198
+ // CHECK-NEXT: [12, 12, 12],
199
+ // CHECK-NEXT: [12, 12, 12],
200
+ // CHECK-NEXT: [12, 12, 12]]
151
201
call @spmm (%values , %col_indices , %row_pointers , %dense , %result ) :
152
202
(memref <?xf32 >, memref <?xi32 >, memref <?xi32 >, memref <?x?xf32 >, memref <?x?xf32 >) -> ()
153
203
0 commit comments