Skip to content

Commit be37567

Browse files
feat: increased dimension of matrices being computed, and implemented Kahan's error correction to stop floating point accumulation errors
1 parent 360c227 commit be37567

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

samples/introduction/matmul/kernels/src/lib.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,24 @@ pub unsafe fn matrix_mul_cuda(C: *mut f32, A: *const f32, B: *const f32, wa: usi
1919
let b_step = BLOCK_SIZE * wb;
2020

2121
let mut c_sub: f32 = 0.0;
22+
let mut kahan_correction_factor = 0.0f32;
2223
let mut b = b_begin;
2324

2425
for a in (a_begin..=a_end).step_by(a_step) {
26+
// The equivalent Cuda C++ code for the below is:
27+
// ```
28+
// __shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
29+
// ```
30+
// This memory space is shared between threads of the same block
2531
#[address_space(shared)]
2632
static mut As: [[MaybeUninit<f32>; BLOCK_SIZE]; BLOCK_SIZE] = [[const { MaybeUninit::uninit() }; BLOCK_SIZE]; BLOCK_SIZE];
33+
2734
#[address_space(shared)]
2835
static mut Bs: [[MaybeUninit<f32>; BLOCK_SIZE]; BLOCK_SIZE] = [[const { MaybeUninit::uninit() }; BLOCK_SIZE]; BLOCK_SIZE];
2936

3037
// Load A and B matrices into shared memory
38+
// A.add(index) returns the pointer to the index-th element of A
39+
// Hence a dereference is needed to get the value at that index
3140
unsafe {
3241
As[ty][tx].write(*A.add((a + wa * ty + tx) as usize));
3342
Bs[ty][tx].write(*B.add((b + wb * ty + tx) as usize));
@@ -36,8 +45,21 @@ pub unsafe fn matrix_mul_cuda(C: *mut f32, A: *const f32, B: *const f32, wa: usi
3645
// Synchronize to make sure the matrices are loaded
3746
cuda_std::thread::sync_threads();
3847
for k in 0..BLOCK_SIZE {
48+
// Typically, this would be a simple calculation:
49+
// ```
50+
// c_sub += As[ty][k] * Bs[k][tx];
51+
// ```
52+
// However, to improve numerical stability, we use Kahan summation here so that the error can be isolated
53+
// and not allow it to accumulate in c_sub
3954
unsafe {
40-
c_sub += As[ty][k].assume_init() * Bs[k][tx].assume_init();
55+
let input = As[ty][k].assume_init() * Bs[k][tx].assume_init();
56+
let y = input - kahan_correction_factor;
57+
let t = c_sub + y;
58+
59+
// This seems like the correction factor would yield zero, however due to f32 precision limitations,
60+
// it helps to isolate the small errors that would otherwise accumulate in c_sub
61+
kahan_correction_factor = (t - c_sub) - y;
62+
c_sub = t;
4163
}
4264
}
4365

samples/introduction/matmul/src/main.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ fn matrix_multiply(
3636
DeviceBuffer::from_slice(h_c.as_slice()).expect("device array couldn't be initialized!");
3737

3838
stream.synchronize().expect("Stream couldn't synchronize!");
39-
let threads = BlockSize::xy(block_size as u32, block_size as u32);
39+
let blocks = BlockSize::xy(block_size as u32, block_size as u32);
4040
let grid = GridSize::xy(
41-
(dims_b.0 / (threads.x as usize)).try_into().unwrap(),
42-
(dims_a.1 / (threads.y as usize)).try_into().unwrap(),
41+
(dims_b.0 / (blocks.x as usize)).try_into().unwrap(),
42+
(dims_a.1 / (blocks.y as usize)).try_into().unwrap(),
4343
);
4444

4545
println!("Computing result using CUDA Kernel...");
@@ -50,7 +50,7 @@ fn matrix_multiply(
5050
.expect("Kernel function not found!");
5151

5252
unsafe {
53-
launch!(matrix_mul_cuda<<<grid, threads, 0, stream>>>(
53+
launch!(matrix_mul_cuda<<<grid, blocks, 0, stream>>>(
5454
d_c.as_device_ptr(),
5555
d_a.as_device_ptr(),
5656
d_b.as_device_ptr(),
@@ -70,7 +70,7 @@ fn matrix_multiply(
7070

7171
for _ in 0..N_ITER {
7272
unsafe {
73-
launch!(matrix_mul_cuda<<<grid, threads, 0, stream>>>(
73+
launch!(matrix_mul_cuda<<<grid, blocks, 0, stream>>>(
7474
d_c.as_device_ptr(),
7575
d_a.as_device_ptr(),
7676
d_b.as_device_ptr(),
@@ -152,8 +152,8 @@ fn main() -> Result<(), cust::error::CudaError> {
152152
println!("Device Name: {}", device.name().unwrap());
153153

154154
let block_size: u32 = 32;
155-
let dims_a: (usize, usize, usize) = (block_size as usize, block_size as usize, 1);
156-
let dims_b: (usize, usize, usize) = (block_size as usize, block_size as usize, 1);
155+
let dims_a: (usize, usize, usize) = (40 * block_size as usize, 40 * block_size as usize, 1);
156+
let dims_b: (usize, usize, usize) = (80 * block_size as usize, 40 * block_size as usize, 1);
157157

158158
if dims_a.0 != dims_b.1 {
159159
panic!("Matrix multiplication not possible with the given dimensions!");

0 commit comments

Comments
 (0)