1515ARG BASEIMAGE
1616FROM ${BASEIMAGE}
1717ARG MODE
18-
1918ENV MODE=$MODE
2019
2120RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}"
@@ -24,53 +23,52 @@ RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with M
2423# Uninstall existing jax to avoid conflicts
2524# RUN pip uninstall -y jax jaxlib libtpu
2625
27- RUN pip install aiohttp==3.12.15
26+ # --- STAGE 1: Install Static Dependencies ---
27+ # Install any packages *not* defined in your project dependency files
28+ RUN --mount=type=cache,target=/root/.cache/pip pip install \
29+ aiohttp==3.12.15\
30+ keyring \
31+ keyrings.google-artifactregistry-auth \
32+ numba==0.61.2
2833
29- # Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
30- RUN pip install keyring keyrings.google-artifactregistry-auth
34+ # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
3135
32- RUN pip install numba==0.61.2
36+ # Copy *only* the dependency definition files.
37+ # This assumes vllm and tpu_commons are in the build context, copied from the parent directory.
38+ COPY vllm/requirements/tpu.txt /tmp/
39+ COPY vllm/requirements/build.txt /tmp/
40+ COPY vllm/requirements/common.txt /tmp/
41+ COPY tpu_commons/requirements.txt /tmp/
3342
34- COPY vllm /vllm
35- RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
36- --extra-index-url https://pypi.org/simple/ \
37- --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
38- --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
39- --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
40- --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
41- --find-links https://storage.googleapis.com/libtpu-releases/index.html \
42- --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
43- --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
43+ # Run the full dependency installation.
44+ # This entire layer is cached and will *only* be rebuilt if
45+ # these .txt files change.
46+ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
47+ # Set the target device so pip installs the right JAX/libtpu
48+ export VLLM_TARGET_DEVICE="tpu" && \
49+ pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
50+ --extra-index-url https://pypi.org/simple/ \
51+ --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
52+ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
53+ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
54+ --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
55+ --find-links https://storage.googleapis.com/libtpu-releases/index.html \
56+ --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
57+ --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
4458
45- # Install tpu-commons from local source
46- COPY tpu_commons /tpu_commons
47- RUN pip install -e /tpu_commons --no-cache-dir --pre \
48- --extra-index-url https://pypi.org/simple/ \
49- --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
50- --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
59+ # --- STAGE 3: Install Project Source Code ---
5160
52- # # Install vLLM for Jax and TPUs from the artifact registry
53- # RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
54- # --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
55- # --extra-index-url https://pypi.org/simple/ \
56- # --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
57- # --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
58- # --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
59- # --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
60- # --find-links https://storage.googleapis.com/libtpu-releases/index.html \
61- # --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
62- # --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
63- # vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu
61+ # Now, copy the full source code. This invalidates cache frequently,
62+ # but the next step is fast.
63+ COPY vllm /vllm/
64+ COPY tpu_commons /tpu_commons/
6465
65- # # Install tpu-commons from the artifact registry
66- # RUN pip install --no-cache-dir --pre \
67- # --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
68- # --extra-index-url https://pypi.org/simple/ \
69- # --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
70- # --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
71- # tpu-commons==0.1.2
66+ # Install in editable mode. This is lightning-fast because all
67+ # dependencies were installed and cached in STAGE 2.
68+ RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ -e /tpu_commons/
7269
7370RUN if [ "$MODE" = "post-training-experimental" ]; then \
71+ echo "MODE=grpo-experimental: Re-installing JAX/libtpu" ; \
7472 pip uninstall -y jax jaxlib libtpu && \
7573 pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
7674 pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
0 commit comments