|
7 | 7 |
|
8 | 8 | from cuda import cuda, cudart
|
9 | 9 | from cuda.core.experimental._context import Context, ContextOptions
|
10 |
| -from cuda.core.experimental._memory import Buffer, MemoryResource, _DefaultAsyncMempool |
| 10 | +from cuda.core.experimental._memory import Buffer, MemoryResource, _AsyncMemoryResource, _DefaultAsyncMempool |
11 | 11 | from cuda.core.experimental._stream import Stream, StreamOptions, default_stream
|
12 | 12 | from cuda.core.experimental._utils import ComputeCapability, CUDAError, handle_return, precondition
|
13 | 13 |
|
@@ -62,15 +62,21 @@ def __new__(cls, device_id=None):
|
62 | 62 | for dev_id in range(total):
|
63 | 63 | dev = super().__new__(cls)
|
64 | 64 | dev._id = dev_id
|
65 |
| - dev._mr = _DefaultAsyncMempool(dev_id) |
| 65 | + # If the device is in TCC mode, or does not support memory pools for some other reason, |
| 66 | + # use the AsyncMemoryResource which does not use memory pools. |
| 67 | + if (handle_return(cudart.cudaGetDeviceProperties(dev_id))).memoryPoolsSupported == 0: |
| 68 | + dev._mr = _AsyncMemoryResource(dev_id) |
| 69 | + else: |
| 70 | + dev._mr = _DefaultAsyncMempool(dev_id) |
| 71 | + |
66 | 72 | dev._has_inited = False
|
67 | 73 | _tls.devices.append(dev)
|
68 | 74 |
|
69 | 75 | return _tls.devices[device_id]
|
70 | 76 |
|
71 | 77 | def _check_context_initialized(self, *args, **kwargs):
|
72 | 78 | if not self._has_inited:
|
73 |
| - raise CUDAError("the device is not yet initialized, perhaps you forgot to call .set_current() first?") |
| 79 | + raise CUDAError("the device is not yet initialized, " "perhaps you forgot to call .set_current() first?") |
74 | 80 |
|
75 | 81 | @property
|
76 | 82 | def device_id(self) -> int:
|
|
0 commit comments