Skip to content

Commit aefa92a

Browse files
committed
Add Cooperative Groups API integration
This works as follows: - Users build their Cuda code via `CudaBuilder` as normal. - If they want to use the cooperative groups API, then in their `build.rs`, just after building their PTX, they will: - Create a `cuda_builder::cg::CooperativeGroups` instance, - Add any needed opts for building the Cooperative Groups API bridge code (`-arch=sm_*` and so on), - Add their newly built PTX code to be linked with the CG API, which can include multiple PTX, cubin or fatbin files, - Call `.compile(..)`, which will spit out a fully linked `cubin`, - In the user's main application code, instead of using `launch!` to schedule their GPU work, they will now use `launch_cooperative!`.
1 parent 8a6cb73 commit aefa92a

File tree

14 files changed

+464
-30
lines changed

14 files changed

+464
-30
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ members = [
1414
]
1515

1616
exclude = [
17-
"crates/optix/examples/common"
17+
"crates/optix/examples/common",
1818
]
1919

2020
[profile.dev.package.rustc_codegen_nvvm]

crates/cuda_builder/Cargo.toml

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "cuda_builder"
3-
version = "0.3.0"
3+
version = "0.4.0"
44
edition = "2021"
55
authors = ["Riccardo D'Ambrosio <[email protected]>", "The rust-gpu Authors"]
66
license = "MIT OR Apache-2.0"
@@ -9,8 +9,16 @@ repository = "https://github.com/Rust-GPU/Rust-CUDA"
99
readme = "../../README.md"
1010

1111
[dependencies]
12+
anyhow = "1"
13+
thiserror = "1"
14+
cc = { version = "1", default-features = false, optional = true }
15+
cust = { path = "../cust", optional = true }
1216
rustc_codegen_nvvm = { version = "0.3", path = "../rustc_codegen_nvvm" }
1317
nvvm = { path = "../nvvm", version = "0.1" }
1418
serde = { version = "1.0.130", features = ["derive"] }
1519
serde_json = "1.0.68"
1620
find_cuda_helper = { version = "0.2", path = "../find_cuda_helper" }
21+
22+
[features]
23+
default = []
24+
cooperative_groups = ["cc", "cust"]

crates/cuda_builder/cg/cg_bridge.cu

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include "cooperative_groups.h"
2+
#include "cg_bridge.cuh"
3+
namespace cg = cooperative_groups;
4+
5+
__device__ GridGroup this_grid()
6+
{
7+
cg::grid_group gg = cg::this_grid();
8+
GridGroupWrapper* ggp = new GridGroupWrapper { gg };
9+
return ggp;
10+
}
11+
12+
__device__ void GridGroup_destroy(GridGroup gg)
13+
{
14+
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg);
15+
delete g;
16+
}
17+
18+
__device__ bool GridGroup_is_valid(GridGroup gg)
19+
{
20+
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg);
21+
return g->gg.is_valid();
22+
}
23+
24+
__device__ void GridGroup_sync(GridGroup gg)
25+
{
26+
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg);
27+
return g->gg.sync();
28+
}
29+
30+
__device__ unsigned long long GridGroup_size(GridGroup gg)
31+
{
32+
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg);
33+
return g->gg.size();
34+
}
35+
36+
__device__ unsigned long long GridGroup_thread_rank(GridGroup gg)
37+
{
38+
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg);
39+
return g->gg.thread_rank();
40+
}
41+
42+
__device__ unsigned long long GridGroup_num_threads(GridGroup gg)
43+
{
44+
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg);
45+
return g->gg.num_threads();
46+
}
47+
48+
__device__ unsigned long long GridGroup_num_blocks(GridGroup gg)
49+
{
50+
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg);
51+
return g->gg.num_blocks();
52+
}
53+
54+
__device__ unsigned long long GridGroup_block_rank(GridGroup gg)
55+
{
56+
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg);
57+
return g->gg.block_rank();
58+
}
59+
60+
__host__ int main()
61+
{}

