Skip to content

Commit d9f60ff

Browse files
authored
Support noarch build variants (#190)
* Fix issues handling CPU layers * Support noarch build variants This change adds support for noarch build variants. So far we have used the universal variant for kernels that do not have any AoT-compiled code. However, the universal variant has two important issues: 1. A kernel without AoT-compiled might still be backend-specific. E.g. NVIDIA CuTe-based kernels are not universal in the sense that they don't work on non-NVIDIA GPUs. 2. We cannot specify dependencies per backend. To solve these issues, we introduce the noarch variants to replace universal kernels. Noarch kernels have variants of the shape `torch-<backend>` (e.g. `torch-xpu`). This resolves the issues outlined. This change introduces support for loading noarch kernels. In the future, we will start emitting deprecation warnings for universal kernels (to eventually remove support). * Fix build variant regex * Remove outdated comment
1 parent 8b807fa commit d9f60ff

File tree

7 files changed

+82
-44
lines changed

7 files changed

+82
-44
lines changed

src/kernels/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .doc import generate_readme_for_kernel
1515
from .wheel import build_variant_to_wheel
1616

17-
BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-universal)")
17+
BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-)")
1818

1919

2020
def main():

src/kernels/layer/kernelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
276276

277277
def _validate_device_type(device_type: str) -> None:
278278
"""Validate that the device type is supported."""
279-
supported_devices = {"cuda", "mps", "npu", "rocm", "xpu"}
279+
supported_devices = {"cpu", "cuda", "mps", "npu", "rocm", "xpu"}
280280
if device_type not in supported_devices:
281281
raise ValueError(
282282
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"

src/kernels/layer/repos.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ class DeviceRepos(ABC):
2424
@staticmethod
2525
def create_repo(device: Device) -> "DeviceRepos":
2626
"""Create an appropriate repository set for this device type."""
27-
if device.type == "cuda":
27+
if device.type == "cpu":
28+
return _CPURepos()
29+
elif device.type == "cuda":
2830
return _CUDARepos()
2931
elif device.type == "rocm":
3032
return _ROCMRepos()
@@ -51,6 +53,26 @@ def insert(self, device: Device, repos: Dict[Mode, RepositoryProtocol]):
5153
...
5254

5355

56+
class _CPURepos(DeviceRepos):
57+
_repos: Dict[Mode, RepositoryProtocol]
58+
59+
def __init__(self):
60+
super().__init__()
61+
self._repos = {}
62+
63+
@property
64+
def repos(
65+
self,
66+
) -> Optional[Dict[Mode, RepositoryProtocol]]:
67+
return self._repos
68+
69+
def insert(self, device: Device, repos: Dict[Mode, RepositoryProtocol]):
70+
if device.type != "cpu":
71+
raise ValueError(f"Device type must be 'cpu', got {device.type}")
72+
73+
self._repos = repos
74+
75+
5476
class _XPURepos(DeviceRepos):
5577
_repos: Dict[Mode, RepositoryProtocol]
5678

src/kernels/utils.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,33 @@ def build_variant() -> str:
8484
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
8585

8686

87-
def universal_build_variant() -> str:
87+
def build_variant_noarch() -> str:
88+
import torch
89+
90+
if torch.version.cuda is not None:
91+
return "torch-cuda"
92+
elif torch.version.hip is not None:
93+
return "torch-rocm"
94+
elif torch.backends.mps.is_available():
95+
return "torch-metal"
96+
elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
97+
return "torch-xpu"
98+
elif _get_privateuse_backend_name() == "npu":
99+
return "torch-npu"
100+
else:
101+
return "torch-cpu"
102+
103+
104+
def build_variant_universal() -> str:
88105
# Once we support other frameworks, detection goes here.
89106
return "torch-universal"
90107

91108

109+
def build_variants() -> List[str]:
110+
"""Return compatible build variants in preferred order."""
111+
return [build_variant(), build_variant_noarch(), build_variant_universal()]
112+
113+
92114
def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
93115
metadata_path = variant_path / "metadata.json"
94116
if metadata_path.exists():
@@ -146,13 +168,12 @@ def install_kernel(
146168
`Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
147169
"""
148170
package_name = package_name_from_repo_id(repo_id)
149-
variant = build_variant()
150-
universal_variant = universal_build_variant()
171+
allow_patterns = [f"build/{variant}/*" for variant in build_variants()]
151172
user_agent = _get_user_agent(user_agent=user_agent)
152173
repo_path = Path(
153174
snapshot_download(
154175
repo_id,
155-
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
176+
allow_patterns=allow_patterns,
156177
cache_dir=CACHE_DIR,
157178
revision=revision,
158179
local_files_only=local_files_only,
@@ -173,23 +194,22 @@ def _find_kernel_in_repo_path(
173194
package_name: str,
174195
variant_locks: Optional[Dict[str, VariantLock]] = None,
175196
) -> Tuple[str, Path]:
176-
specific_variant = build_variant()
177-
universal_variant = universal_build_variant()
178-
179-
specific_variant_path = repo_path / "build" / specific_variant
180-
universal_variant_path = repo_path / "build" / universal_variant
181-
182-
if specific_variant_path.exists():
183-
variant = specific_variant
184-
variant_path = specific_variant_path
185-
elif universal_variant_path.exists():
186-
variant = universal_variant
187-
variant_path = universal_variant_path
188-
else:
197+
variants = build_variants()
198+
variant = None
199+
variant_path = None
200+
for candidate_variant in variants:
201+
variant_path = repo_path / "build" / candidate_variant
202+
if variant_path.exists():
203+
variant = candidate_variant
204+
break
205+
206+
if variant is None:
189207
raise FileNotFoundError(
190-
f"Kernel at path `{repo_path}` does not have one of build variants: {specific_variant}, {universal_variant}"
208+
f"Kernel at path `{repo_path}` does not have one of build variants: {', '.join(variants)}"
191209
)
192210

211+
assert variant_path is not None
212+
193213
if variant_locks is not None:
194214
variant_lock = variant_locks.get(variant)
195215
if variant_lock is None:
@@ -295,13 +315,9 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
295315
Returns:
296316
`ModuleType`: The imported kernel module.
297317
"""
298-
variant = build_variant()
299-
universal_variant = universal_build_variant()
300-
301318
# Presume we were given the top level path of the kernel repository.
302319
for base_path in [repo_path, repo_path / "build"]:
303-
# Prefer the universal variant if it exists.
304-
for v in [universal_variant, variant]:
320+
for v in build_variants():
305321
variant_path = base_path / v
306322
if variant_path.exists():
307323
return _import_from_path(package_name, variant_path)
@@ -337,9 +353,8 @@ def has_kernel(
337353

338354
package_name = package_name_from_repo_id(repo_id)
339355
variant = build_variant()
340-
universal_variant = universal_build_variant()
341356

342-
for variant in [universal_variant, variant]:
357+
for variant in build_variants():
343358
for init_file in ["__init__.py", f"{package_name}/__init__.py"]:
344359
if file_exists(
345360
repo_id,
@@ -379,13 +394,11 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
379394

380395
package_name = package_name_from_repo_id(repo_id)
381396

382-
variant = build_variant()
383-
universal_variant = universal_build_variant()
384-
397+
allow_patterns = [f"build/{variant}/*" for variant in build_variants()]
385398
repo_path = Path(
386399
snapshot_download(
387400
repo_id,
388-
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
401+
allow_patterns=allow_patterns,
389402
cache_dir=CACHE_DIR,
390403
revision=locked_sha,
391404
local_files_only=True,
@@ -399,7 +412,7 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
399412
return _import_from_path(package_name, variant_path)
400413
except FileNotFoundError:
401414
raise FileNotFoundError(
402-
f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`"
415+
f"Locked kernel `{repo_id}` does not have applicable variant or was not downloaded with `kernels download <project>`"
403416
)
404417

405418

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def device():
4040
elif _get_privateuse_backend_name() == "npu":
4141
return "npu"
4242

43-
pytest.skip("No CUDA, NPU or XPU")
43+
return "cpu"
4444

4545

4646
def pytest_runtest_setup(item):

tests/test_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ def test_universal_kernel(universal_kernel):
163163
torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1)
164164

