Make GPU CUDA plugin require JAX #8919
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This fixes internal b/419277657.
Some PyTorch/XLA GPU features require JAX. The CI tests ad-hoc install the latest version of JAX, creating a skew with the JAX version pinned by PyTorch/XLA, thus causing test failures.
Rather than only installing the latest version of JAX in CI, we'll just make the CUDA plugin depend on a version of JAX that's the same as what's used by PyTorch/XLA on TPU.
Side note, there appears to be two ways of building PyTorch/XLA for GPU. One is by setting
XLA_CUDA=1
, which will cause the PyTorch/XLA C++.so
to be built with CUDA support. Another is by building a "PyTorch/XLA CUDA plugin" similar to thelibtpu
plugin, thus factoring CUDA-specific functionality behind a backend.so
. That shows up as the "Build PyTorch/XLA CUDA Plugin" job in CI. It appears that our cloud build jobs ref build the CUDA plugin but does not upload them to GCS (discussed earlier in #8876).In any case, our GPU CI infra uses the CUDA plugin path and don't bake in CUDA support into the PyTorch/XLA native
.so
. Therefore, in order to fix GPU CI, I made the CUDA plugin depend on the right versions of JAX.