Skip to content

Commit a82a1b4

Browse files
committed
split vllm dependencies and code
1 parent e2a4405 commit a82a1b4

File tree

2 files changed

+43
-49
lines changed

2 files changed

+43
-49
lines changed

docker_build_dependency_image.sh

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,18 +142,14 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
142142
exit 1
143143
fi
144144

145-
# To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__.
146-
# This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext).
145+
# To install from local paths, we copy vllm and tpu_commons into the build context.
146+
# This assumes vllm and tpu_commons are sibling directories to the current one (maxtext).
147+
echo "Copying local vllm and tpu_commons directories into the build context..."
147148
rsync -a --exclude='__pycache__' ../tpu_commons .
148-
# To install vllm from a local path, we copy it into the build context, excluding __pycache__.
149-
# This assumes vllm is a sibling directory to the current one (maxtext).
150149
rsync -a --exclude='__pycache__' ../vllm .
151150

152-
# rsync -a --exclude='__pycache__' ../tunix .
153-
154-
# The cleanup is set to run even if the build fails to remove the copied directory.
155-
# trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM
156-
trap "rm -rf ./tpu_commons ./vllm " EXIT INT TERM
151+
# The cleanup is set to run even if the build fails to remove the copied directories.
152+
trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu_commons ./vllm" EXIT INT TERM
157153

158154
docker build \
159155
--network host \

maxtext_post_training_dependencies.Dockerfile

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
ARG BASEIMAGE
1616
FROM ${BASEIMAGE}
1717
ARG MODE
18-
1918
ENV MODE=$MODE
2019

2120
RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}"
@@ -24,53 +23,52 @@ RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with M
2423
# Uninstall existing jax to avoid conflicts
2524
# RUN pip uninstall -y jax jaxlib libtpu
2625

27-
RUN pip install aiohttp==3.12.15
26+
# --- STAGE 1: Install Static Dependencies ---
27+
# Install any packages *not* defined in your project dependency files
28+
RUN --mount=type=cache,target=/root/.cache/pip pip install \
29+
aiohttp==3.12.15\
30+
keyring \
31+
keyrings.google-artifactregistry-auth \
32+
numba==0.61.2
2833

29-
# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
30-
RUN pip install keyring keyrings.google-artifactregistry-auth
34+
# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
3135

32-
RUN pip install numba==0.61.2
36+
# Copy *only* the dependency definition files.
37+
# This assumes vllm and tpu_commons are in the build context, copied from the parent directory.
38+
COPY vllm/requirements/tpu.txt /tmp/
39+
COPY vllm/requirements/build.txt /tmp/
40+
COPY vllm/requirements/common.txt /tmp/
41+
COPY tpu_commons/requirements.txt /tmp/
3342

34-
COPY vllm /vllm
35-
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
36-
--extra-index-url https://pypi.org/simple/ \
37-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
38-
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
39-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
40-
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
41-
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
42-
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
43-
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
43+
# Run the full dependency installation.
44+
# This entire layer is cached and will *only* be rebuilt if
45+
# these .txt files change.
46+
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
47+
# Set the target device so pip installs the right JAX/libtpu
48+
export VLLM_TARGET_DEVICE="tpu" && \
49+
pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
50+
--extra-index-url https://pypi.org/simple/ \
51+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
52+
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
53+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
54+
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
55+
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
56+
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
57+
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
4458

45-
# Install tpu-commons from local source
46-
COPY tpu_commons /tpu_commons
47-
RUN pip install -e /tpu_commons --no-cache-dir --pre \
48-
--extra-index-url https://pypi.org/simple/ \
49-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
50-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
59+
# --- STAGE 3: Install Project Source Code ---
5160

52-
# # Install vLLM for Jax and TPUs from the artifact registry
53-
# RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
54-
# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
55-
# --extra-index-url https://pypi.org/simple/ \
56-
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
57-
# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
58-
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
59-
# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
60-
# --find-links https://storage.googleapis.com/libtpu-releases/index.html \
61-
# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
62-
# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
63-
# vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu
61+
# Now, copy the full source code. This invalidates cache frequently,
62+
# but the next step is fast.
63+
COPY vllm /vllm/
64+
COPY tpu_commons /tpu_commons/
6465

65-
# # Install tpu-commons from the artifact registry
66-
# RUN pip install --no-cache-dir --pre \
67-
# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
68-
# --extra-index-url https://pypi.org/simple/ \
69-
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
70-
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
71-
# tpu-commons==0.1.2
66+
# Install in editable mode. This is lightning-fast because all
67+
# dependencies were installed and cached in STAGE 2.
68+
RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ -e /tpu_commons/
7269

7370
RUN if [ "$MODE" = "post-training-experimental" ]; then \
71+
echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \
7472
pip uninstall -y jax jaxlib libtpu && \
7573
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
7674
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \

0 commit comments

Comments
 (0)