Replies: 2 comments
-
Moreover, I use environment variables to specify my CUDA version, and CUDA is also installed on the shared drive, so both machines use the same environment variables. The JAX version is as follows: jax 0.5.0 What could be the possible reasons? |
Beta Was this translation helpful? Give feedback.
-
Hi @Sumching, I could not reproduce the crash you are observing with 2 GPUs with the nightly JAX version. Do you think you could try the nightly or the latest release (0.6.0) and let me know if the crash is still there? |
Beta Was this translation helpful? Give feedback.
-
I have two machines, A and B, both equipped with 4090 GPUs, and they share a common hard drive. My uv virtual environment and code are stored on this shared drive. The code is as follows:
Machine A runs fine with a single GPU, but encounters a segmentation fault on the line
x0 = jax.tree_map(lambda x: x[0], w_batched)
when using more than 2 GPUs. Machine B, however, works well with both single and multiple GPUs.Beta Was this translation helpful? Give feedback.
All reactions