1+ use cuda_std:: * ;
2+ use core:: mem:: MaybeUninit ;
3+
4+ // SAFETY: This function is unsafe because it dereferences raw pointers.
5+ #[ kernel]
6+ pub unsafe fn matrix_mul_cuda ( C : * mut f32 , A : * const f32 , B : * const f32 , wa : usize , wb : usize ) {
7+ let bx: usize = cuda_std:: thread:: block_idx ( ) . x as usize ;
8+ let by: usize = cuda_std:: thread:: block_idx ( ) . y as usize ;
9+
10+ let tx: usize = cuda_std:: thread:: thread_idx ( ) . x as usize ;
11+ let ty: usize = cuda_std:: thread:: thread_idx ( ) . y as usize ;
12+
13+ const BLOCK_SIZE : usize = 32 ;
14+ let a_begin = wa * BLOCK_SIZE * by;
15+ let a_end = a_begin + wa - 1 ;
16+ let a_step = BLOCK_SIZE ;
17+
18+ let b_begin = BLOCK_SIZE * bx;
19+ let b_step = BLOCK_SIZE * wb;
20+
21+ let mut c_sub: f32 = 0.0 ;
22+ let mut b = b_begin;
23+
24+ for a in ( a_begin..=a_end) . step_by ( a_step) {
25+ #[ address_space( shared) ]
26+ static mut As : [ [ MaybeUninit < f32 > ; BLOCK_SIZE ] ; BLOCK_SIZE ] = [ [ const { MaybeUninit :: uninit ( ) } ; BLOCK_SIZE ] ; BLOCK_SIZE ] ;
27+ #[ address_space( shared) ]
28+ static mut Bs : [ [ MaybeUninit < f32 > ; BLOCK_SIZE ] ; BLOCK_SIZE ] = [ [ const { MaybeUninit :: uninit ( ) } ; BLOCK_SIZE ] ; BLOCK_SIZE ] ;
29+
30+ // Load A and B matrices into shared memory
31+ unsafe {
32+ As [ ty] [ tx] . write ( * A . add ( ( a + wa * ty + tx) as usize ) ) ;
33+ Bs [ ty] [ tx] . write ( * B . add ( ( b + wb * ty + tx) as usize ) ) ;
34+ }
35+
36+ // Synchronize to make sure the matrices are loaded
37+ cuda_std:: thread:: sync_threads ( ) ;
38+ for k in 0 ..BLOCK_SIZE {
39+ unsafe {
40+ c_sub += As [ ty] [ k] . assume_init ( ) * Bs [ k] [ tx] . assume_init ( ) ;
41+ }
42+ }
43+
44+ // Synchronize to make sure that the preceding computation is done before loading two new sub-matrices of A and B in the next iteration
45+ cuda_std:: thread:: sync_threads ( ) ;
46+
47+ b += b_step;
48+ }
49+
50+ let c = wb * BLOCK_SIZE * by + BLOCK_SIZE * bx;
51+ unsafe { * C . add ( ( c + wb * ty + tx) as usize ) = c_sub; }
52+ }
0 commit comments