@@ -36,10 +36,10 @@ fn matrix_multiply(
3636 DeviceBuffer :: from_slice ( h_c. as_slice ( ) ) . expect ( "device array couldn't be initialized!" ) ;
3737
3838 stream. synchronize ( ) . expect ( "Stream couldn't synchronize!" ) ;
39- let threads = BlockSize :: xy ( block_size as u32 , block_size as u32 ) ;
39+ let blocks = BlockSize :: xy ( block_size as u32 , block_size as u32 ) ;
4040 let grid = GridSize :: xy (
41- ( dims_b. 0 / ( threads . x as usize ) ) . try_into ( ) . unwrap ( ) ,
42- ( dims_a. 1 / ( threads . y as usize ) ) . try_into ( ) . unwrap ( ) ,
41+ ( dims_b. 0 / ( blocks . x as usize ) ) . try_into ( ) . unwrap ( ) ,
42+ ( dims_a. 1 / ( blocks . y as usize ) ) . try_into ( ) . unwrap ( ) ,
4343 ) ;
4444
4545 println ! ( "Computing result using CUDA Kernel..." ) ;
@@ -50,7 +50,7 @@ fn matrix_multiply(
5050 . expect ( "Kernel function not found!" ) ;
5151
5252 unsafe {
53- launch ! ( matrix_mul_cuda<<<grid, threads , 0 , stream>>>(
53+ launch ! ( matrix_mul_cuda<<<grid, blocks , 0 , stream>>>(
5454 d_c. as_device_ptr( ) ,
5555 d_a. as_device_ptr( ) ,
5656 d_b. as_device_ptr( ) ,
@@ -70,7 +70,7 @@ fn matrix_multiply(
7070
7171 for _ in 0 ..N_ITER {
7272 unsafe {
73- launch ! ( matrix_mul_cuda<<<grid, threads , 0 , stream>>>(
73+ launch ! ( matrix_mul_cuda<<<grid, blocks , 0 , stream>>>(
7474 d_c. as_device_ptr( ) ,
7575 d_a. as_device_ptr( ) ,
7676 d_b. as_device_ptr( ) ,
@@ -152,8 +152,8 @@ fn main() -> Result<(), cust::error::CudaError> {
152152 println ! ( "Device Name: {}" , device. name( ) . unwrap( ) ) ;
153153
154154 let block_size: u32 = 32 ;
155- let dims_a: ( usize , usize , usize ) = ( block_size as usize , block_size as usize , 1 ) ;
156- let dims_b: ( usize , usize , usize ) = ( block_size as usize , block_size as usize , 1 ) ;
155+ let dims_a: ( usize , usize , usize ) = ( 40 * block_size as usize , 40 * block_size as usize , 1 ) ;
156+ let dims_b: ( usize , usize , usize ) = ( 80 * block_size as usize , 40 * block_size as usize , 1 ) ;
157157
158158 if dims_a. 0 != dims_b. 1 {
159159 panic ! ( "Matrix multiplication not possible with the given dimensions!" ) ;
0 commit comments