diff --git a/warp/_src/context.py b/warp/_src/context.py index 8721555ba5..cf34289553 100644 --- a/warp/_src/context.py +++ b/warp/_src/context.py @@ -3870,6 +3870,13 @@ def get_cuda_compile_arch(self) -> int | None: output_arch = min(self.arch, warp.config.ptx_target_arch) else: output_arch = min(self.arch, runtime.default_ptx_arch) + + # Ensure the chosen PTX arch is actually supported by NVRTC. + # Note: _resolve_supported_ptx_arch may clamp *up* to a higher arch + # than self.arch. This is intentional — PTX is forward-compatible, + # so the CUDA driver's JIT will translate it down at load time. + if runtime.nvrtc_supported_archs and output_arch not in runtime.nvrtc_supported_archs: + output_arch = _resolve_supported_ptx_arch(output_arch, runtime.nvrtc_supported_archs) else: output_arch = self.arch @@ -3947,6 +3954,26 @@ def _validate_cuda_arch_suffix( return suffix +def _resolve_supported_ptx_arch(target_arch: int, supported_archs: set[int]) -> int: + """Return ``target_arch`` if NVRTC supports it, otherwise the closest supported arch. + + Preference is given to the lowest supported architecture that is >= + ``target_arch`` so that the resulting PTX is as broadly forward-compatible + as possible. If no such architecture exists the highest supported value + is returned instead. + + ``supported_archs`` must be non-empty; a ``ValueError`` is raised otherwise. + """ + if not supported_archs: + raise ValueError("supported_archs must be non-empty") + if target_arch in supported_archs: + return target_arch + above = sorted(a for a in supported_archs if a >= target_arch) + resolved = above[0] if above else max(supported_archs) + print(f"Warning: PTX target arch sm_{target_arch} is not supported by NVRTC; using sm_{resolved} instead") + return resolved + + """ Meta-type for arguments that can be resolved to a concrete Device. """ DeviceLike = Union[Device, str, None] @@ -5199,11 +5226,16 @@ def __init__(self): ) except ValueError: pass # no eligible NVRTC-supported arch ≥ default, retain existing + + # Validate that the chosen PTX arch is actually supported by NVRTC + if self.nvrtc_supported_archs and self.default_ptx_arch not in self.nvrtc_supported_archs: + self.default_ptx_arch = _resolve_supported_ptx_arch(self.default_ptx_arch, self.nvrtc_supported_archs) else: self.set_default_device("cpu") if self.nvrtc_supported_archs: # NVRTC available but no devices/driver — enable offline compilation - self.default_ptx_arch = warp.config.ptx_target_arch if warp.config.ptx_target_arch is not None else 75 + target = warp.config.ptx_target_arch if warp.config.ptx_target_arch is not None else 75 + self.default_ptx_arch = _resolve_supported_ptx_arch(target, self.nvrtc_supported_archs) else: self.default_ptx_arch = None diff --git a/warp/tests/test_context.py b/warp/tests/test_context.py index dcf932d03a..ef85339129 100644 --- a/warp/tests/test_context.py +++ b/warp/tests/test_context.py @@ -16,6 +16,38 @@ def test_context_type_str(self): self.assertEqual(wp._src.context.type_str(tuple[int, float]), "tuple[int, float]") self.assertEqual(wp._src.context.type_str(tuple[int, ...]), "tuple[int, ...]") + def test_resolve_supported_ptx_arch_exact_match(self): + """Target arch is in the supported set — returned as-is.""" + resolve = wp._src.context._resolve_supported_ptx_arch + self.assertEqual(resolve(75, {70, 75, 80, 86}), 75) + self.assertEqual(resolve(86, {70, 75, 80, 86}), 86) + + def test_resolve_supported_ptx_arch_clamp_up(self): + """Target arch missing — lowest supported arch >= target is chosen.""" + resolve = wp._src.context._resolve_supported_ptx_arch + # 75 not in set, next above is 80 + self.assertEqual(resolve(75, {70, 80, 86}), 80) + # 72 not in set, next above is 75 + self.assertEqual(resolve(72, {60, 75, 80}), 75) + + def test_resolve_supported_ptx_arch_fallback_to_max(self): + """All supported archs are below the target — highest supported is returned.""" + resolve = wp._src.context._resolve_supported_ptx_arch + self.assertEqual(resolve(90, {70, 75, 80}), 80) + + def test_resolve_supported_ptx_arch_single_element(self): + """Only one supported arch available.""" + resolve = wp._src.context._resolve_supported_ptx_arch + self.assertEqual(resolve(75, {80}), 80) + self.assertEqual(resolve(90, {80}), 80) + self.assertEqual(resolve(80, {80}), 80) + + def test_resolve_supported_ptx_arch_empty_raises(self): + """Empty supported set must raise an explicit error.""" + resolve = wp._src.context._resolve_supported_ptx_arch + with self.assertRaises(ValueError): + resolve(75, set()) + if __name__ == "__main__": unittest.main(verbosity=2)