crates/cuda_builder/cg/cg_bridge.cuh

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
#include "cooperative_groups.h"
3+
namespace cg = cooperative_groups;
4+
5+
typedef struct GridGroupWrapper {
6+
cg::grid_group gg;
7+
} GridGroupWrapper;
8+
9+
extern "C" typedef void* GridGroup;
10+
extern "C" __device__ GridGroup this_grid();
11+
extern "C" __device__ void GridGroup_destroy(GridGroup gg);
12+
extern "C" __device__ bool GridGroup_is_valid(GridGroup gg);
13+
extern "C" __device__ void GridGroup_sync(GridGroup gg);
14+
extern "C" __device__ unsigned long long GridGroup_size(GridGroup gg);
15+
extern "C" __device__ unsigned long long GridGroup_thread_rank(GridGroup gg);
16+
// extern "C" dim3 GridGroup_group_dim(); // TODO: impl these.
17+
extern "C" __device__ unsigned long long GridGroup_num_threads(GridGroup gg);
18+
// extern "C" dim3 GridGroup_dim_blocks(); // TODO: impl these.
19+
extern "C" __device__ unsigned long long GridGroup_num_blocks(GridGroup gg);
20+
// extern "C" dim3 GridGroup_block_index(); // TODO: impl these.
21+
extern "C" __device__ unsigned long long GridGroup_block_rank(GridGroup gg);

