Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion warp/_src/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3870,6 +3870,13 @@ def get_cuda_compile_arch(self) -> int | None:
output_arch = min(self.arch, warp.config.ptx_target_arch)
else:
output_arch = min(self.arch, runtime.default_ptx_arch)

# Ensure the chosen PTX arch is actually supported by NVRTC.
# Note: _resolve_supported_ptx_arch may clamp *up* to a higher arch
# than self.arch. This is intentional — PTX is forward-compatible,
# so the CUDA driver's JIT will translate it down at load time.
if runtime.nvrtc_supported_archs and output_arch not in runtime.nvrtc_supported_archs:
output_arch = _resolve_supported_ptx_arch(output_arch, runtime.nvrtc_supported_archs)
else:
output_arch = self.arch

Expand Down Expand Up @@ -3947,6 +3954,26 @@ def _validate_cuda_arch_suffix(
return suffix


def _resolve_supported_ptx_arch(target_arch: int, supported_archs: set[int]) -> int:
"""Return ``target_arch`` if NVRTC supports it, otherwise the closest supported arch.

Preference is given to the lowest supported architecture that is >=
``target_arch`` so that the resulting PTX is as broadly forward-compatible
as possible. If no such architecture exists the highest supported value
is returned instead.

``supported_archs`` must be non-empty; a ``ValueError`` is raised otherwise.
"""
if not supported_archs:
raise ValueError("supported_archs must be non-empty")
if target_arch in supported_archs:
return target_arch
above = sorted(a for a in supported_archs if a >= target_arch)
resolved = above[0] if above else max(supported_archs)
print(f"Warning: PTX target arch sm_{target_arch} is not supported by NVRTC; using sm_{resolved} instead")
return resolved


""" Meta-type for arguments that can be resolved to a concrete Device.
"""
DeviceLike = Union[Device, str, None]
Expand Down Expand Up @@ -5199,11 +5226,16 @@ def __init__(self):
)
except ValueError:
pass # no eligible NVRTC-supported arch ≥ default, retain existing

# Validate that the chosen PTX arch is actually supported by NVRTC
if self.nvrtc_supported_archs and self.default_ptx_arch not in self.nvrtc_supported_archs:
self.default_ptx_arch = _resolve_supported_ptx_arch(self.default_ptx_arch, self.nvrtc_supported_archs)
else:
self.set_default_device("cpu")
if self.nvrtc_supported_archs:
# NVRTC available but no devices/driver — enable offline compilation
self.default_ptx_arch = warp.config.ptx_target_arch if warp.config.ptx_target_arch is not None else 75
target = warp.config.ptx_target_arch if warp.config.ptx_target_arch is not None else 75
self.default_ptx_arch = _resolve_supported_ptx_arch(target, self.nvrtc_supported_archs)
else:
self.default_ptx_arch = None

Expand Down
32 changes: 32 additions & 0 deletions warp/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,38 @@ def test_context_type_str(self):
self.assertEqual(wp._src.context.type_str(tuple[int, float]), "tuple[int, float]")
self.assertEqual(wp._src.context.type_str(tuple[int, ...]), "tuple[int, ...]")

def test_resolve_supported_ptx_arch_exact_match(self):
"""Target arch is in the supported set — returned as-is."""
resolve = wp._src.context._resolve_supported_ptx_arch
self.assertEqual(resolve(75, {70, 75, 80, 86}), 75)
self.assertEqual(resolve(86, {70, 75, 80, 86}), 86)

def test_resolve_supported_ptx_arch_clamp_up(self):
"""Target arch missing — lowest supported arch >= target is chosen."""
resolve = wp._src.context._resolve_supported_ptx_arch
# 75 not in set, next above is 80
self.assertEqual(resolve(75, {70, 80, 86}), 80)
# 72 not in set, next above is 75
self.assertEqual(resolve(72, {60, 75, 80}), 75)

def test_resolve_supported_ptx_arch_fallback_to_max(self):
"""All supported archs are below the target — highest supported is returned."""
resolve = wp._src.context._resolve_supported_ptx_arch
self.assertEqual(resolve(90, {70, 75, 80}), 80)

def test_resolve_supported_ptx_arch_single_element(self):
"""Only one supported arch available."""
resolve = wp._src.context._resolve_supported_ptx_arch
self.assertEqual(resolve(75, {80}), 80)
self.assertEqual(resolve(90, {80}), 80)
self.assertEqual(resolve(80, {80}), 80)

def test_resolve_supported_ptx_arch_empty_raises(self):
"""Empty supported set must raise an explicit error."""
resolve = wp._src.context._resolve_supported_ptx_arch
with self.assertRaises(ValueError):
resolve(75, set())


if __name__ == "__main__":
unittest.main(verbosity=2)