Skip to content

Commit 1909796

Browse files
committed
grpo runs of new hardware
1 parent 1dc6045 commit 1909796

File tree

4 files changed

+45
-21
lines changed

4 files changed

+45
-21
lines changed

docker_build_dependency_image.sh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,15 @@ if [[ ${INSTALL_GRPO} -eq 1 ]] ; then
142142
exit 1
143143
fi
144144

145-
# To install from local paths, we copy vllm and tpu_commons into the build context.
146-
# This assumes vllm and tpu_commons are sibling directories to the current one (maxtext).
147-
echo "Copying local vllm and tpu_commons directories into the build context..."
148-
rsync -a --exclude='__pycache__' ../tpu_commons .
145+
# To install from local paths, we copy vllm and tpu-inference into the build context.
146+
# This assumes vllm and tpu-inference are sibling directories to the current one (maxtext).
147+
echo "Copying local vllm and tpu-inference directories into the build context..."
148+
rsync -a --exclude='__pycache__' ../tunix .
149+
rsync -a --exclude='__pycache__' ../tpu-inference .
149150
rsync -a --exclude='__pycache__' ../vllm .
150151

151152
# The cleanup is set to run even if the build fails to remove the copied directories.
152-
trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu_commons ./vllm" EXIT INT TERM
153+
trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu-inference ./vllm" EXIT INT TERM
153154

154155
docker build \
155156
--network host \

maxtext_grpo_dependencies.Dockerfile

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ FROM ${BASEIMAGE}
1717
ARG MODE
1818
ENV MODE=$MODE
1919

20-
RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}"
20+
RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
2121

2222
# Uninstall existing jax to avoid conflicts
2323
# RUN pip uninstall -y jax jaxlib libtpu
@@ -27,23 +27,27 @@ RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix) with MODE=${MOD
2727
RUN --mount=type=cache,target=/root/.cache/pip pip install \
2828
aiohttp==3.12.15\
2929
keyring \
30-
keyrings.google-artifactregistry-auth \
30+
keyrings.google-artifactregistry-auth
31+
32+
RUN --mount=type=cache,target=/root/.cache/pip pip install \
3133
numba==0.61.2
3234

35+
# RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
3336
# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
3437

3538
# Copy *only* the dependency definition files.
36-
# This assumes vllm and tpu_commons are in the build context, copied from the parent directory.
39+
# This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
3740
COPY vllm/requirements/tpu.txt /tmp/
3841
COPY vllm/requirements/build.txt /tmp/
3942
COPY vllm/requirements/common.txt /tmp/
40-
COPY tpu_commons/requirements.txt /tmp/
43+
COPY tpu-inference/requirements.txt /tmp/
4144

4245
# Run the full dependency installation.
4346
# This entire layer is cached and will *only* be rebuilt if
4447
# these .txt files change.
45-
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
48+
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
4649
# Set the target device so pip installs the right JAX/libtpu
50+
# Install tpu-inference dependencies
4751
export VLLM_TARGET_DEVICE="tpu" && \
4852
pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
4953
--extra-index-url https://pypi.org/simple/ \
@@ -55,16 +59,34 @@ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
5559
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
5660
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
5761

62+
# Install tpu-inference dependencies
63+
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
64+
pip install -r /tmp/requirements.txt --no-cache-dir --pre \
65+
--extra-index-url https://pypi.org/simple/ \
66+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
67+
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
68+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
69+
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
70+
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
71+
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
72+
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
73+
5874
# --- STAGE 3: Install Project Source Code ---
5975

6076
# Now, copy the full source code. This invalidates cache frequently,
6177
# but the next step is fast.
6278
COPY vllm /vllm/
63-
COPY tpu_commons /tpu_commons/
79+
COPY tpu-inference /tpu-inference/
80+
COPY tunix /tunix
81+
6482

6583
# Install in editable mode. This is lightning-fast because all
6684
# dependencies were installed and cached in STAGE 2.
67-
RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ -e /tpu_commons/
85+
RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
86+
RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
87+
88+
RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
89+
# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
6890

6991
RUN if [ "$MODE" = "grpo-experimental" ]; then \
7092
echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \

src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,12 @@
195195
# ====== Training ======
196196
BATCH_SIZE = 1
197197
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
198-
# NUM_BATCHES = 3738
199-
NUM_BATCHES = 4 # 200
198+
NUM_BATCHES = 3738
199+
# NUM_BATCHES = 4 # 200
200200
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
201201
# increased to a max. of 330 (if batch size is 4).
202202
NUM_TEST_BATCHES = 330
203-
NUM_TEST_BATCHES = 5 # 200
203+
# NUM_TEST_BATCHES = 5 # 200
204204

205205
EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
206206
NUM_EPOCHS = 1 # can potentially train for more epochs

src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,20 @@
142142

143143

144144
# ====== Input Checkpoint directory =====
145-
MODEL_CHECKPOINT_PATH = "/path/to/scanned/model/ckpt_load_dir/"
145+
MODEL_CHECKPOINT_PATH = "gs://mazumdera-test-bucket-europe-west4/llama3.1-8b-Instruct/scanned-pathways/0/items"
146146

147147
# ====== Checkpoint directory =====
148148
LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/"
149149
if not os.path.exists(LOG_DIR):
150150
os.makedirs(LOG_DIR)
151151

152152
# ===== Profiling =====
153-
PROFILE_DIR = f"/path/to/profile_dir/{run_id}/profiles_llama3/"
153+
PROFILE_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/v5p-64/llama3-1-8b/profile_dir/{run_id}/profiles_llama3/"
154154
if not epath.Path(PROFILE_DIR).exists():
155155
epath.Path(PROFILE_DIR).mkdir(parents=True)
156156

157157
# ====== Checkpoint saving ======
158-
CKPT_DIR = f"/path/to/ckpt_save_dir/{run_id}/ckpts_llama3/"
158+
CKPT_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/v5p-64/llama3-1-8b/ckpt_save_dir/{run_id}/ckpts_llama3/"
159159

160160
if not epath.Path(CKPT_DIR).exists():
161161
epath.Path(CKPT_DIR).mkdir(parents=True)
@@ -195,11 +195,12 @@
195195
# ====== Training ======
196196
BATCH_SIZE = 1
197197
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
198-
# NUM_BATCHES = 3738
199-
NUM_BATCHES = 4 # 200
198+
NUM_BATCHES = 3738
199+
# NUM_BATCHES = 4 # 200
200200
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
201201
# increased to a max. of 330 (if batch size is 4).
202-
NUM_TEST_BATCHES = 5 # 200
202+
NUM_TEST_BATCHES = 330
203+
# NUM_TEST_BATCHES = 5 # 200
203204

204205
EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
205206
NUM_EPOCHS = 1 # can potentially train for more epochs

0 commit comments

Comments
 (0)