Skip to content

Commit ec3e2d4

Browse files
committed
update to flag=post-training
1 parent 3c140c1 commit ec3e2d4

File tree

2 files changed

+167
-22
lines changed

2 files changed

+167
-22
lines changed

docker_build_dependency_image.sh

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# works with any custom wheels.
2828
# bash docker_build_dependency_image.sh MODE=custom_wheels
2929

30-
# bash docker_build_dependency_image.sh MODE=post-training
30+
# bash docker_build_dependency_image.sh MODE=grpo
3131

3232
# Enable "exit immediately if any command fails" option
3333
set -e
@@ -68,17 +68,17 @@ if [[ -z ${MODE} ]]; then
6868
export MODE=stable
6969
echo "Default MODE=${MODE}"
7070
export CUSTOM_JAX=0
71-
export INSTALL_POST_TRAINING=0
71+
export INSTALL_GRPO=0
7272
elif [[ ${MODE} == "custom_wheels" ]] ; then
7373
export MODE=nightly
7474
export CUSTOM_JAX=1
75-
export INSTALL_POST_TRAINING=0
76-
elif [[ ${MODE} == "post-training" || ${MODE} == "post-training-experimental" ]] ; then
77-
export INSTALL_POST_TRAINING=1
75+
export INSTALL_GRPO=0
76+
elif [[ ${MODE} == "grpo" || ${MODE} == "grpo-experimental" ]] ; then
77+
export INSTALL_GRPO=1
7878
export CUSTOM_JAX=0
7979
else
8080
export CUSTOM_JAX=0
81-
export INSTALL_POST_TRAINING=0
81+
export INSTALL_GRPO=0
8282
fi
8383

8484
if [[ -z ${DEVICE} ]]; then
@@ -124,9 +124,9 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
124124
elif [[ ${MANTARAY} == "true" ]]; then
125125
echo "Building with benchmark-db"
126126
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
127-
elif [[ ${INSTALL_POST_TRAINING} -eq 1 && ${DEVICE} == "tpu" ]]; then
128-
echo "Installing MaxText stable mode dependencies for Post-Training"
129-
docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
127+
elif [[ ${INSTALL_GRPO} -eq 1 && ${DEVICE} == "tpu" ]]; then
128+
echo "Installing MaxText stable mode dependencies for GRPO BASEIMAGE=$BASEIMAGE"
129+
docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
130130
else
131131
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
132132
fi
@@ -136,29 +136,27 @@ else
136136
docker build --network host --build-arg CUSTOM_LIBTPU=true -f ./maxtext_libtpu_path.Dockerfile -t ${LOCAL_IMAGE_NAME} .
137137
fi
138138

139-
if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
139+
if [[ ${INSTALL_GRPO} -eq 1 ]] ; then
140140
if [[ ${DEVICE} != "tpu" ]] ; then
141-
echo "Error: MODE=post-training is only supported for DEVICE=tpu"
141+
echo "Error: MODE=grpo is only supported for DEVICE=tpu"
142142
exit 1
143143
fi
144144

145-
# # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__.
146-
# # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext).
147-
# rsync -a --exclude='__pycache__' ../tpu_commons .
148-
# # To install vllm from a local path, we copy it into the build context, excluding __pycache__.
149-
# # This assumes vllm is a sibling directory to the current one (maxtext).
150-
# rsync -a --exclude='__pycache__' ../vllm .
145+
# To install from local paths, we copy vllm and tpu-inference into the build context.
146+
# This assumes vllm and tpu-inference are sibling directories to the current one (maxtext).
147+
echo "Copying local vllm and tpu-inference directories into the build context..."
148+
rsync -a --exclude='__pycache__' ../tunix .
149+
rsync -a --exclude='__pycache__' ../tpu-inference .
150+
rsync -a --exclude='__pycache__' ../vllm .
151151

152-
# rsync -a --exclude='__pycache__' ../tunix .
153-
154-
# # The cleanup is set to run even if the build fails to remove the copied directory.
155-
# trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM
152+
# The cleanup is set to run even if the build fails to remove the copied directories.
153+
trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu-inference ./vllm" EXIT INT TERM
156154

157155
docker build \
158156
--network host \
159157
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
160158
--build-arg MODE=${MODE} \
161-
-f ./maxtext_post_training_dependencies.Dockerfile \
159+
-f ./maxtext_grpo_dependencies.Dockerfile \
162160
-t ${LOCAL_IMAGE_NAME} .
163161
fi
164162

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
ARG BASEIMAGE
16+
FROM ${BASEIMAGE}
17+
ARG MODE
18+
ENV MODE=$MODE
19+
20+
RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
21+
RUN pip uninstall -y jax jaxlib libtpu
22+
23+
RUN pip install aiohttp==3.12.15
24+
25+
# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
26+
RUN pip install keyring keyrings.google-artifactregistry-auth
27+
28+
RUN pip install numba==0.61.2
29+
30+
COPY tunix /tunix
31+
RUN pip install -e /tunix --no-cache-dir
32+
33+
34+
COPY vllm /vllm
35+
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
36+
--extra-index-url https://pypi.org/simple/ \
37+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
38+
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
39+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
40+
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
41+
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
42+
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
43+
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
44+
45+
46+
COPY tpu-inference /tpu-inference
47+
RUN pip install -e /tpu-inference --no-cache-dir --pre \
48+
--extra-index-url https://pypi.org/simple/ \
49+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
50+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
51+
52+
# # Install vLLM for Jax and TPUs from the artifact registry
53+
# RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
54+
# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
55+
# --extra-index-url https://pypi.org/simple/ \
56+
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
57+
# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
58+
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
59+
# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
60+
# --find-links https://storage.googleapis.com/libtpu-releases/index.html \
61+
# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
62+
# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
63+
# vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu
64+
65+
# # Install tpu-commons from the artifact registry
66+
# RUN pip install --no-cache-dir --pre \
67+
# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
68+
# --extra-index-url https://pypi.org/simple/ \
69+
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
70+
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
71+
# tpu-commons==0.1.2
72+
73+
# # Uninstall existing jax to avoid conflicts
74+
# # RUN pip uninstall -y jax jaxlib libtpu
75+
76+
# # --- STAGE 1: Install Static Dependencies ---
77+
# # Install any packages *not* defined in your project dependency files
78+
# RUN --mount=type=cache,target=/root/.cache/pip pip install \
79+
# aiohttp==3.12.15\
80+
# keyring \
81+
# keyrings.google-artifactregistry-auth
82+
83+
# RUN --mount=type=cache,target=/root/.cache/pip pip install \
84+
# numba==0.61.2
85+
86+
# # RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
87+
# # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
88+
89+
# # Copy *only* the dependency definition files.
90+
# # This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
91+
# COPY vllm/requirements/tpu.txt /tmp/
92+
# COPY vllm/requirements/build.txt /tmp/
93+
# COPY vllm/requirements/common.txt /tmp/
94+
# COPY tpu-inference/requirements.txt /tmp/
95+
96+
# # Run the full dependency installation.
97+
# # This entire layer is cached and will *only* be rebuilt if
98+
# # these .txt files change.
99+
# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
100+
# # Set the target device so pip installs the right JAX/libtpu
101+
# # Install tpu-inference dependencies
102+
# export VLLM_TARGET_DEVICE="tpu" && \
103+
# pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
104+
# --extra-index-url https://pypi.org/simple/ \
105+
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
106+
# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
107+
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
108+
# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
109+
# --find-links https://storage.googleapis.com/libtpu-releases/index.html \
110+
# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
111+
# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
112+
113+
# # Install tpu-inference dependencies
114+
# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
115+
# pip install -r /tmp/requirements.txt --no-cache-dir --pre \
116+
# --extra-index-url https://pypi.org/simple/ \
117+
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
118+
# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
119+
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
120+
# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
121+
# --find-links https://storage.googleapis.com/libtpu-releases/index.html \
122+
# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
123+
# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
124+
125+
# # --- STAGE 3: Install Project Source Code ---
126+
127+
# # Now, copy the full source code. This invalidates cache frequently,
128+
# # but the next step is fast.
129+
# COPY vllm /vllm/
130+
# COPY tpu-inference /tpu-inference/
131+
# COPY tunix /tunix
132+
133+
134+
# # Install in editable mode. This is lightning-fast because all
135+
# # dependencies were installed and cached in STAGE 2.
136+
# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
137+
# RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
138+
139+
# RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
140+
# # RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
141+
142+
RUN if [ "$MODE" = "grpo-experimental" ]; then \
143+
echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \
144+
pip uninstall -y jax jaxlib libtpu && \
145+
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
146+
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
147+
fi

0 commit comments

Comments
 (0)