crates/cuda_builder/src/cg.rs

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
//! Cooperative Groups compilation and linking.
2+
3+
use std::path::{Path, PathBuf};
4+
5+
use anyhow::Context;
6+
7+
use crate::{CudaBuilderError, CudaBuilderResult};
8+
9+
/// An artifact which may be linked together with the Cooperative Groups API bridge PTX code.
10+
pub enum LinkableArtifact {
11+
/// A PTX artifact.
12+
Ptx(PathBuf),
13+
/// A cubin artifact.
14+
Cubin(PathBuf),
15+
/// A fatbin artifact.
16+
Fatbin(PathBuf),
17+
}
18+
19+
impl LinkableArtifact {
20+
/// Add this artifact to the given linker.
21+
fn link_artifact(&self, linker: &mut cust::link::Linker) -> CudaBuilderResult<()> {
22+
match &self {
23+
LinkableArtifact::Ptx(path) => {
24+
let mut data = std::fs::read_to_string(&path).with_context(|| {
25+
format!("error reading PTX file for linking, file={:?}", path)
26+
})?;
27+
if !data.ends_with('\0') {
28+
// If the PTX is not null-terminated, then linking will fail. Only required for PTX.
29+
data.push('\0');
30+
}
31+
linker
32+
.add_ptx(&data)
33+
.with_context(|| format!("error linking PTX file={:?}", path))?;
34+
}
35+
LinkableArtifact::Cubin(path) => {
36+
let data = std::fs::read(&path).with_context(|| {
37+
format!("error reading cubin file for linking, file={:?}", path)
38+
})?;
39+
linker
40+
.add_cubin(&data)
41+
.with_context(|| format!("error linking cubin file={:?}", path))?;
42+
}
43+
LinkableArtifact::Fatbin(path) => {
44+
let data = std::fs::read(&path).with_context(|| {
45+
format!("error reading fatbin file for linking, file={:?}", path)
46+
})?;
47+
linker
48+
.add_fatbin(&data)
49+
.with_context(|| format!("error linking fatbin file={:?}", path))?;
50+
}
51+
}
52+
Ok(())
53+
}
54+
}
55+
56+
/// A builder which will compile the Cooperative Groups API bridging code, and will then link it
57+
/// together with any other artifacts provided to this builder.
58+
///
59+
/// The result of this process will be a `cubin` file containing the linked Cooperative Groups
60+
/// PTX code along with any other linked artifacts provided to this builder. The output `cubin`
61+
/// may then be loaded via `cust::module::Module::from_cubin(..)` and used as normal.
62+
#[derive(Default)]
63+
pub struct CooperativeGroups {
64+
/// Artifacts to be linked together with the Cooperative Groups bridge code.
65+
artifacts: Vec<LinkableArtifact>,
66+
/// Flags to pass to nvcc for Cooperative Groups API bridge compilation.
67+
nvcc_flags: Vec<String>,
68+
}
69+
70+
impl CooperativeGroups {
71+
/// Construct a new instance.
72+
pub fn new() -> Self {
73+
Self::default()
74+
}
75+
76+
/// Add the artifact at the given path for linking.
77+
///
78+
/// This only applies to linking with the Cooperative Groups API bridge code. Typically,
79+
/// this will be the PTX of your main program which has already been built via `CudaBuilder`.
80+
pub fn link(mut self, artifact: LinkableArtifact) -> Self {
81+
self.artifacts.push(artifact);
82+
self
83+
}
84+
85+
/// Add a flag to be passed along to `nvcc` during compilation of the Cooperative Groups API bridge code.
86+
///
87+
/// This provides maximum flexibility for code generation. If needed, multiple architectures
88+
/// may be generated by adding the appropriate flags to the `nvcc` call.
89+
///
90+
/// By default, `nvcc` will generate code for `sm_52`. Override by specifying any of `--gpu-architecture`,
91+
/// `--gpu-code`, or `--generate-code` flags.
92+
///
93+
/// Regardless of the flags added via this method, this builder will always added the following flags:
94+
/// - `-I<cudaRoot>/include`: ensuring `cooperative_groups.h` can be found.
95+
/// - `-Icg`: ensuring the bridging header can be found.
96+
/// - `--ptx`: forces the compiled output to be in PTX form.
97+
/// - `--device-c`: to compile the bridging code as relocatable device code.
98+
/// - `src/cg_bridge.cu` will be added as the code to be compiled, which generates the
99+
/// Cooperative Groups API bridge.
100+
///
101+
/// Docs: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#command-option-description
102+
pub fn nvcc_flag(mut self, val: impl AsRef<str>) -> Self {
103+
self.nvcc_flags.push(val.as_ref().to_string());
104+
self
105+
}
106+
107+
/// Compile the Cooperative Groups API bridging code, and then link it together
108+
/// with any other artifacts provided to this builder.
109+
///
110+
/// - `cg_out` specifies the output location for the Cooperative Groups API bridge PTX.
111+
/// - `cubin_out` specifies the output location for the fully linked `cubin`.
112+
///
113+
/// ## Errors
114+
/// - At least one artifact must be provided to this builder for linking.
115+
/// - Any errors which take place from the `nvcc` compilation of the Cooperative Groups briding
116+
/// code, or any errors which take place during module linking.
117+
pub fn compile(
118+
mut self,
119+
cg_out: impl AsRef<Path>,
120+
cubin_out: impl AsRef<Path>,
121+
) -> CudaBuilderResult<()> {
122+
// Perform some initial validation.
123+
if self.artifacts.is_empty() {
124+
return Err(anyhow::anyhow!("must provide at least 1 ptx/cubin/fatbin artifact to be linked with the Cooperative Groups API bridge code").into());
125+
}
126+
127+
// Find the cuda installation directory for compilation of CG API.
128+
let cuda_root =
129+
find_cuda_helper::find_cuda_root().ok_or(CudaBuilderError::CudaRootNotFound)?;
130+
let cuda_include = cuda_root.join("include");
131+
let cg_src = std::path::Path::new(std::file!())
132+
.parent()
133+
.context("error accessing parent dir cuda_builder/src")?
134+
.parent()
135+
.context("error accessing parent dir cuda_builder")?
136+
.join("cg")
137+
.canonicalize()
138+
.context("error taking canonical path to cooperative groups API bridge code")?;
139+
let cg_bridge_cu = cg_src.join("cg_bridge.cu");
140+
141+
// Build up the `nvcc` invocation and then build the bridging code.
142+
let mut nvcc = std::process::Command::new("nvcc");
143+
nvcc.arg(format!("-I{:?}", &cuda_include).as_str())
144+
.arg(format!("-I{:?}", &cg_src).as_str())
145+
.arg("--ptx")
146+
.arg("-o")
147+
.arg(cg_out.as_ref().to_string_lossy().as_ref())
148+
.arg("--device-c")
149+
.arg(cg_bridge_cu.to_string_lossy().as_ref());
150+
for flag in self.nvcc_flags.iter() {
151+
nvcc.arg(flag.as_str());
152+
}
153+
nvcc.status()
154+
.context("error calling nvcc for Cooperative Groups API bridge compilation")?;
155+
156+
// Link together the briding code with any given PTX/cubin/fatbin artifacts.
157+
let _ctx = cust::quick_init().context("error building cuda context")?;
158+
let mut linker = cust::link::Linker::new().context("error building cust linker")?;
159+
self.artifacts
160+
.push(LinkableArtifact::Ptx(cg_out.as_ref().to_path_buf()));
161+
for artifact in self.artifacts.iter() {
162+
artifact.link_artifact(&mut linker)?;
163+
}
164+
let linked_cubin = linker
165+
.complete()
166+
.context("error linking artifacts with Cooperative Groups API bridge PTX")?;
167+
168+
// Write finalized cubin.
169+
std::fs::write(&cubin_out, &linked_cubin)
170+
.with_context(|| format!("error writing linked cubin to {:?}", cubin_out.as_ref()))?;
171+
172+
Ok(())
173+
}
174+
}

