Skip to content

Commit 5c4090b

Browse files
nitins17maxtext authors
authored andcommitted
Update usage of cuda12_pip to cuda12 and cuda12_local to cuda12-local when installing jax
PiperOrigin-RevId: 748000088
1 parent 4e8ce0e commit 5c4090b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

setup.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ elif [[ $MODE == "nightly" ]]; then
163163
# Install jax-nightly
164164
if [[ -n "$JAX_VERSION" ]]; then
165165
echo "Installing jax-nightly, jaxlib-nightly ${JAX_VERSION}"
166-
python3 -m pip install -U --pre jax==${JAX_VERSION} jaxlib==${JAX_VERSION} jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
166+
python3 -m pip install -U --pre jax==${JAX_VERSION} jaxlib==${JAX_VERSION} jax-cuda12-plugin[with-cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
167167
else
168168
echo "Installing latest jax-nightly, jaxlib-nightly"
169-
python3 -m pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
169+
python3 -m pip install -U --pre jax jaxlib jax-cuda12-plugin[with-cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
170170
fi
171171
# Install Transformer Engine
172172
export NVTE_FRAMEWORK=jax

0 commit comments

Comments
 (0)