Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/RunTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
device_name: v4-8
cloud_runner: linux-x86-n2-16-buildkit
build_mode: jax_ai_image
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
base_image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.0-rev1

gpu_image:
needs: prelim
Expand Down
23 changes: 19 additions & 4 deletions maxtext_jax_ai_image.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ARG JAX_AI_IMAGE_BASEIMAGE

# JAX AI Base Image
FROM $JAX_AI_IMAGE_BASEIMAGE

Check warning on line 4 in maxtext_jax_ai_image.Dockerfile

View workflow job for this annotation

GitHub Actions / tpu_image / Build and upload image (v4-8)

Default value for global ARG results in an empty or invalid base image name

InvalidDefaultArgInFrom: Default value for ARG $JAX_AI_IMAGE_BASEIMAGE results in empty or invalid base image name More info: https://docs.docker.com/go/dockerfile/rule/invalid-default-arg-in-from/

Check warning on line 4 in maxtext_jax_ai_image.Dockerfile

View workflow job for this annotation

GitHub Actions / gpu_image / Build and upload image (a100-40gb-4)

Default value for global ARG results in an empty or invalid base image name

InvalidDefaultArgInFrom: Default value for ARG $JAX_AI_IMAGE_BASEIMAGE results in empty or invalid base image name More info: https://docs.docker.com/go/dockerfile/rule/invalid-default-arg-in-from/
ARG JAX_AI_IMAGE_BASEIMAGE

ARG COMMIT_HASH
Expand Down Expand Up @@ -49,10 +49,18 @@

# Install google-tunix for TPU devices, skip for GPU
RUN if [ "$DEVICE" = "tpu" ]; then \
echo "TPU device detected. Installing google-tunix."; \
python3 -m pip install 'google-tunix>=0.1.2'; \
# TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600)
python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0'; \
fi
\
if [[ "$JAX_AI_IMAGE_BASEIMAGE" == *"nightly"* ]]; then \
echo "Nightly image detected. Uninstalling base JAX and installing pre-release."; \
pip uninstall -y jax jaxlib libtpu; \
pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
else \
echo "Non-nightly image. Installing JAX 0.7.0."; \
python3 -m pip install jax[tpu]==0.7.0; \
fi; \
fi

# Now copy the remaining code (source files that may change frequently)
COPY . .
Expand All @@ -68,7 +76,14 @@
fi

# Run the script available in JAX AI base image to generate the manifest file
RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH
RUN if [ -d "/jax-ai-image" ]; then \
echo "Found /jax-ai-image directory. Running with 'jax-ai-image' path."; \
bash /jax-ai-image/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH; \
else \
echo "/jax-ai-image not found. Running with 'jax-stable-stack' path."; \
bash /jax-stable-stack/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH; \
fi


# Install (editable) MaxText
RUN test -f '/tmp/venv_created' && "$(tail -n1 /tmp/venv_created)"/bin/activate ; pip install --no-dependencies -e .
Loading