Skip to content

Commit 290d711

Browse files
adamcavendishjorge-ortega
authored andcommitted
chore(examples): restructure CUDA examples and add a GEMM example
- Refactored the CUDA examples directory for improved clarity and modularity so each example is more self-contained. - Added a new GEMM (General Matrix Multiply) example, including naive and tiled kernel implementations, build scripts, and benchmarks. - The tiled-gemm kernel demonstrates the shared memory usage.
1 parent 81aa642 commit 290d711

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+733
-86
lines changed

.github/workflows/ci_linux.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,18 @@ jobs:
7373
- name: Clippy
7474
env:
7575
RUSTFLAGS: -Dwarnings
76-
run: cargo clippy --workspace --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "ex*" --exclude "cudnn*"
76+
run: cargo clippy --workspace --exclude "optix*" --exclude "path-tracer" --exclude "denoiser" --exclude "ex*" --exclude "cudnn*"
7777

7878
- name: Build all bindings
7979
run: cargo build --all-features -p cust_raw
8080

8181
- name: Build workspace
82-
run: cargo build --workspace --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "ex*" --exclude "cudnn*"
82+
run: cargo build --workspace --exclude "optix*" --exclude "path-tracer" --exclude "denoiser" --exclude "ex*" --exclude "cudnn*"
8383

8484
- name: Check documentation
8585
env:
8686
RUSTDOCFLAGS: -Dwarnings
87-
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "ex*" --exclude "cudnn*" --exclude "cust_raw"
87+
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix*" --exclude "path-tracer" --exclude "denoiser" --exclude "ex*" --exclude "cudnn*" --exclude "cust_raw"
8888

8989
- name: Prepare artifact details
9090
id: artifact_details

.github/workflows/ci_windows.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
run: cargo build --all-features -p cust_raw
6767

6868
- name: Build
69-
run: cargo build --workspace --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*" --exclude "cudnn*"
69+
run: cargo build --workspace --exclude "optix*" --exclude "path-tracer" --exclude "denoiser" --exclude "vecadd*" --exclude "gemm*" --exclude "ex*" --exclude "cudnn*"
7070

7171
# Don't currently test because many tests rely on the system having a CUDA GPU
7272
# - name: Test
@@ -75,4 +75,4 @@ jobs:
7575
- name: Check documentation
7676
env:
7777
RUSTDOCFLAGS: -Dwarnings
78-
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix*" --exclude "path_tracer" --exclude "denoiser" --exclude "add" --exclude "ex*" --exclude "cudnn*" --exclude "cust_raw"
78+
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix*" --exclude "path-tracer" --exclude "denoiser" --exclude "vecadd*" --exclude "gemm*" --exclude "ex*" --exclude "cudnn*" --exclude "cust_raw"

Cargo.toml

+7-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@ members = [
88

99
"xtask",
1010

11+
"examples/cuda/vecadd",
12+
"examples/cuda/vecadd/kernels",
13+
"examples/cuda/gemm",
14+
"examples/cuda/gemm/kernels",
15+
"examples/cuda/path_tracer",
16+
"examples/cuda/path_tracer/kernels",
17+
1118
"examples/optix/*",
12-
"examples/cuda/cpu/*",
13-
"examples/cuda/gpu/*",
1419
]
1520

1621
exclude = [

examples/cuda/cpu/add/Cargo.toml

-22
This file was deleted.

examples/cuda/cpu/add/build.rs

-8
This file was deleted.

examples/cuda/cpu/path_tracer/build.rs

-14
This file was deleted.

examples/cuda/gemm/Cargo.toml

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[package]
2+
name = "gemm"
3+
version = "0.1.0"
4+
edition = "2024"
5+
6+
[dependencies]
7+
blastoff = { path = "../../../crates/blastoff" }
8+
cuda_std = { path = "../../../crates/cuda_std" }
9+
cust = { path = "../../../crates/cust" }
10+
cust_raw = { path = "../../../crates/cust_raw", features = ["driver"] }
11+
ndarray = { version = "0.16", features = ["approx"] }
12+
ndarray-rand = "0.15.0"
13+
rand = "0.9"
14+
15+
[build-dependencies]
16+
cuda_builder = { path = "../../../crates/cuda_builder" }

examples/cuda/gemm/build.rs

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
use std::env;
2+
use std::path;
3+
4+
use cuda_builder::CudaBuilder;
5+
6+
fn main() {
7+
println!("cargo::rerun-if-changed=build.rs");
8+
println!("cargo::rerun-if-changed=kernels");
9+
10+
let out_path = path::PathBuf::from(env::var("OUT_DIR").unwrap());
11+
CudaBuilder::new("kernels")
12+
.copy_to(out_path.join("kernels.ptx"))
13+
.build()
14+
.unwrap();
15+
}

examples/cuda/gemm/kernels/Cargo.toml

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[package]
2+
name = "gemm-kernels"
3+
version = "0.1.0"
4+
edition = "2024"
5+
6+
[dependencies]
7+
cuda_std = { path = "../../../../crates/cuda_std" }
8+
glam = { version = "0.30.1", default-features = false, features = ["cuda", "nostd-libm"] }
9+
10+
[lib]
11+
crate-type = ["cdylib", "rlib"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
use cuda_std::kernel;
2+
use cuda_std::thread;
3+
4+
#[kernel]
5+
#[allow(improper_ctypes_definitions)]
6+
/// Naive GEMM kernel for C = alpha * A * B + beta * C.
7+
///
8+
/// This kernel computes each element of the output matrix C independently, without any memory coalescing or tiling optimizations.
9+
///
10+
/// # Safety
11+
/// CUDA kernel requires unsafe.
12+
///
13+
/// # Parameters
14+
/// - `mat_a`: Input matrix A, shape (m x k), row-major order.
15+
/// - `mat_b`: Input matrix B, shape (k x n), row-major order.
16+
/// - `mat_c`: Output matrix C, shape (m x n), row-major order. Must be valid for writes.
17+
/// - `m`: Number of rows in A and C.
18+
/// - `n`: Number of columns in B and C.
19+
/// - `k`: Number of columns in A and rows in B.
20+
/// - `alpha`: Scalar multiplier for A * B.
21+
/// - `beta`: Scalar multiplier for C.
22+
///
23+
/// # Thread Mapping
24+
/// Each thread computes one element of C at (row, col).
25+
pub unsafe fn gemm_naive(
26+
mat_a: &[f32],
27+
mat_b: &[f32],
28+
mat_c: *mut f32,
29+
m: usize,
30+
n: usize,
31+
k: usize,
32+
alpha: f32,
33+
beta: f32,
34+
) {
35+
let row = (thread::block_dim_x() * thread::block_idx_x() + thread::thread_idx_x()) as usize;
36+
let col = (thread::block_dim_y() * thread::block_idx_y() + thread::thread_idx_y()) as usize;
37+
38+
if row < m && col < n {
39+
let mut sum = 0.0f32;
40+
for i in 0..k {
41+
sum += mat_a[row * k + i] * mat_b[i * n + col];
42+
}
43+
let elem = unsafe { &mut *mat_c.add((row * n + col) as usize) };
44+
*elem = alpha * sum + beta * *elem;
45+
}
46+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
use cuda_std::address_space;
2+
use cuda_std::kernel;
3+
use cuda_std::thread;
4+
5+
#[kernel]
6+
#[allow(improper_ctypes_definitions)]
7+
/// Tiled GEMM kernel for C = alpha * A * B + beta * C.
8+
///
9+
/// This kernel uses shared memory tiling to improve memory access patterns and performance.
10+
///
11+
/// # Safety
12+
/// CUDA kernel requires unsafe.
13+
///
14+
/// # Parameters
15+
/// - `mat_a`: Input matrix A, shape (m x k), row-major order.
16+
/// - `mat_b`: Input matrix B, shape (k x n), row-major order.
17+
/// - `mat_c`: Output matrix C, shape (m x n), row-major order. Must be valid for writes.
18+
/// - `m`: Number of rows in A and C.
19+
/// - `n`: Number of columns in B and C.
20+
/// - `k`: Number of columns in A and rows in B.
21+
/// - `alpha`: Scalar multiplier for A * B.
22+
/// - `beta`: Scalar multiplier for C.
23+
///
24+
/// # Tiling
25+
/// Each block computes a TILE_SIZE x TILE_SIZE tile of C using shared memory for A and B tiles.
26+
/// Threads within a block collaboratively load tiles and compute partial sums.
27+
///
28+
/// # Thread Mapping
29+
/// Each thread computes one element of the output tile.
30+
pub unsafe fn gemm_tiled(
31+
mat_a: &[f32],
32+
mat_b: &[f32],
33+
mat_c: *mut f32,
34+
m: usize,
35+
n: usize,
36+
k: usize,
37+
alpha: f32,
38+
beta: f32,
39+
) {
40+
const TILE_SIZE: usize = 16;
41+
42+
#[address_space(shared)]
43+
static mut TILE_A: [f32; TILE_SIZE * TILE_SIZE] = [0.; TILE_SIZE * TILE_SIZE];
44+
#[address_space(shared)]
45+
static mut TILE_B: [f32; TILE_SIZE * TILE_SIZE] = [0.; TILE_SIZE * TILE_SIZE];
46+
47+
// Thread indices within the block.
48+
let tx = thread::thread_idx_x() as usize;
49+
let ty = thread::thread_idx_y() as usize;
50+
51+
// Calculate row and column in the mat_c.
52+
let row = thread::block_idx_x() as usize * TILE_SIZE + ty;
53+
let col = thread::block_idx_y() as usize * TILE_SIZE + tx;
54+
55+
let mut sum = 0.0f32;
56+
// Loop over tiles of mat_a and mat_b in the k dimension.
57+
for kk in (0..k).step_by(TILE_SIZE) {
58+
// Collaborative loading of tiles into shared memory.
59+
if row < m && (kk + tx) < k {
60+
unsafe { TILE_A[ty * TILE_SIZE + tx] = mat_a[row * k + (kk + tx)] };
61+
} else {
62+
unsafe { TILE_A[ty * TILE_SIZE + tx] = 0.0f32 };
63+
}
64+
if col < n && (kk + ty) < k {
65+
unsafe { TILE_B[ty * TILE_SIZE + tx] = mat_b[(kk + ty) * n + col] };
66+
} else {
67+
unsafe { TILE_B[ty * TILE_SIZE + tx] = 0.0f32 };
68+
}
69+
thread::sync_threads();
70+
71+
// Perform the computation on the tile.
72+
for i in 0..TILE_SIZE {
73+
sum += unsafe { TILE_A[ty * TILE_SIZE + i] * TILE_B[i * TILE_SIZE + tx] };
74+
}
75+
thread::sync_threads();
76+
}
77+
78+
// Write the result back to mat_c with alpha and beta scaling.
79+
if row < m && col < n {
80+
let c = unsafe { mat_c.add(row * n + col) };
81+
unsafe { *c = alpha * sum + beta * *c };
82+
}
83+
}

examples/cuda/gemm/kernels/src/lib.rs

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mod gemm_naive;
2+
mod gemm_tiled;
3+
4+
pub use crate::gemm_naive::gemm_naive;
5+
pub use crate::gemm_tiled::gemm_tiled;

0 commit comments

Comments
 (0)