diff --git a/base_requirements/requirements.txt b/base_requirements/requirements.txt index bc3e68345..1d3c99f27 100644 --- a/base_requirements/requirements.txt +++ b/base_requirements/requirements.txt @@ -39,5 +39,5 @@ tiktoken tokamax transformers qwix -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index 0ef64ad55..8cfdcf31f 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -142,17 +142,15 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then exit 1 fi - # # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__. - # # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext). - # rsync -a --exclude='__pycache__' ../tpu_commons . - # # To install vllm from a local path, we copy it into the build context, excluding __pycache__. - # # This assumes vllm is a sibling directory to the current one (maxtext). - # rsync -a --exclude='__pycache__' ../vllm . - - # rsync -a --exclude='__pycache__' ../tunix . - - # # The cleanup is set to run even if the build fails to remove the copied directory. - # trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM + # To install from local paths, we copy vllm and tpu-inference into the build context. + # This assumes vllm and tpu-inference are sibling directories to the current one (maxtext). + echo "Copying local vllm and tpu-inference directories into the build context..." + rsync -a --exclude='__pycache__' ../tunix . + rsync -a --exclude='__pycache__' ../tpu-inference . + rsync -a --exclude='__pycache__' ../vllm . + + # The cleanup is set to run even if the build fails to remove the copied directories. + trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu-inference ./vllm" EXIT INT TERM docker build \ --network host \ diff --git a/maxtext_grpo_dependencies_split.Dockerfile b/maxtext_grpo_dependencies_split.Dockerfile new file mode 100644 index 000000000..dd8ecd452 --- /dev/null +++ b/maxtext_grpo_dependencies_split.Dockerfile @@ -0,0 +1,96 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG BASEIMAGE +FROM ${BASEIMAGE} +ARG MODE +ENV MODE=$MODE + +RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}" + +# Uninstall existing jax to avoid conflicts +# RUN pip uninstall -y jax jaxlib libtpu + +# --- STAGE 1: Install Static Dependencies --- +# Install any packages *not* defined in your project dependency files +RUN --mount=type=cache,target=/root/.cache/pip pip install \ + aiohttp==3.12.15\ + keyring \ + keyrings.google-artifactregistry-auth + +RUN --mount=type=cache,target=/root/.cache/pip pip install \ + numba==0.61.2 + +# RUN VLLM_TARGET_DEVICE="tpu" pip install vllm +# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) --- + +# Copy *only* the dependency definition files. +# This assumes vllm and tpu-inference are in the build context, copied from the parent directory. +COPY vllm/requirements/tpu.txt /tmp/ +COPY vllm/requirements/build.txt /tmp/ +COPY vllm/requirements/common.txt /tmp/ +COPY tpu-inference/requirements.txt /tmp/ + +# Run the full dependency installation. +# This entire layer is cached and will *only* be rebuilt if +# these .txt files change. +RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ + # Set the target device so pip installs the right JAX/libtpu + # Install tpu-inference dependencies + export VLLM_TARGET_DEVICE="tpu" && \ + pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + + # Install tpu-inference dependencies +RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ + pip install -r /tmp/requirements.txt --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + +# --- STAGE 3: Install Project Source Code --- + +# Now, copy the full source code. This invalidates cache frequently, +# but the next step is fast. +COPY vllm /vllm/ +COPY tpu-inference /tpu-inference/ +COPY tunix /tunix + + +# Install in editable mode. This is lightning-fast because all +# dependencies were installed and cached in STAGE 2. +RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ +RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/ + +RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/ +# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/ + +RUN if [ "$MODE" = "grpo-experimental" ]; then \ + echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \ + pip uninstall -y jax jaxlib libtpu && \ + pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ + pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ + fi diff --git a/maxtext_post_training_dependencies.Dockerfile b/maxtext_post_training_dependencies.Dockerfile index 277c1fe92..f902752d0 100644 --- a/maxtext_post_training_dependencies.Dockerfile +++ b/maxtext_post_training_dependencies.Dockerfile @@ -15,15 +15,11 @@ ARG BASEIMAGE FROM ${BASEIMAGE} ARG MODE - ENV MODE=$MODE RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}" -# Uninstall existing jax to avoid conflicts -RUN pip uninstall -y jax jaxlib libtpu - RUN pip install aiohttp==3.12.15 # Install Python packages that enable pip to authenticate with Google Artifact Registry automatically. @@ -31,9 +27,12 @@ RUN pip install keyring keyrings.google-artifactregistry-auth RUN pip install numba==0.61.2 -# Install vLLM for Jax and TPUs from the artifact registry -RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ - --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +COPY tunix /tunix +RUN pip install -e /tunix --no-cache-dir + + +COPY vllm /vllm +RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \ --extra-index-url https://pypi.org/simple/ \ --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ @@ -41,18 +40,107 @@ RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ --find-links https://storage.googleapis.com/libtpu-releases/index.html \ --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ - --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ - vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + -# Install tpu-commons from the artifact registry -RUN pip install --no-cache-dir --pre \ - --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +COPY tpu-inference /tpu-inference +RUN pip install -e /tpu-inference --no-cache-dir --pre \ --extra-index-url https://pypi.org/simple/ \ --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ - tpu-commons==0.1.2 + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + +# # Install vLLM for Jax and TPUs from the artifact registry +# RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ +# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ +# vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu + +# # Install tpu-commons from the artifact registry +# RUN pip install --no-cache-dir --pre \ +# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# tpu-commons==0.1.2 + +# # Uninstall existing jax to avoid conflicts +# # RUN pip uninstall -y jax jaxlib libtpu + +# # --- STAGE 1: Install Static Dependencies --- +# # Install any packages *not* defined in your project dependency files +# RUN --mount=type=cache,target=/root/.cache/pip pip install \ +# aiohttp==3.12.15\ +# keyring \ +# keyrings.google-artifactregistry-auth + +# RUN --mount=type=cache,target=/root/.cache/pip pip install \ +# numba==0.61.2 + +# # RUN VLLM_TARGET_DEVICE="tpu" pip install vllm +# # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) --- + +# # Copy *only* the dependency definition files. +# # This assumes vllm and tpu-inference are in the build context, copied from the parent directory. +# COPY vllm/requirements/tpu.txt /tmp/ +# COPY vllm/requirements/build.txt /tmp/ +# COPY vllm/requirements/common.txt /tmp/ +# COPY tpu-inference/requirements.txt /tmp/ + +# # Run the full dependency installation. +# # This entire layer is cached and will *only* be rebuilt if +# # these .txt files change. +# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ +# # Set the target device so pip installs the right JAX/libtpu +# # Install tpu-inference dependencies +# export VLLM_TARGET_DEVICE="tpu" && \ +# pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + +# # Install tpu-inference dependencies +# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ +# pip install -r /tmp/requirements.txt --no-cache-dir --pre \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + +# # --- STAGE 3: Install Project Source Code --- + +# # Now, copy the full source code. This invalidates cache frequently, +# # but the next step is fast. +# COPY vllm /vllm/ +# COPY tpu-inference /tpu-inference/ +# COPY tunix /tunix + + +# # Install in editable mode. This is lightning-fast because all +# # dependencies were installed and cached in STAGE 2. +# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ +# RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/ + +# RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/ +# # RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/ RUN if [ "$MODE" = "post-training-experimental" ]; then \ + echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \ pip uninstall -y jax jaxlib libtpu && \ pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ diff --git a/requirements.txt b/requirements.txt index 36471cf55..0f4050a47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,5 +39,5 @@ tensorflow tiktoken tokamax>=0.0.4 transformers -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 19f262c3b..5141e4333 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -3,7 +3,7 @@ datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72184f89e7814cf784360.zip flax>=0.11.0 google-api-python-client -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip grain[parquet]>=0.2.12 jaxtyping jsonlines diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py index def23ec94..f50661f6a 100644 --- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py @@ -142,20 +142,24 @@ # ====== Input Checkpoint directory ===== -MODEL_CHECKPOINT_PATH = "/path/to/scanned/model/ckpt_load_dir/" +MODEL_CHECKPOINT_PATH = "gs://mazumdera-test-bucket-europe-west4/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items" +# MODEL_CHECKPOINT_PATH = "gs://maxtext-model-checkpoints/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items" # ====== Checkpoint directory ===== -LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/" +LOG_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/tensorboard/grpo/logs_llama3/" +# LOG_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/tensorboard/grpo/logs_llama3/" if not os.path.exists(LOG_DIR): - os.makedirs(LOG_DIR) + epath.Path(LOG_DIR).mkdir(parents=True) # ===== Profiling ===== -PROFILE_DIR = f"/path/to/profile_dir/{run_id}/profiles_llama3/" +PROFILE_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/profiles_llama3/" +# PROFILE_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/profiles_llama3/" if not epath.Path(PROFILE_DIR).exists(): epath.Path(PROFILE_DIR).mkdir(parents=True) # ====== Checkpoint saving ====== -CKPT_DIR = f"/path/to/ckpt_save_dir/{run_id}/ckpts_llama3/" +CKPT_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/ckpts_llama3/" +# CKPT_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/ckpts_llama3/" if not epath.Path(CKPT_DIR).exists(): epath.Path(CKPT_DIR).mkdir(parents=True) @@ -195,11 +199,12 @@ # ====== Training ====== BATCH_SIZE = 1 # Increase `NUM_BATCHES` and `MAX_STEPS` for better results. -# NUM_BATCHES = 3738 -NUM_BATCHES = 4 # 200 +NUM_BATCHES = 3738 +# NUM_BATCHES = 4 # 200 # Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be # increased to a max. of 330 (if batch size is 4). -NUM_TEST_BATCHES = 5 # 200 +NUM_TEST_BATCHES = 330 +# NUM_TEST_BATCHES = 5 # 200 EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. NUM_EPOCHS = 1 # can potentially train for more epochs diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py index f994e7356..557e8ce9f 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py @@ -142,7 +142,7 @@ # ====== Input Checkpoint directory ===== -MODEL_CHECKPOINT_PATH = "/path/to/scanned/model/ckpt_load_dir/" +MODEL_CHECKPOINT_PATH = "gs://mazumdera-test-bucket-europe-west4/llama3.1-8b-Instruct/scanned-pathways/0/items" # ====== Checkpoint directory ===== LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/" @@ -150,12 +150,12 @@ os.makedirs(LOG_DIR) # ===== Profiling ===== -PROFILE_DIR = f"/path/to/profile_dir/{run_id}/profiles_llama3/" +PROFILE_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/v5p-64/llama3-1-8b/profile_dir/{run_id}/profiles_llama3/" if not epath.Path(PROFILE_DIR).exists(): epath.Path(PROFILE_DIR).mkdir(parents=True) # ====== Checkpoint saving ====== -CKPT_DIR = f"/path/to/ckpt_save_dir/{run_id}/ckpts_llama3/" +CKPT_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/v5p-64/llama3-1-8b/ckpt_save_dir/{run_id}/ckpts_llama3/" if not epath.Path(CKPT_DIR).exists(): epath.Path(CKPT_DIR).mkdir(parents=True) @@ -195,11 +195,12 @@ # ====== Training ====== BATCH_SIZE = 1 # Increase `NUM_BATCHES` and `MAX_STEPS` for better results. -# NUM_BATCHES = 3738 -NUM_BATCHES = 4 # 200 +NUM_BATCHES = 3738 +# NUM_BATCHES = 4 # 200 # Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be # increased to a max. of 330 (if batch size is 4). -NUM_TEST_BATCHES = 5 # 200 +NUM_TEST_BATCHES = 330 +# NUM_TEST_BATCHES = 5 # 200 EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. NUM_EPOCHS = 1 # can potentially train for more epochs diff --git a/src/install_maxtext_extra_deps/extra_deps_from_github.txt b/src/install_maxtext_extra_deps/extra_deps_from_github.txt index 676f2e58e..9f7bf08af 100644 --- a/src/install_maxtext_extra_deps/extra_deps_from_github.txt +++ b/src/install_maxtext_extra_deps/extra_deps_from_github.txt @@ -1,2 +1,2 @@ -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip