Skip to content

Commit 2d4b2db

Browse files
committed
Find CUDA in conda cuda-toolkit
1 parent 37e903e commit 2d4b2db

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

python/triton/windows_utils.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@ def unparse_version(t, prefix=""):
3737

3838

3939
def max_version(versions, prefix="", check=lambda x: True):
40+
versions = [x for x in versions if check(x)]
4041
versions = [parse_version(x, prefix) for x in versions]
41-
versions = [
42-
x for x in versions if x is not None and check(unparse_version(x, prefix))
43-
]
42+
versions = [x for x in versions if x is not None]
4443
if not versions:
4544
return None
4645
version = unparse_version(max(versions), prefix)
@@ -272,7 +271,30 @@ def find_cuda_pip():
272271
return None, [], []
273272

274273

275-
def check_cuda(cuda_base_path):
274+
def check_cuda_conda(cuda_base_path):
275+
return all(
276+
x.exists()
277+
for x in [
278+
cuda_base_path / "bin" / "ptxas.exe",
279+
cuda_base_path / "include" / "cuda.h",
280+
cuda_base_path / "lib" / "cuda.lib",
281+
]
282+
)
283+
284+
285+
def find_cuda_conda():
286+
cuda_base_path = Path(sys.exec_prefix) / "Library"
287+
if check_cuda_conda(cuda_base_path):
288+
return (
289+
str(cuda_base_path / "bin"),
290+
[str(cuda_base_path / "include")],
291+
[str(cuda_base_path / "lib")],
292+
)
293+
294+
return None, [], []
295+
296+
297+
def check_cuda_system_wide(cuda_base_path):
276298
return all(
277299
x.exists()
278300
for x in [
@@ -290,7 +312,7 @@ def find_cuda_env():
290312
continue
291313

292314
cuda_base_path = Path(cuda_base_path)
293-
if check_cuda(cuda_base_path):
315+
if check_cuda_system_wide(cuda_base_path):
294316
return cuda_base_path
295317

296318
return None
@@ -306,7 +328,7 @@ def find_cuda_hardcoded():
306328
paths = sorted(paths)[::-1]
307329
for path in paths:
308330
cuda_base_path = Path(path)
309-
if check_cuda(cuda_base_path):
331+
if check_cuda_system_wide(cuda_base_path):
310332
return cuda_base_path
311333

312334
return None
@@ -318,6 +340,10 @@ def find_cuda():
318340
if cuda_bin_path:
319341
return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
320342

343+
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = find_cuda_conda()
344+
if cuda_bin_path:
345+
return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
346+
321347
cuda_base_path = find_cuda_env()
322348
if cuda_base_path is None:
323349
cuda_base_path = find_cuda_hardcoded()

0 commit comments

Comments
 (0)