165165

166+
def test_noarch_kernel(device):
167+
supported_devices = ["cpu", "cuda", "xpu"]
168+
if device not in supported_devices:
169+
pytest.skip(f"Device is not one of: {','.join(supported_devices)}")
170+
get_kernel("kernels-test/silu-and-mul-noarch")
171+
172+
166173
@pytest.mark.parametrize(
167174
"repo_revision",
168175
[

tests/test_layer.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
_validate_layer,
2424
)
2525
from kernels.utils import (
26-
_get_privateuse_backend_name,
2726
install_kernel,
2827
)
2928

@@ -250,16 +249,13 @@ def test_hub_forward_npu():
250249
assert silu_and_mul_with_kernel.n_calls == 0
251250

252251

253-
@pytest.mark.skipif(
254-
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
255-
reason="Skip on xpu devices",
256-
)
257-
@pytest.mark.skipif(
258-
_get_privateuse_backend_name() == "npu",
259-
reason="Skip on npu devices",
260-
)
261-
def test_rocm_kernel_mapping():
252+
def test_rocm_kernel_mapping(device):
262253
"""Test that ROCm shorthand device mapping works correctly."""
254+
255+
# Lookup uses the GPU capability, so it fails for non-ROCm/CUDA.
256+
if device not in ["cuda", "rocm"]:
257+
pytest.skip("Test only applicable to CUDA and ROCM devices")
258+
263259
kernel_layer_mapping = {
264260
"SiluAndMul": {
265261
"rocm": LayerRepository(

0 commit comments

Comments
 (0)