|
| 1 | +//! Cooperative Groups compilation and linking. |
| 2 | +
|
| 3 | +use std::path::{Path, PathBuf}; |
| 4 | + |
| 5 | +use anyhow::Context; |
| 6 | + |
| 7 | +use crate::{CudaBuilderError, CudaBuilderResult}; |
| 8 | + |
| 9 | +/// An artifact which may be linked together with the Cooperative Groups API bridge PTX code. |
| 10 | +pub enum LinkableArtifact { |
| 11 | + /// A PTX artifact. |
| 12 | + Ptx(PathBuf), |
| 13 | + /// A cubin artifact. |
| 14 | + Cubin(PathBuf), |
| 15 | + /// A fatbin artifact. |
| 16 | + Fatbin(PathBuf), |
| 17 | +} |
| 18 | + |
| 19 | +impl LinkableArtifact { |
| 20 | + /// Add this artifact to the given linker. |
| 21 | + fn link_artifact(&self, linker: &mut cust::link::Linker) -> CudaBuilderResult<()> { |
| 22 | + match &self { |
| 23 | + LinkableArtifact::Ptx(path) => { |
| 24 | + let mut data = std::fs::read_to_string(&path).with_context(|| { |
| 25 | + format!("error reading PTX file for linking, file={:?}", path) |
| 26 | + })?; |
| 27 | + if !data.ends_with('\0') { |
| 28 | + // If the PTX is not null-terminated, then linking will fail. Only required for PTX. |
| 29 | + data.push('\0'); |
| 30 | + } |
| 31 | + linker |
| 32 | + .add_ptx(&data) |
| 33 | + .with_context(|| format!("error linking PTX file={:?}", path))?; |
| 34 | + } |
| 35 | + LinkableArtifact::Cubin(path) => { |
| 36 | + let data = std::fs::read(&path).with_context(|| { |
| 37 | + format!("error reading cubin file for linking, file={:?}", path) |
| 38 | + })?; |
| 39 | + linker |
| 40 | + .add_cubin(&data) |
| 41 | + .with_context(|| format!("error linking cubin file={:?}", path))?; |
| 42 | + } |
| 43 | + LinkableArtifact::Fatbin(path) => { |
| 44 | + let data = std::fs::read(&path).with_context(|| { |
| 45 | + format!("error reading fatbin file for linking, file={:?}", path) |
| 46 | + })?; |
| 47 | + linker |
| 48 | + .add_fatbin(&data) |
| 49 | + .with_context(|| format!("error linking fatbin file={:?}", path))?; |
| 50 | + } |
| 51 | + } |
| 52 | + Ok(()) |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +/// A builder which will compile the Cooperative Groups API bridging code, and will then link it |
| 57 | +/// together with any other artifacts provided to this builder. |
| 58 | +/// |
| 59 | +/// The result of this process will be a `cubin` file containing the linked Cooperative Groups |
| 60 | +/// PTX code along with any other linked artifacts provided to this builder. The output `cubin` |
| 61 | +/// may then be loaded via `cust::module::Module::from_cubin(..)` and used as normal. |
| 62 | +#[derive(Default)] |
| 63 | +pub struct CooperativeGroups { |
| 64 | + /// Artifacts to be linked together with the Cooperative Groups bridge code. |
| 65 | + artifacts: Vec<LinkableArtifact>, |
| 66 | + /// Flags to pass to nvcc for Cooperative Groups API bridge compilation. |
| 67 | + nvcc_flags: Vec<String>, |
| 68 | +} |
| 69 | + |
| 70 | +impl CooperativeGroups { |
| 71 | + /// Construct a new instance. |
| 72 | + pub fn new() -> Self { |
| 73 | + Self::default() |
| 74 | + } |
| 75 | + |
| 76 | + /// Add the artifact at the given path for linking. |
| 77 | + /// |
| 78 | + /// This only applies to linking with the Cooperative Groups API bridge code. Typically, |
| 79 | + /// this will be the PTX of your main program which has already been built via `CudaBuilder`. |
| 80 | + pub fn link(mut self, artifact: LinkableArtifact) -> Self { |
| 81 | + self.artifacts.push(artifact); |
| 82 | + self |
| 83 | + } |
| 84 | + |
| 85 | + /// Add a flag to be passed along to `nvcc` during compilation of the Cooperative Groups API bridge code. |
| 86 | + /// |
| 87 | + /// This provides maximum flexibility for code generation. If needed, multiple architectures |
| 88 | + /// may be generated by adding the appropriate flags to the `nvcc` call. |
| 89 | + /// |
| 90 | + /// By default, `nvcc` will generate code for `sm_52`. Override by specifying any of `--gpu-architecture`, |
| 91 | + /// `--gpu-code`, or `--generate-code` flags. |
| 92 | + /// |
| 93 | + /// Regardless of the flags added via this method, this builder will always added the following flags: |
| 94 | + /// - `-I<cudaRoot>/include`: ensuring `cooperative_groups.h` can be found. |
| 95 | + /// - `-Icg`: ensuring the bridging header can be found. |
| 96 | + /// - `--ptx`: forces the compiled output to be in PTX form. |
| 97 | + /// - `--device-c`: to compile the bridging code as relocatable device code. |
| 98 | + /// - `src/cg_bridge.cu` will be added as the code to be compiled, which generates the |
| 99 | + /// Cooperative Groups API bridge. |
| 100 | + /// |
| 101 | + /// Docs: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#command-option-description |
| 102 | + pub fn nvcc_flag(mut self, val: impl AsRef<str>) -> Self { |
| 103 | + self.nvcc_flags.push(val.as_ref().to_string()); |
| 104 | + self |
| 105 | + } |
| 106 | + |
| 107 | + /// Compile the Cooperative Groups API bridging code, and then link it together |
| 108 | + /// with any other artifacts provided to this builder. |
| 109 | + /// |
| 110 | + /// - `cg_out` specifies the output location for the Cooperative Groups API bridge PTX. |
| 111 | + /// - `cubin_out` specifies the output location for the fully linked `cubin`. |
| 112 | + /// |
| 113 | + /// ## Errors |
| 114 | + /// - At least one artifact must be provided to this builder for linking. |
| 115 | + /// - Any errors which take place from the `nvcc` compilation of the Cooperative Groups briding |
| 116 | + /// code, or any errors which take place during module linking. |
| 117 | + pub fn compile( |
| 118 | + mut self, |
| 119 | + cg_out: impl AsRef<Path>, |
| 120 | + cubin_out: impl AsRef<Path>, |
| 121 | + ) -> CudaBuilderResult<()> { |
| 122 | + // Perform some initial validation. |
| 123 | + if self.artifacts.is_empty() { |
| 124 | + return Err(anyhow::anyhow!("must provide at least 1 ptx/cubin/fatbin artifact to be linked with the Cooperative Groups API bridge code").into()); |
| 125 | + } |
| 126 | + |
| 127 | + // Find the cuda installation directory for compilation of CG API. |
| 128 | + let cuda_root = |
| 129 | + find_cuda_helper::find_cuda_root().ok_or(CudaBuilderError::CudaRootNotFound)?; |
| 130 | + let cuda_include = cuda_root.join("include"); |
| 131 | + let cg_src = std::path::Path::new(std::file!()) |
| 132 | + .parent() |
| 133 | + .context("error accessing parent dir cuda_builder/src")? |
| 134 | + .parent() |
| 135 | + .context("error accessing parent dir cuda_builder")? |
| 136 | + .join("cg") |
| 137 | + .canonicalize() |
| 138 | + .context("error taking canonical path to cooperative groups API bridge code")?; |
| 139 | + let cg_bridge_cu = cg_src.join("cg_bridge.cu"); |
| 140 | + |
| 141 | + // Build up the `nvcc` invocation and then build the bridging code. |
| 142 | + let mut nvcc = std::process::Command::new("nvcc"); |
| 143 | + nvcc.arg(format!("-I{:?}", &cuda_include).as_str()) |
| 144 | + .arg(format!("-I{:?}", &cg_src).as_str()) |
| 145 | + .arg("--ptx") |
| 146 | + .arg("-o") |
| 147 | + .arg(cg_out.as_ref().to_string_lossy().as_ref()) |
| 148 | + .arg("--device-c") |
| 149 | + .arg(cg_bridge_cu.to_string_lossy().as_ref()); |
| 150 | + for flag in self.nvcc_flags.iter() { |
| 151 | + nvcc.arg(flag.as_str()); |
| 152 | + } |
| 153 | + nvcc.status() |
| 154 | + .context("error calling nvcc for Cooperative Groups API bridge compilation")?; |
| 155 | + |
| 156 | + // Link together the briding code with any given PTX/cubin/fatbin artifacts. |
| 157 | + let _ctx = cust::quick_init().context("error building cuda context")?; |
| 158 | + let mut linker = cust::link::Linker::new().context("error building cust linker")?; |
| 159 | + self.artifacts |
| 160 | + .push(LinkableArtifact::Ptx(cg_out.as_ref().to_path_buf())); |
| 161 | + for artifact in self.artifacts.iter() { |
| 162 | + artifact.link_artifact(&mut linker)?; |
| 163 | + } |
| 164 | + let linked_cubin = linker |
| 165 | + .complete() |
| 166 | + .context("error linking artifacts with Cooperative Groups API bridge PTX")?; |
| 167 | + |
| 168 | + // Write finalized cubin. |
| 169 | + std::fs::write(&cubin_out, &linked_cubin) |
| 170 | + .with_context(|| format!("error writing linked cubin to {:?}", cubin_out.as_ref()))?; |
| 171 | + |
| 172 | + Ok(()) |
| 173 | + } |
| 174 | +} |
0 commit comments