Skip to content

Commit 6e6d2f4

Browse files
committed
grpo runs of new hardware
1 parent a82a1b4 commit 6e6d2f4

File tree

4 files changed

+44
-20
lines changed

4 files changed

+44
-20
lines changed

docker_build_dependency_image.sh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,15 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
142142
exit 1
143143
fi
144144

145-
# To install from local paths, we copy vllm and tpu_commons into the build context.
146-
# This assumes vllm and tpu_commons are sibling directories to the current one (maxtext).
147-
echo "Copying local vllm and tpu_commons directories into the build context..."
148-
rsync -a --exclude='__pycache__' ../tpu_commons .
145+
# To install from local paths, we copy vllm and tpu-inference into the build context.
146+
# This assumes vllm and tpu-inference are sibling directories to the current one (maxtext).
147+
echo "Copying local vllm and tpu-inference directories into the build context..."
148+
rsync -a --exclude='__pycache__' ../tunix .
149+
rsync -a --exclude='__pycache__' ../tpu-inference .
149150
rsync -a --exclude='__pycache__' ../vllm .
150151

151152
# The cleanup is set to run even if the build fails to remove the copied directories.
152-
trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu_commons ./vllm" EXIT INT TERM
153+
trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu-inference ./vllm" EXIT INT TERM
153154

154155
docker build \
155156
--network host \

maxtext_post_training_dependencies.Dockerfile

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,27 @@ RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with M
2828
RUN --mount=type=cache,target=/root/.cache/pip pip install \
2929
aiohttp==3.12.15\
3030
keyring \
31-
keyrings.google-artifactregistry-auth \
31+
keyrings.google-artifactregistry-auth
32+
33+
RUN --mount=type=cache,target=/root/.cache/pip pip install \
3234
numba==0.61.2
3335

36+
# RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
3437
# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
3538

3639
# Copy *only* the dependency definition files.
37-
# This assumes vllm and tpu_commons are in the build context, copied from the parent directory.
40+
# This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
3841
COPY vllm/requirements/tpu.txt /tmp/
3942
COPY vllm/requirements/build.txt /tmp/
4043
COPY vllm/requirements/common.txt /tmp/
41-
COPY tpu_commons/requirements.txt /tmp/
44+
COPY tpu-inference/requirements.txt /tmp/
4245

4346
# Run the full dependency installation.
4447
# This entire layer is cached and will *only* be rebuilt if
4548
# these .txt files change.
46-
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
49+
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
4750
# Set the target device so pip installs the right JAX/libtpu
51+
# Install tpu-inference dependencies
4852
export VLLM_TARGET_DEVICE="tpu" && \
4953
pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
5054
--extra-index-url https://pypi.org/simple/ \
@@ -56,16 +60,34 @@ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
5660
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
5761
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
5862

63+
# Install tpu-inference dependencies
64+
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
65+
pip install -r /tmp/requirements.txt --no-cache-dir --pre \
66+
--extra-index-url https://pypi.org/simple/ \
67+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
68+
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
69+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
70+
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
71+
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
72+
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
73+
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
74+
5975
# --- STAGE 3: Install Project Source Code ---
6076

6177
# Now, copy the full source code. This invalidates cache frequently,
6278
# but the next step is fast.
6379
COPY vllm /vllm/
64-
COPY tpu_commons /tpu_commons/
80+
COPY tpu-inference /tpu-inference/
81+
COPY tunix /tunix
82+
6583

6684
# Install in editable mode. This is lightning-fast because all
6785
# dependencies were installed and cached in STAGE 2.
68-
RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ -e /tpu_commons/
86+
RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
87+
RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
88+
89+
RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
90+
# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
6991

7092
RUN if [ "$MODE" = "post-training-experimental" ]; then \
7193
echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \

src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,12 @@
195195
# ====== Training ======
196196
BATCH_SIZE = 1
197197
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
198-
# NUM_BATCHES = 3738
199-
NUM_BATCHES = 4 # 200
198+
NUM_BATCHES = 3738
199+
# NUM_BATCHES = 4 # 200
200200
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
201201
# increased to a max. of 330 (if batch size is 4).
202202
NUM_TEST_BATCHES = 330
203-
NUM_TEST_BATCHES = 5 # 200
203+
# NUM_TEST_BATCHES = 5 # 200
204204

205205
EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
206206
NUM_EPOCHS = 1 # can potentially train for more epochs

src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,20 @@
142142

143143

144144
# ====== Input Checkpoint directory =====
145-
MODEL_CHECKPOINT_PATH = "/path/to/scanned/model/ckpt_load_dir/"
145+
MODEL_CHECKPOINT_PATH = "gs://mazumdera-test-bucket-europe-west4/llama3.1-8b-Instruct/scanned-pathways/0/items"
146146

147147
# ====== Checkpoint directory =====
148148
LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/"
149149
if not os.path.exists(LOG_DIR):
150150
os.makedirs(LOG_DIR)
151151

152152
# ===== Profiling =====
153-
PROFILE_DIR = f"/path/to/profile_dir/{run_id}/profiles_llama3/"
153+
PROFILE_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/v5p-64/llama3-1-8b/profile_dir/{run_id}/profiles_llama3/"
154154
if not epath.Path(PROFILE_DIR).exists():
155155
epath.Path(PROFILE_DIR).mkdir(parents=True)
156156

157157
# ====== Checkpoint saving ======
158-
CKPT_DIR = f"/path/to/ckpt_save_dir/{run_id}/ckpts_llama3/"
158+
CKPT_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/v5p-64/llama3-1-8b/ckpt_save_dir/{run_id}/ckpts_llama3/"
159159

160160
if not epath.Path(CKPT_DIR).exists():
161161
epath.Path(CKPT_DIR).mkdir(parents=True)
@@ -195,11 +195,12 @@
195195
# ====== Training ======
196196
BATCH_SIZE = 1
197197
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
198-
# NUM_BATCHES = 3738
199-
NUM_BATCHES = 4 # 200
198+
NUM_BATCHES = 3738
199+
# NUM_BATCHES = 4 # 200
200200
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
201201
# increased to a max. of 330 (if batch size is 4).
202-
NUM_TEST_BATCHES = 5 # 200
202+
NUM_TEST_BATCHES = 330
203+
# NUM_TEST_BATCHES = 5 # 200
203204

204205
EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
205206
NUM_EPOCHS = 1 # can potentially train for more epochs

0 commit comments

Comments
 (0)