Skip to content

Add cuda-minver option to the general section #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions build2cmake/src/config/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use eyre::{bail, Result};
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use crate::version::Version;

use super::v1::{self, Language};

#[derive(Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -44,6 +46,8 @@ pub struct General {
pub name: String,
#[serde(default)]
pub universal: bool,

pub cuda_minver: Option<Version>,
}

#[derive(Debug, Deserialize, Clone, Serialize)]
Expand Down Expand Up @@ -207,6 +211,7 @@ impl General {
Self {
name: general.name,
universal,
cuda_minver: None,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions build2cmake/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use config::{Backend, Build, BuildCompat};
mod fileset;
use fileset::FileSet;

mod version;

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Cli {
Expand Down
7 changes: 7 additions & 0 deletions build2cmake/src/templates/cuda/preamble.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ endif()

if (NOT HIP_FOUND AND CUDA_FOUND)
set(GPU_LANG "CUDA")

{% if cuda_minver %}
if (CUDA_VERSION VERSION_LESS {{ cuda_minver }})
message(FATAL_ERROR "CUDA version ${CUDA_VERSION} is too old. "
"Minimum required version is {{ cuda_minver }}.")
endif()
{% endif %}
elseif(HIP_FOUND)
set(GPU_LANG "HIP")

Expand Down
11 changes: 9 additions & 2 deletions build2cmake/src/torch/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use minijinja::{context, Environment};

use super::kernel_ops_identifier;
use crate::config::{Backend, Build, Dependencies, Kernel, Torch};
use crate::version::Version;
use crate::FileSet;

static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
Expand Down Expand Up @@ -163,7 +164,7 @@ fn write_cmake(

let cmake_writer = file_set.entry("CMakeLists.txt");

render_preamble(env, name, cmake_writer)?;
render_preamble(env, name, build.general.cuda_minver.as_ref(), cmake_writer)?;

render_deps(env, build, cmake_writer)?;

Expand Down Expand Up @@ -338,12 +339,18 @@ pub fn render_extension(env: &Environment, ops_name: &str, write: &mut impl Writ
Ok(())
}

pub fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> {
pub fn render_preamble(
env: &Environment,
name: &str,
cuda_minver: Option<&Version>,
write: &mut impl Write,
) -> Result<()> {
env.get_template("cuda/preamble.cmake")
.wrap_err("Cannot get CMake prelude template")?
.render_to_write(
context! {
name => name,
cuda_minver => cuda_minver.map(|v| v.to_string()),
cuda_supported_archs => cuda_supported_archs(),

},
Expand Down
68 changes: 68 additions & 0 deletions build2cmake/src/version.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use std::{fmt::Display, str::FromStr};

use eyre::{ensure, Context};
use itertools::Itertools;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};

#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct Version(Vec<usize>);

impl<'de> Deserialize<'de> for Version {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
FromStr::from_str(&s).map_err(de::Error::custom)
}
}

impl Serialize for Version {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.0.iter().map(|v| v.to_string()).join("."))
}
}

impl Display for Version {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
itertools::join(self.0.iter().map(|v| v.to_string()), ".")
)
}
}

impl From<Vec<usize>> for Version {
fn from(value: Vec<usize>) -> Self {
// Remove trailing zeros for normalization.
let mut normalized = value
.into_iter()
.rev()
.skip_while(|&x| x == 0)
.collect::<Vec<_>>();
normalized.reverse();
Version(normalized)
}
}

impl FromStr for Version {
type Err = eyre::Report;

fn from_str(version: &str) -> Result<Self, Self::Err> {
let version = version.trim().to_owned();
ensure!(!version.is_empty(), "Empty version string");
let mut version_parts = Vec::new();
for part in version.split('.') {
let version_part: usize = part
.parse()
.context(format!("Version must consist of numbers: {}", version))?;
version_parts.push(version_part);
}

Ok(Version::from(version_parts))
}
}
5 changes: 5 additions & 0 deletions docs/writing-kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ depends = [ "torch" ]
Universal kernels do not use the other sections described below.
A good example of a universal kernel is a Triton kernel.
Default: `false`
- `cuda-minver`: the minimum required CUDA toolkit version. This option
_must not_ be set under normal circumstances, since it can exclude Torch
build variants that are [required for compliant kernels](https://github.com/huggingface/kernels/blob/main/docs/kernel-requirements.md).
This option is provided for kernels that require functionality only
provided by newer CUDA toolkits.

### `torch`

Expand Down
16 changes: 12 additions & 4 deletions lib/build.nix
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,20 @@ rec {
buildConfig: buildSets:
let
backends' = backends buildConfig;
requiredCuda = buildConfig.general.cuda-minver or "11.8";
supportedBuildSet =
buildSet:
(buildSet.gpu == "cuda" && backends'.cuda)
|| (buildSet.gpu == "rocm" && backends'.rocm)
|| (buildSet.gpu == "metal" && backends'.metal)
|| (buildConfig.general.universal or false);
let
backendSupported =
(buildSet.gpu == "cuda" && backends'.cuda)
|| (buildSet.gpu == "rocm" && backends'.rocm)
|| (buildSet.gpu == "metal" && backends'.metal)
|| (buildConfig.general.universal or false);
cudaVersionSupported =
buildSet.gpu != "cuda"
|| (lib.strings.versionAtLeast buildSet.pkgs.cudaPackages.cudaMajorMinorVersion requiredCuda);
in
backendSupported && cudaVersionSupported;
in
builtins.filter supportedBuildSet buildSets;

Expand Down