Skip to content

Commit f67e15a

Browse files
committed
update: example/BuddySparse, add more comments,verify the correctness with FileCheck
1 parent d850e9e commit f67e15a

File tree

4 files changed

+58
-40
lines changed

4 files changed

+58
-40
lines changed

examples/BuddyMatmul/linalg-batchmatmul-f32.mlir

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,11 @@ func.func @main(){
7777
%m4 = call @alloc_f32(%c1, %c1024, %c1000, %f3) : (index, index, index, f32) -> memref<?x?x?xf32>
7878
%m5 = call @alloc_f32(%c1, %c1, %c1000, %f0) : (index, index, index, f32) -> memref<?x?x?xf32>
7979

80-
// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1000] strides = [1000, 1000, 1] data =
81-
// CHECK-NEXT: [
82-
// CHECK: [
83-
// CHECK: [6144{{(, 6144)*}}]
80+
// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 3] strides = [3, 1] data =
81+
// CHECK-NEXT: [[12, 12, 12],
82+
// CHECK-NEXT: [12, 12, 12],
83+
// CHECK-NEXT: [12, 12, 12],
84+
// CHECK-NEXT: [12, 12, 12]]
8485
call @batch_matmul(%m3, %m4, %m5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
8586

8687
return

examples/BuddySpMM/verify_spmm.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

examples/BuddySpMM/linalg-spmm-f32.mlir renamed to examples/BuddySparse/linalg-spmm-f32.mlir

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,45 +16,74 @@
1616
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
1717
// RUN: | FileCheck %s
1818

19+
20+
// External functions for utilities
21+
// printMemrefF32: Prints contents of a float32 memref buffer
22+
// rtclock: Returns current time for performance measurement
1923
func.func private @printMemrefF32(memref<*xf32>)
2024
func.func private @rtclock() -> f64
2125

26+
// Main SpMM computation kernel
27+
// Parameters:
28+
// - values: Non-zero values of sparse matrix A in CSR format
29+
// - col_indices: Column indices for each non-zero in A
30+
// - row_pointers: Start/end indices for each row of A in values array
31+
// - dense: Dense input matrix B
32+
// - result: Output dense matrix C = A * B
2233
func.func @spmm(%values: memref<?xf32>, %col_indices: memref<?xi32>,
2334
%row_pointers: memref<?xi32>, %dense: memref<?x?xf32>,
2435
%result: memref<?x?xf32>) {
2536
%c0 = arith.constant 0 : index
2637
%c1 = arith.constant 1 : index
2738

39+
// Get dimensions for iteration bounds
40+
// num_rows: Number of rows in sparse matrix A and result C
41+
// num_cols: Number of columns in result C
42+
// dense_cols: Number of columns in dense matrix B
2843
%num_rows = memref.dim %result, %c0 : memref<?x?xf32>
2944
%num_cols = memref.dim %result, %c1 : memref<?x?xf32>
3045
%dense_cols = memref.dim %dense, %c1 : memref<?x?xf32>
3146

47+
// Start timing the computation
3248
%t_start = call @rtclock() : () -> f64
3349

50+
// Main computation loops:
51+
// 1. Outer loop: Iterate over each row of sparse matrix A
52+
// 2. Middle loop: Iterate over each column of dense matrix B
53+
// 3. Inner loop: Process non-zeros in current row of A
3454
scf.for %i = %c0 to %num_rows step %c1 {
55+
// Get start and end indices for current row in CSR format
56+
// These indices mark the range of non-zeros in the current row
3557
%row_start_ptr = memref.load %row_pointers[%i] : memref<?xi32>
3658
%row_start = arith.index_cast %row_start_ptr : i32 to index
3759

3860
%i_plus_1 = arith.addi %i, %c1 : index
3961
%row_end_ptr = memref.load %row_pointers[%i_plus_1] : memref<?xi32>
4062
%row_end = arith.index_cast %row_end_ptr : i32 to index
4163

64+
// Process each column in result matrix C
4265
scf.for %j = %c0 to %dense_cols step %c1 {
66+
// Initialize accumulator for dot product computation
4367
%sum = arith.constant 0.0 : f32
4468

69+
// Compute dot product of current row of A with column j of B
4570
%result_sum = scf.for %k = %row_start to %row_end step %c1 iter_args(%current_sum = %sum) -> (f32) {
71+
// Load non-zero value from A and its column index
4672
%val = memref.load %values[%k] : memref<?xf32>
4773
%col_ptr = memref.load %col_indices[%k] : memref<?xi32>
4874
%col = arith.index_cast %col_ptr : i32 to index
4975

76+
// Load corresponding value from dense matrix B
5077
%dense_val = memref.load %dense[%col, %j] : memref<?x?xf32>
5178

79+
// Multiply and accumulate into partial sum
5280
%prod = arith.mulf %val, %dense_val : f32
5381
%new_sum = arith.addf %current_sum, %prod : f32
5482

5583
scf.yield %new_sum : f32
5684
}
5785

86+
// Store computed result in output matrix C
5887
memref.store %result_sum, %result[%i, %j] : memref<?x?xf32>
5988
}
6089
}
@@ -71,6 +100,8 @@ func.func @spmm(%values: memref<?xf32>, %col_indices: memref<?xi32>,
71100
return
72101
}
73102

103+
// Helper function to allocate and initialize values array
104+
// Allocates memory for given size and fills with specified value
74105
func.func @alloc_values(%size: index, %val: f32) -> memref<?xf32> {
75106
%c0 = arith.constant 0 : index
76107
%c1 = arith.constant 1 : index
@@ -83,6 +114,8 @@ func.func @alloc_values(%size: index, %val: f32) -> memref<?xf32> {
83114
return %values : memref<?xf32>
84115
}
85116

117+
// Helper function to allocate and initialize column indices array
118+
// Creates cyclic pattern of indices modulo pattern size
86119
func.func @alloc_col_indices(%size: index, %pattern: index) -> memref<?xi32> {
87120
%c0 = arith.constant 0 : index
88121
%c1 = arith.constant 1 : index
@@ -98,6 +131,8 @@ func.func @alloc_col_indices(%size: index, %pattern: index) -> memref<?xi32> {
98131
return %indices : memref<?xi32>
99132
}
100133

134+
// Helper function to allocate and initialize row pointers array
135+
// Creates regular pattern with fixed number of non-zeros per row
101136
func.func @alloc_row_pointers(%rows: index, %nnz_per_row: index) -> memref<?xi32> {
102137
%c0 = arith.constant 0 : index
103138
%c1 = arith.constant 1 : index
@@ -113,6 +148,8 @@ func.func @alloc_row_pointers(%rows: index, %nnz_per_row: index) -> memref<?xi32
113148
return %pointers : memref<?xi32>
114149
}
115150

151+
// Helper function to allocate and initialize 2D float array
152+
// Allocates memory for given dimensions and fills with specified value
116153
func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: f32) -> memref<?x?xf32> {
117154
%c0 = arith.constant 0 : index
118155
%c1 = arith.constant 1 : index
@@ -127,6 +164,11 @@ func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: f32) -> memref<?x?xf32>
127164
return %0 : memref<?x?xf32>
128165
}
129166

167+
// Main function that sets up test case and runs SpMM
168+
// Creates a test case with:
169+
// - 4x4 sparse matrix with 2 non-zeros per row (values = 2.0)
170+
// - 4x3 dense matrix (values = 3.0)
171+
// - 4x3 result matrix (initialized to 0.0)
130172
func.func @main() {
131173
%c3 = arith.constant 3 : index
132174
%c4 = arith.constant 4 : index
@@ -135,19 +177,27 @@ func.func @main() {
135177
%f2 = arith.constant 2.0 : f32
136178
%f3 = arith.constant 3.0 : f32
137179

138-
//test:
180+
// Test parameters:
181+
// nnz: Total number of non-zeros in sparse matrix (8)
182+
// nnz_per_row: Number of non-zeros per row (2)
139183
%nnz = arith.constant 8 : index
140184
%nnz_per_row = arith.constant 2 : index
141185

186+
// Allocate and initialize sparse matrix components
142187
%values = call @alloc_values(%nnz, %f2) : (index, f32) -> memref<?xf32>
143188
%col_indices = call @alloc_col_indices(%nnz, %c4) : (index, index) -> memref<?xi32>
144189
%row_pointers = call @alloc_row_pointers(%c4, %nnz_per_row) : (index, index) -> memref<?xi32>
145190

146-
//4x3
191+
// Allocate and initialize dense input and result matrices
147192
%dense = call @alloc_f32(%c4, %c3, %f3) : (index, index, f32) -> memref<?x?xf32>
148193
//4x3
149194
%result = call @alloc_f32(%c4, %c3, %f0) : (index, index, f32) -> memref<?x?xf32>
150195

196+
// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 3] strides = [3, 1] data =
197+
// CHECK-NEXT: [[12, 12, 12],
198+
// CHECK-NEXT: [12, 12, 12],
199+
// CHECK-NEXT: [12, 12, 12],
200+
// CHECK-NEXT: [12, 12, 12]]
151201
call @spmm(%values, %col_indices, %row_pointers, %dense, %result) :
152202
(memref<?xf32>, memref<?xi32>, memref<?xi32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
153203

examples/BuddySpMM/makefile renamed to examples/BuddySparse/makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ linalg-spmm-f32-run:
3434
-finalize-memref-to-llvm \
3535
-reconcile-unrealized-casts | \
3636
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
37-
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}
37+
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

0 commit comments

Comments
 (0)