crates/cuda_builder/src/lib.rs

+18-17
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,37 @@
11
//! Utility crate for easily building CUDA crates using rustc_codegen_nvvm. Derived from rust-gpu's spirv_builder.
22
3+
#[cfg(feature = "cooperative_groups")]
4+
pub mod cg;
5+
36
pub use nvvm::*;
47
use serde::Deserialize;
58
use std::{
69
borrow::Borrow,
710
env,
811
ffi::OsString,
9-
fmt,
1012
path::{Path, PathBuf},
1113
process::{Command, Stdio},
1214
};
1315

14-
#[derive(Debug)]
16+
/// Cuda builder result type.
17+
pub type CudaBuilderResult<T> = Result<T, CudaBuilderError>;
18+
19+
/// Cuda builder error type.
20+
#[derive(thiserror::Error, Debug)]
1521
#[non_exhaustive]
1622
pub enum CudaBuilderError {
23+
#[error("crate path {0} does not exist")]
1724
CratePathDoesntExist(PathBuf),
18-
FailedToCopyPtxFile(std::io::Error),
25+
#[error("build failed")]
1926
BuildFailed,
20-
}
21-
22-
impl fmt::Display for CudaBuilderError {
23-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24-
match self {
25-
CudaBuilderError::CratePathDoesntExist(path) => {
26-
write!(f, "Crate path {} does not exist", path.display())
27-
}
28-
CudaBuilderError::BuildFailed => f.write_str("Build failed"),
29-
CudaBuilderError::FailedToCopyPtxFile(err) => {
30-
f.write_str(&format!("Failed to copy PTX file: {:?}", err))
31-
}
32-
}
33-
}
27+
#[error("failed to copy PTX file: {0:?}")]
28+
FailedToCopyPtxFile(#[from] std::io::Error),
29+
#[cfg(feature = "cooperative_groups")]
30+
#[error("could not find cuda root installation dir")]
31+
CudaRootNotFound,
32+
#[cfg(feature = "cooperative_groups")]
33+
#[error("compilation of the Cooperative Groups API bridge code failed: {0}")]
34+
CGError(#[from] anyhow::Error),
3435
}
3536

3637
#[derive(Debug, Clone, Copy, PartialEq)]

crates/cuda_std/Cargo.toml

+7
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,10 @@ cuda_std_macros = { version = "0.2", path = "../cuda_std_macros" }
1313
half = "1.7.1"
1414
bitflags = "1.3.2"
1515
paste = "1.0.5"
16+
17+
[features]
18+
default = []
19+
cooperative_groups = []
20+
21+
[package.metadata.docs.rs]
22+
all-features = true

0 commit comments

Comments
 (0)