From 3b1ecd09981313e9b668de63e1ab32d033ce2ed3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 23 Jun 2025 13:01:23 +0000 Subject: [PATCH] Add `cuda-minver` option to the `general` section This option is used to set the minimum required version of CUDA. Older versions are rejected by CMake and not included in Nix builds. --- build2cmake/src/config/v2.rs | 5 ++ build2cmake/src/main.rs | 2 + build2cmake/src/templates/cuda/preamble.cmake | 7 ++ build2cmake/src/torch/cuda.rs | 11 ++- build2cmake/src/version.rs | 68 +++++++++++++++++++ docs/writing-kernels.md | 5 ++ lib/build.nix | 16 +++-- 7 files changed, 108 insertions(+), 6 deletions(-) create mode 100644 build2cmake/src/version.rs diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index 39f16ca..023e870 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -9,6 +9,8 @@ use eyre::{bail, Result}; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use crate::version::Version; + use super::v1::{self, Language}; #[derive(Debug, Deserialize, Serialize)] @@ -44,6 +46,8 @@ pub struct General { pub name: String, #[serde(default)] pub universal: bool, + + pub cuda_minver: Option, } #[derive(Debug, Deserialize, Clone, Serialize)] @@ -207,6 +211,7 @@ impl General { Self { name: general.name, universal, + cuda_minver: None, } } } diff --git a/build2cmake/src/main.rs b/build2cmake/src/main.rs index 4f7f710..902ab82 100644 --- a/build2cmake/src/main.rs +++ b/build2cmake/src/main.rs @@ -17,6 +17,8 @@ use config::{Backend, Build, BuildCompat}; mod fileset; use fileset::FileSet; +mod version; + #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Cli { diff --git a/build2cmake/src/templates/cuda/preamble.cmake b/build2cmake/src/templates/cuda/preamble.cmake index dbb92ad..ff8c08c 100644 --- a/build2cmake/src/templates/cuda/preamble.cmake +++ b/build2cmake/src/templates/cuda/preamble.cmake @@ -36,6 +36,13 @@ endif() if (NOT HIP_FOUND AND CUDA_FOUND) set(GPU_LANG "CUDA") + + {% if cuda_minver %} + if (CUDA_VERSION VERSION_LESS {{ cuda_minver }}) + message(FATAL_ERROR "CUDA version ${CUDA_VERSION} is too old. " + "Minimum required version is {{ cuda_minver }}.") + endif() + {% endif %} elseif(HIP_FOUND) set(GPU_LANG "HIP") diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index 6f5d2dc..a6b30d1 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -8,6 +8,7 @@ use minijinja::{context, Environment}; use super::kernel_ops_identifier; use crate::config::{Backend, Build, Dependencies, Kernel, Torch}; +use crate::version::Version; use crate::FileSet; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); @@ -163,7 +164,7 @@ fn write_cmake( let cmake_writer = file_set.entry("CMakeLists.txt"); - render_preamble(env, name, cmake_writer)?; + render_preamble(env, name, build.general.cuda_minver.as_ref(), cmake_writer)?; render_deps(env, build, cmake_writer)?; @@ -338,12 +339,18 @@ pub fn render_extension(env: &Environment, ops_name: &str, write: &mut impl Writ Ok(()) } -pub fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> { +pub fn render_preamble( + env: &Environment, + name: &str, + cuda_minver: Option<&Version>, + write: &mut impl Write, +) -> Result<()> { env.get_template("cuda/preamble.cmake") .wrap_err("Cannot get CMake prelude template")? .render_to_write( context! { name => name, + cuda_minver => cuda_minver.map(|v| v.to_string()), cuda_supported_archs => cuda_supported_archs(), }, diff --git a/build2cmake/src/version.rs b/build2cmake/src/version.rs new file mode 100644 index 0000000..97bac83 --- /dev/null +++ b/build2cmake/src/version.rs @@ -0,0 +1,68 @@ +use std::{fmt::Display, str::FromStr}; + +use eyre::{ensure, Context}; +use itertools::Itertools; +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; + +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct Version(Vec); + +impl<'de> Deserialize<'de> for Version { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + FromStr::from_str(&s).map_err(de::Error::custom) + } +} + +impl Serialize for Version { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.0.iter().map(|v| v.to_string()).join(".")) + } +} + +impl Display for Version { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + itertools::join(self.0.iter().map(|v| v.to_string()), ".") + ) + } +} + +impl From> for Version { + fn from(value: Vec) -> Self { + // Remove trailing zeros for normalization. + let mut normalized = value + .into_iter() + .rev() + .skip_while(|&x| x == 0) + .collect::>(); + normalized.reverse(); + Version(normalized) + } +} + +impl FromStr for Version { + type Err = eyre::Report; + + fn from_str(version: &str) -> Result { + let version = version.trim().to_owned(); + ensure!(!version.is_empty(), "Empty version string"); + let mut version_parts = Vec::new(); + for part in version.split('.') { + let version_part: usize = part + .parse() + .context(format!("Version must consist of numbers: {}", version))?; + version_parts.push(version_part); + } + + Ok(Version::from(version_parts)) + } +} diff --git a/docs/writing-kernels.md b/docs/writing-kernels.md index 540a002..72b552e 100644 --- a/docs/writing-kernels.md +++ b/docs/writing-kernels.md @@ -96,6 +96,11 @@ depends = [ "torch" ] Universal kernels do not use the other sections described below. A good example of a universal kernel is a Triton kernel. Default: `false` +- `cuda-minver`: the minimum required CUDA toolkit version. This option + _must not_ be set under normal circumstances, since it can exclude Torch + build variants that are [required for compliant kernels](https://github.com/huggingface/kernels/blob/main/docs/kernel-requirements.md). + This option is provided for kernels that require functionality only + provided by newer CUDA toolkits. ### `torch` diff --git a/lib/build.nix b/lib/build.nix index 7b6f9dc..56faa8d 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -62,12 +62,20 @@ rec { buildConfig: buildSets: let backends' = backends buildConfig; + requiredCuda = buildConfig.general.cuda-minver or "11.8"; supportedBuildSet = buildSet: - (buildSet.gpu == "cuda" && backends'.cuda) - || (buildSet.gpu == "rocm" && backends'.rocm) - || (buildSet.gpu == "metal" && backends'.metal) - || (buildConfig.general.universal or false); + let + backendSupported = + (buildSet.gpu == "cuda" && backends'.cuda) + || (buildSet.gpu == "rocm" && backends'.rocm) + || (buildSet.gpu == "metal" && backends'.metal) + || (buildConfig.general.universal or false); + cudaVersionSupported = + buildSet.gpu != "cuda" + || (lib.strings.versionAtLeast buildSet.pkgs.cudaPackages.cudaMajorMinorVersion requiredCuda); + in + backendSupported && cudaVersionSupported; in builtins.filter supportedBuildSet buildSets;