Is there a way to force jax into CPU-only mode? #28587
Unanswered
Jacob-Stevens-Haas
asked this question in
General
Replies: 1 comment
-
By default, JAX places new arrays on To make JAX ignore GPUs entirely, you could try setting the environment variable
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
print(os.environ.get('JAX_PLATFORMS')) # cpu
import jax
print(jax.devices()) # [CpuDevice(id=0)] Regarding your question about a GPU being truly out of memory: if JAX (without |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
At the top of my file I have
Yet somehow my computation stalls with a warning that the GPU is out of memory:
I verified that
jax.config.__getattribute__("jax_default_device")
is indeed CPU. The problem was because a context manager set it as GPU temporarily. I'm wondering whether there's a way to specifically prohibit jax from seeing a particular device. I'm also wondering what would happen if the GPU was truly out of memory so that jax couldn't even allocate the 1% of GPU it uses even when the default device is CPU.Beta Was this translation helpful? Give feedback.
All reactions