Skip to content

Make GPU CUDA plugin require JAX #8919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft

Conversation

tengyifei
Copy link
Collaborator

@tengyifei tengyifei commented Apr 1, 2025

This fixes internal b/419277657.

Some PyTorch/XLA GPU features require JAX. The CI tests ad-hoc install the latest version of JAX, creating a skew with the JAX version pinned by PyTorch/XLA, thus causing test failures.

Rather than only installing the latest version of JAX in CI, we'll just make the CUDA plugin depend on a version of JAX that's the same as what's used by PyTorch/XLA on TPU.

Side note, there appears to be two ways of building PyTorch/XLA for GPU. One is by setting XLA_CUDA=1, which will cause the PyTorch/XLA C++ .so to be built with CUDA support. Another is by building a "PyTorch/XLA CUDA plugin" similar to the libtpu plugin, thus factoring CUDA-specific functionality behind a backend .so. That shows up as the "Build PyTorch/XLA CUDA Plugin" job in CI. It appears that our cloud build jobs ref build the CUDA plugin but does not upload them to GCS (discussed earlier in #8876).

In any case, our GPU CI infra uses the CUDA plugin path and don't bake in CUDA support into the PyTorch/XLA native .so. Therefore, in order to fix GPU CI, I made the CUDA plugin depend on the right versions of JAX.

Some XLA GPU features require JAX. Rather than only installing the
latest version of JAX in CI, we'll just make the CUDA plugin depend on a
version of JAX that's the same as what's used by PyTorch/XLA on TPU.
(Except the JAX CUDA wheels).
@tengyifei tengyifei force-pushed the yifeit/cuda-plugin branch from 1b6d9ac to 47d6ea1 Compare May 23, 2025 08:01
@tengyifei tengyifei force-pushed the yifeit/cuda-plugin branch from 4f623f5 to 2e67486 Compare May 23, 2025 18:43
@tengyifei
Copy link
Collaborator Author

This is not easily fixable because JAX 0.6.1 started requiring CuDNN 9.8 (see https://github.com/jax-ml/jax/blob/main/CHANGELOG.md?plain=1#L61), but CuDNN 9.8 requires Debian 12 (#8928).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant