Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ AbstractFFTs = "0.5, 1.0"
Adapt = "4.4"
BFloat16s = "0.5, 0.6"
CEnum = "0.2, 0.3, 0.4, 0.5"
CUDA_Compiler_jll = "0.2"
CUDA_Compiler_jll = "0.3"
CUDA_Driver_jll = "13"
CUDA_Runtime_Discovery = "1"
CUDA_Runtime_jll = "0.19"
Expand All @@ -65,7 +65,7 @@ DataFrames = "1.5"
EnzymeCore = "0.8.2"
ExprTools = "0.1"
GPUArrays = "11.2.4"
GPUCompiler = "1.1"
GPUCompiler = "1.4"
GPUToolbox = "0.3, 1"
KernelAbstractions = "0.9.38"
LLVM = "9.3.1"
Expand Down
22 changes: 13 additions & 9 deletions src/compatibility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,19 @@ function llvm_compat(version=LLVM.version())
return (cap=cap_support, ptx=ptx_support)
end

function cuda_compat(driver=driver_version(), compiler=compiler_version())
# devices have to be supported by both the compiler and the driver
driver_cap_support = cuda_cap_support(driver)
compiler_cap_support = cuda_cap_support(compiler)
cap_support = sort(collect(driver_cap_support ∩ compiler_cap_support))

# PTX code only has to be supported by the compiler
compiler_ptx_support = cuda_ptx_support(compiler)
ptx_support = cuda_ptx_support(compiler)
function cuda_compat(version=runtime_version())
# we don't have to check the driver version, because it offers backwards compatbility
# beyond the CUDA toolkit version (e.g. R580 for CUDA 13 still supports Volta as
# deprecated in CUDA 13), and we don't have a reliable way to query the actual version
# as NVML isn't available on all platforms. let's instead simply assume that unsupported
# devices will not be exposed to the CUDA runtime and thus won't be visible to us.

# we also don't have to check the compiler version, because CUDA_Compiler_jll is
# guaranteed to have the same major version as CUDA_Runtime_jll, meaning that the
# compiler will always support at least the same devices as the runtime.

cap_support = sort(collect(cuda_cap_support(version)))
ptx_support = sort(collect(cuda_ptx_support(version)))

return (cap=cap_support, ptx=ptx_support)
end
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ end
error("CUDA.jl requires PTX $requested_ptx, which is not supported by LLVM $(LLVM.version())")
llvm_ptx = maximum(llvm_ptxs)
isempty(cuda_ptxs) &&
error("CUDA.jl requires PTX $requested_ptx, which is not supported by CUDA driver $(driver_version()) / runtime $(runtime_version())")
error("CUDA.jl requires PTX $requested_ptx, which is not supported by CUDA $(runtime_version())")
cuda_ptx = maximum(cuda_ptxs)
end

Expand All @@ -229,7 +229,7 @@ end
## use the highest capability supported by CUDA
cuda_caps = filter(<=(capability(dev)), cuda_support.cap)
isempty(cuda_caps) &&
error("Compute capability $(requested_cap) is not supported by CUDA driver $(driver_version()) / runtime $(runtime_version())")
error("Compute capability $(requested_cap) is not supported by CUDA $(runtime_version())")
cuda_cap = maximum(cuda_caps)
end

Expand Down
10 changes: 3 additions & 7 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ function __init__()
return
end

if !(v"12" <= driver < v"14-")
@error "This version of CUDA.jl only supports NVIDIA drivers for CUDA 12.x or 13.x (yours is for CUDA $driver)"
_initialization_error[] = "CUDA driver unsupported"
if driver < v"12"
@error "This version of CUDA.jl requires an NVIDIA driver for CUDA 12.x or higher (yours only supports up to CUDA $driver)"
_initialization_error[] = "NVIDIA driver too old"
return
end

Expand Down Expand Up @@ -133,10 +133,6 @@ function __init__()
if runtime < v"12"
@error "This version of CUDA.jl only supports CUDA 12 or higher (your toolkit provides CUDA $runtime)"
end
if runtime.major != driver.major
@warn """You are using CUDA $runtime with a driver for CUDA $(driver.major).x.
It is recommended to upgrade your driver, or switch to automatic installation of CUDA."""
end

# ensure the loaded runtime matches what we precompiled for.
if toolkit_version == nothing
Expand Down