@@ -37,10 +37,9 @@ def unparse_version(t, prefix=""):
37
37
38
38
39
39
def max_version (versions , prefix = "" , check = lambda x : True ):
40
+ versions = [x for x in versions if check (x )]
40
41
versions = [parse_version (x , prefix ) for x in versions ]
41
- versions = [
42
- x for x in versions if x is not None and check (unparse_version (x , prefix ))
43
- ]
42
+ versions = [x for x in versions if x is not None ]
44
43
if not versions :
45
44
return None
46
45
version = unparse_version (max (versions ), prefix )
@@ -272,7 +271,30 @@ def find_cuda_pip():
272
271
return None , [], []
273
272
274
273
275
- def check_cuda (cuda_base_path ):
274
+ def check_cuda_conda (cuda_base_path ):
275
+ return all (
276
+ x .exists ()
277
+ for x in [
278
+ cuda_base_path / "bin" / "ptxas.exe" ,
279
+ cuda_base_path / "include" / "cuda.h" ,
280
+ cuda_base_path / "lib" / "cuda.lib" ,
281
+ ]
282
+ )
283
+
284
+
285
+ def find_cuda_conda ():
286
+ cuda_base_path = Path (sys .exec_prefix ) / "Library"
287
+ if check_cuda_conda (cuda_base_path ):
288
+ return (
289
+ str (cuda_base_path / "bin" ),
290
+ [str (cuda_base_path / "include" )],
291
+ [str (cuda_base_path / "lib" )],
292
+ )
293
+
294
+ return None , [], []
295
+
296
+
297
+ def check_cuda_system_wide (cuda_base_path ):
276
298
return all (
277
299
x .exists ()
278
300
for x in [
@@ -290,7 +312,7 @@ def find_cuda_env():
290
312
continue
291
313
292
314
cuda_base_path = Path (cuda_base_path )
293
- if check_cuda (cuda_base_path ):
315
+ if check_cuda_system_wide (cuda_base_path ):
294
316
return cuda_base_path
295
317
296
318
return None
@@ -306,7 +328,7 @@ def find_cuda_hardcoded():
306
328
paths = sorted (paths )[::- 1 ]
307
329
for path in paths :
308
330
cuda_base_path = Path (path )
309
- if check_cuda (cuda_base_path ):
331
+ if check_cuda_system_wide (cuda_base_path ):
310
332
return cuda_base_path
311
333
312
334
return None
@@ -318,6 +340,10 @@ def find_cuda():
318
340
if cuda_bin_path :
319
341
return cuda_bin_path , cuda_inc_dirs , cuda_lib_dirs
320
342
343
+ cuda_bin_path , cuda_inc_dirs , cuda_lib_dirs = find_cuda_conda ()
344
+ if cuda_bin_path :
345
+ return cuda_bin_path , cuda_inc_dirs , cuda_lib_dirs
346
+
321
347
cuda_base_path = find_cuda_env ()
322
348
if cuda_base_path is None :
323
349
cuda_base_path = find_cuda_hardcoded ()
0 commit comments