@@ -84,11 +84,33 @@ def build_variant() -> str:
8484 return f"torch{ torch_version .major } { torch_version .minor } -{ cxxabi } -{ compute_framework } -{ cpu } -{ os } "
8585
8686
87- def universal_build_variant () -> str :
87+ def build_variant_noarch () -> str :
88+ import torch
89+
90+ if torch .version .cuda is not None :
91+ return "torch-cuda"
92+ elif torch .version .hip is not None :
93+ return "torch-rocm"
94+ elif torch .backends .mps .is_available ():
95+ return "torch-metal"
96+ elif hasattr (torch .version , "xpu" ) and torch .version .xpu is not None :
97+ return "torch-xpu"
98+ elif _get_privateuse_backend_name () == "npu" :
99+ return "torch-npu"
100+ else :
101+ return "torch-cpu"
102+
103+
104+ def build_variant_universal () -> str :
88105 # Once we support other frameworks, detection goes here.
89106 return "torch-universal"
90107
91108
109+ def build_variants () -> List [str ]:
110+ """Return compatible build variants in preferred order."""
111+ return [build_variant (), build_variant_noarch (), build_variant_universal ()]
112+
113+
92114def _import_from_path (module_name : str , variant_path : Path ) -> ModuleType :
93115 metadata_path = variant_path / "metadata.json"
94116 if metadata_path .exists ():
@@ -146,13 +168,12 @@ def install_kernel(
146168 `Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
147169 """
148170 package_name = package_name_from_repo_id (repo_id )
149- variant = build_variant ()
150- universal_variant = universal_build_variant ()
171+ allow_patterns = [f"build/{ variant } /*" for variant in build_variants ()]
151172 user_agent = _get_user_agent (user_agent = user_agent )
152173 repo_path = Path (
153174 snapshot_download (
154175 repo_id ,
155- allow_patterns = [ f"build/ { variant } /*" , f"build/ { universal_variant } /*" ] ,
176+ allow_patterns = allow_patterns ,
156177 cache_dir = CACHE_DIR ,
157178 revision = revision ,
158179 local_files_only = local_files_only ,
@@ -173,23 +194,22 @@ def _find_kernel_in_repo_path(
173194 package_name : str ,
174195 variant_locks : Optional [Dict [str , VariantLock ]] = None ,
175196) -> Tuple [str , Path ]:
176- specific_variant = build_variant ()
177- universal_variant = universal_build_variant ()
178-
179- specific_variant_path = repo_path / "build" / specific_variant
180- universal_variant_path = repo_path / "build" / universal_variant
181-
182- if specific_variant_path .exists ():
183- variant = specific_variant
184- variant_path = specific_variant_path
185- elif universal_variant_path .exists ():
186- variant = universal_variant
187- variant_path = universal_variant_path
188- else :
197+ variants = build_variants ()
198+ variant = None
199+ variant_path = None
200+ for candidate_variant in variants :
201+ variant_path = repo_path / "build" / candidate_variant
202+ if variant_path .exists ():
203+ variant = candidate_variant
204+ break
205+
206+ if variant is None :
189207 raise FileNotFoundError (
190- f"Kernel at path `{ repo_path } ` does not have one of build variants: { specific_variant } , { universal_variant } "
208+ f"Kernel at path `{ repo_path } ` does not have one of build variants: { ', ' . join ( variants ) } "
191209 )
192210
211+ assert variant_path is not None
212+
193213 if variant_locks is not None :
194214 variant_lock = variant_locks .get (variant )
195215 if variant_lock is None :
@@ -295,13 +315,9 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
295315 Returns:
296316 `ModuleType`: The imported kernel module.
297317 """
298- variant = build_variant ()
299- universal_variant = universal_build_variant ()
300-
301318 # Presume we were given the top level path of the kernel repository.
302319 for base_path in [repo_path , repo_path / "build" ]:
303- # Prefer the universal variant if it exists.
304- for v in [universal_variant , variant ]:
320+ for v in build_variants ():
305321 variant_path = base_path / v
306322 if variant_path .exists ():
307323 return _import_from_path (package_name , variant_path )
@@ -337,9 +353,8 @@ def has_kernel(
337353
338354 package_name = package_name_from_repo_id (repo_id )
339355 variant = build_variant ()
340- universal_variant = universal_build_variant ()
341356
342- for variant in [ universal_variant , variant ] :
357+ for variant in build_variants () :
343358 for init_file in ["__init__.py" , f"{ package_name } /__init__.py" ]:
344359 if file_exists (
345360 repo_id ,
@@ -379,13 +394,11 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
379394
380395 package_name = package_name_from_repo_id (repo_id )
381396
382- variant = build_variant ()
383- universal_variant = universal_build_variant ()
384-
397+ allow_patterns = [f"build/{ variant } /*" for variant in build_variants ()]
385398 repo_path = Path (
386399 snapshot_download (
387400 repo_id ,
388- allow_patterns = [ f"build/ { variant } /*" , f"build/ { universal_variant } /*" ] ,
401+ allow_patterns = allow_patterns ,
389402 cache_dir = CACHE_DIR ,
390403 revision = locked_sha ,
391404 local_files_only = True ,
@@ -399,7 +412,7 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
399412 return _import_from_path (package_name , variant_path )
400413 except FileNotFoundError :
401414 raise FileNotFoundError (
402- f"Locked kernel `{ repo_id } ` does not have build ` { variant } ` or was not downloaded with `kernels download <project>`"
415+ f"Locked kernel `{ repo_id } ` does not have applicable variant or was not downloaded with `kernels download <project>`"
403416 )
404417
405418
0 commit comments