From 01d43dcce7a0f9916edb43efbbb1cb3fa07c1b5f Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Wed, 15 Oct 2025 20:29:19 +0000 Subject: [PATCH 1/5] Fix JAX version bugs for nightly image --- maxtext_jax_ai_image.Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/maxtext_jax_ai_image.Dockerfile b/maxtext_jax_ai_image.Dockerfile index cd2dd457a0..f92326ed5d 100644 --- a/maxtext_jax_ai_image.Dockerfile +++ b/maxtext_jax_ai_image.Dockerfile @@ -51,7 +51,9 @@ RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg RUN if [ "$DEVICE" = "tpu" ]; then \ python3 -m pip install 'google-tunix>=0.1.2'; \ # TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600) - python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0'; \ + if [ "$MODE" = "stable_stack" ]; then \ + python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0'; \ + fi; \ fi # Now copy the remaining code (source files that may change frequently) From 6c6f25796948b8fae3a269922d00c24a231bd089 Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Wed, 15 Oct 2025 23:22:04 +0000 Subject: [PATCH 2/5] fix libtpu --- .github/workflows/RunTests.yml | 2 +- maxtext_jax_ai_image.Dockerfile | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/RunTests.yml b/.github/workflows/RunTests.yml index 3787f22074..357965c7e6 100644 --- a/.github/workflows/RunTests.yml +++ b/.github/workflows/RunTests.yml @@ -51,7 +51,7 @@ jobs: device_name: v4-8 cloud_runner: linux-x86-n2-16-buildkit build_mode: jax_ai_image - base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest + base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:jax0.7.0_rev1 gpu_image: needs: prelim diff --git a/maxtext_jax_ai_image.Dockerfile b/maxtext_jax_ai_image.Dockerfile index f92326ed5d..faa70b1fcd 100644 --- a/maxtext_jax_ai_image.Dockerfile +++ b/maxtext_jax_ai_image.Dockerfile @@ -50,9 +50,10 @@ RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg # Install google-tunix for TPU devices, skip for GPU RUN if [ "$DEVICE" = "tpu" ]; then \ python3 -m pip install 'google-tunix>=0.1.2'; \ + echo "tunix installed, MODE is $MODE"; \ # TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600) if [ "$MODE" = "stable_stack" ]; then \ - python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0'; \ + python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0' 'libtpu==0.0.19'; \ fi; \ fi From a7e4132a1e0afbb3ef8557a65b29175c42fe0600 Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Wed, 15 Oct 2025 23:30:59 +0000 Subject: [PATCH 3/5] fix --- maxtext_jax_ai_image.Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/maxtext_jax_ai_image.Dockerfile b/maxtext_jax_ai_image.Dockerfile index faa70b1fcd..faf85037ce 100644 --- a/maxtext_jax_ai_image.Dockerfile +++ b/maxtext_jax_ai_image.Dockerfile @@ -51,6 +51,7 @@ RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg RUN if [ "$DEVICE" = "tpu" ]; then \ python3 -m pip install 'google-tunix>=0.1.2'; \ echo "tunix installed, MODE is $MODE"; \ + echo $MODE; \ # TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600) if [ "$MODE" = "stable_stack" ]; then \ python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0' 'libtpu==0.0.19'; \ From 1ab3f230ebfdf405562af9f83bd3e9672f1dce7f Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Thu, 16 Oct 2025 00:13:18 +0000 Subject: [PATCH 4/5] print base image --- maxtext_jax_ai_image.Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/maxtext_jax_ai_image.Dockerfile b/maxtext_jax_ai_image.Dockerfile index faf85037ce..410842f588 100644 --- a/maxtext_jax_ai_image.Dockerfile +++ b/maxtext_jax_ai_image.Dockerfile @@ -52,6 +52,7 @@ RUN if [ "$DEVICE" = "tpu" ]; then \ python3 -m pip install 'google-tunix>=0.1.2'; \ echo "tunix installed, MODE is $MODE"; \ echo $MODE; \ + echo $JAX_AI_IMAGE_BASEIMAGE; \ # TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600) if [ "$MODE" = "stable_stack" ]; then \ python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0' 'libtpu==0.0.19'; \ From b213113cb905fbc50a0ac3b2b4c53e52c8fa1a7d Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Thu, 16 Oct 2025 16:56:56 +0000 Subject: [PATCH 5/5] upgrade jax --- .github/workflows/RunTests.yml | 2 +- maxtext_jax_ai_image.Dockerfile | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/RunTests.yml b/.github/workflows/RunTests.yml index 357965c7e6..3787f22074 100644 --- a/.github/workflows/RunTests.yml +++ b/.github/workflows/RunTests.yml @@ -51,7 +51,7 @@ jobs: device_name: v4-8 cloud_runner: linux-x86-n2-16-buildkit build_mode: jax_ai_image - base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:jax0.7.0_rev1 + base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest gpu_image: needs: prelim diff --git a/maxtext_jax_ai_image.Dockerfile b/maxtext_jax_ai_image.Dockerfile index 410842f588..ce7f4f3701 100644 --- a/maxtext_jax_ai_image.Dockerfile +++ b/maxtext_jax_ai_image.Dockerfile @@ -50,13 +50,15 @@ RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg # Install google-tunix for TPU devices, skip for GPU RUN if [ "$DEVICE" = "tpu" ]; then \ python3 -m pip install 'google-tunix>=0.1.2'; \ - echo "tunix installed, MODE is $MODE"; \ - echo $MODE; \ echo $JAX_AI_IMAGE_BASEIMAGE; \ # TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600) - if [ "$MODE" = "stable_stack" ]; then \ - python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0' 'libtpu==0.0.19'; \ - fi; \ + if [[ "$JAX_AI_IMAGE_BASEIMAGE" == *"nightly"* ]]; then \ + echo "Nightly image detected"; \ + python3 -m pip install --upgrade jax jaxlib; \ + else \ + echo "Non-nightly image"; \ + python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0' 'libtpu==0.0.19'; \ + fi fi # Now copy the remaining code (source files that may change frequently)