Skip to content

Commit f37cc7b

Browse files
fix: added shared memory space for matrix multiplication calculation
1 parent 23f8732 commit f37cc7b

File tree

1 file changed

+52
-0
lines changed
  • samples/introduction/matmul/kernels/src

1 file changed

+52
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use cuda_std::*;
2+
use core::mem::MaybeUninit;
3+
4+
// SAFETY: This function is unsafe because it dereferences raw pointers.
5+
#[kernel]
6+
pub unsafe fn matrix_mul_cuda(C: *mut f32, A: *const f32, B: *const f32, wa: usize, wb: usize) {
7+
let bx: usize = cuda_std::thread::block_idx().x as usize;
8+
let by: usize = cuda_std::thread::block_idx().y as usize;
9+
10+
let tx: usize = cuda_std::thread::thread_idx().x as usize;
11+
let ty: usize = cuda_std::thread::thread_idx().y as usize;
12+
13+
const BLOCK_SIZE: usize = 32;
14+
let a_begin = wa * BLOCK_SIZE * by;
15+
let a_end = a_begin + wa - 1;
16+
let a_step = BLOCK_SIZE;
17+
18+
let b_begin = BLOCK_SIZE * bx;
19+
let b_step = BLOCK_SIZE * wb;
20+
21+
let mut c_sub: f32 = 0.0;
22+
let mut b = b_begin;
23+
24+
for a in (a_begin..=a_end).step_by(a_step) {
25+
#[address_space(shared)]
26+
static mut As: [[MaybeUninit<f32>; BLOCK_SIZE]; BLOCK_SIZE] = [[const { MaybeUninit::uninit() }; BLOCK_SIZE]; BLOCK_SIZE];
27+
#[address_space(shared)]
28+
static mut Bs: [[MaybeUninit<f32>; BLOCK_SIZE]; BLOCK_SIZE] = [[const { MaybeUninit::uninit() }; BLOCK_SIZE]; BLOCK_SIZE];
29+
30+
// Load A and B matrices into shared memory
31+
unsafe {
32+
As[ty][tx].write(*A.add((a + wa * ty + tx) as usize));
33+
Bs[ty][tx].write(*B.add((b + wb * ty + tx) as usize));
34+
}
35+
36+
// Synchronize to make sure the matrices are loaded
37+
cuda_std::thread::sync_threads();
38+
for k in 0..BLOCK_SIZE {
39+
unsafe {
40+
c_sub += As[ty][k].assume_init() * Bs[k][tx].assume_init();
41+
}
42+
}
43+
44+
// Synchronize to make sure that the preceding computation is done before loading two new sub-matrices of A and B in the next iteration
45+
cuda_std::thread::sync_threads();
46+
47+
b += b_step;
48+
}
49+
50+
let c = wb * BLOCK_SIZE * by + BLOCK_SIZE * bx;
51+
unsafe { *C.add((c + wb * ty + tx) as usize) = c_sub; }
52+
}

0 commit comments

Comments
 (0)