Skip to content

Commit 55a8181

Browse files
authored
Add cuda-minver option to the general section (#165)
This option is used to set the minimum required version of CUDA. Older versions are rejected by CMake and not included in Nix builds.
1 parent 6704ae8 commit 55a8181

File tree

7 files changed

+108
-6
lines changed

7 files changed

+108
-6
lines changed

build2cmake/src/config/v2.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use eyre::{bail, Result};
99
use itertools::Itertools;
1010
use serde::{Deserialize, Serialize};
1111

12+
use crate::version::Version;
13+
1214
use super::v1::{self, Language};
1315

1416
#[derive(Debug, Deserialize, Serialize)]
@@ -44,6 +46,8 @@ pub struct General {
4446
pub name: String,
4547
#[serde(default)]
4648
pub universal: bool,
49+
50+
pub cuda_minver: Option<Version>,
4751
}
4852

4953
#[derive(Debug, Deserialize, Clone, Serialize)]
@@ -207,6 +211,7 @@ impl General {
207211
Self {
208212
name: general.name,
209213
universal,
214+
cuda_minver: None,
210215
}
211216
}
212217
}

build2cmake/src/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use config::{Backend, Build, BuildCompat};
1717
mod fileset;
1818
use fileset::FileSet;
1919

20+
mod version;
21+
2022
#[derive(Parser, Debug)]
2123
#[command(version, about, long_about = None)]
2224
struct Cli {

build2cmake/src/templates/cuda/preamble.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ endif()
3636

3737
if (NOT HIP_FOUND AND CUDA_FOUND)
3838
set(GPU_LANG "CUDA")
39+
40+
{% if cuda_minver %}
41+
if (CUDA_VERSION VERSION_LESS {{ cuda_minver }})
42+
message(FATAL_ERROR "CUDA version ${CUDA_VERSION} is too old. "
43+
"Minimum required version is {{ cuda_minver }}.")
44+
endif()
45+
{% endif %}
3946
elseif(HIP_FOUND)
4047
set(GPU_LANG "HIP")
4148

build2cmake/src/torch/cuda.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use minijinja::{context, Environment};
88

99
use super::kernel_ops_identifier;
1010
use crate::config::{Backend, Build, Dependencies, Kernel, Torch};
11+
use crate::version::Version;
1112
use crate::FileSet;
1213

1314
static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
@@ -163,7 +164,7 @@ fn write_cmake(
163164

164165
let cmake_writer = file_set.entry("CMakeLists.txt");
165166

166-
render_preamble(env, name, cmake_writer)?;
167+
render_preamble(env, name, build.general.cuda_minver.as_ref(), cmake_writer)?;
167168

168169
render_deps(env, build, cmake_writer)?;
169170

@@ -338,12 +339,18 @@ pub fn render_extension(env: &Environment, ops_name: &str, write: &mut impl Writ
338339
Ok(())
339340
}
340341

341-
pub fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> {
342+
pub fn render_preamble(
343+
env: &Environment,
344+
name: &str,
345+
cuda_minver: Option<&Version>,
346+
write: &mut impl Write,
347+
) -> Result<()> {
342348
env.get_template("cuda/preamble.cmake")
343349
.wrap_err("Cannot get CMake prelude template")?
344350
.render_to_write(
345351
context! {
346352
name => name,
353+
cuda_minver => cuda_minver.map(|v| v.to_string()),
347354
cuda_supported_archs => cuda_supported_archs(),
348355

349356
},

build2cmake/src/version.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
use std::{fmt::Display, str::FromStr};
2+
3+
use eyre::{ensure, Context};
4+
use itertools::Itertools;
5+
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
6+
7+
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
8+
pub struct Version(Vec<usize>);
9+
10+
impl<'de> Deserialize<'de> for Version {
11+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
12+
where
13+
D: Deserializer<'de>,
14+
{
15+
let s = String::deserialize(deserializer)?;
16+
FromStr::from_str(&s).map_err(de::Error::custom)
17+
}
18+
}
19+
20+
impl Serialize for Version {
21+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
22+
where
23+
S: Serializer,
24+
{
25+
serializer.serialize_str(&self.0.iter().map(|v| v.to_string()).join("."))
26+
}
27+
}
28+
29+
impl Display for Version {
30+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31+
write!(
32+
f,
33+
"{}",
34+
itertools::join(self.0.iter().map(|v| v.to_string()), ".")
35+
)
36+
}
37+
}
38+
39+
impl From<Vec<usize>> for Version {
40+
fn from(value: Vec<usize>) -> Self {
41+
// Remove trailing zeros for normalization.
42+
let mut normalized = value
43+
.into_iter()
44+
.rev()
45+
.skip_while(|&x| x == 0)
46+
.collect::<Vec<_>>();
47+
normalized.reverse();
48+
Version(normalized)
49+
}
50+
}
51+
52+
impl FromStr for Version {
53+
type Err = eyre::Report;
54+
55+
fn from_str(version: &str) -> Result<Self, Self::Err> {
56+
let version = version.trim().to_owned();
57+
ensure!(!version.is_empty(), "Empty version string");
58+
let mut version_parts = Vec::new();
59+
for part in version.split('.') {
60+
let version_part: usize = part
61+
.parse()
62+
.context(format!("Version must consist of numbers: {}", version))?;
63+
version_parts.push(version_part);
64+
}
65+
66+
Ok(Version::from(version_parts))
67+
}
68+
}

docs/writing-kernels.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ depends = [ "torch" ]
9696
Universal kernels do not use the other sections described below.
9797
A good example of a universal kernel is a Triton kernel.
9898
Default: `false`
99+
- `cuda-minver`: the minimum required CUDA toolkit version. This option
100+
_must not_ be set under normal circumstances, since it can exclude Torch
101+
build variants that are [required for compliant kernels](https://github.com/huggingface/kernels/blob/main/docs/kernel-requirements.md).
102+
This option is provided for kernels that require functionality only
103+
provided by newer CUDA toolkits.
99104

100105
### `torch`
101106

lib/build.nix

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,20 @@ rec {
6262
buildConfig: buildSets:
6363
let
6464
backends' = backends buildConfig;
65+
requiredCuda = buildConfig.general.cuda-minver or "11.8";
6566
supportedBuildSet =
6667
buildSet:
67-
(buildSet.gpu == "cuda" && backends'.cuda)
68-
|| (buildSet.gpu == "rocm" && backends'.rocm)
69-
|| (buildSet.gpu == "metal" && backends'.metal)
70-
|| (buildConfig.general.universal or false);
68+
let
69+
backendSupported =
70+
(buildSet.gpu == "cuda" && backends'.cuda)
71+
|| (buildSet.gpu == "rocm" && backends'.rocm)
72+
|| (buildSet.gpu == "metal" && backends'.metal)
73+
|| (buildConfig.general.universal or false);
74+
cudaVersionSupported =
75+
buildSet.gpu != "cuda"
76+
|| (lib.strings.versionAtLeast buildSet.pkgs.cudaPackages.cudaMajorMinorVersion requiredCuda);
77+
in
78+
backendSupported && cudaVersionSupported;
7179
in
7280
builtins.filter supportedBuildSet buildSets;
7381

0 commit comments

Comments
 (0)