@@ -28,23 +28,27 @@ RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with M
2828RUN --mount=type=cache,target=/root/.cache/pip pip install \
2929 aiohttp==3.12.15\
3030 keyring \
31- keyrings.google-artifactregistry-auth \
31+ keyrings.google-artifactregistry-auth
32+
33+ RUN --mount=type=cache,target=/root/.cache/pip pip install \
3234 numba==0.61.2
3335
36+ # RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
3437# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
3538
3639# Copy *only* the dependency definition files.
37- # This assumes vllm and tpu_commons are in the build context, copied from the parent directory.
40+ # This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
3841COPY vllm/requirements/tpu.txt /tmp/
3942COPY vllm/requirements/build.txt /tmp/
4043COPY vllm/requirements/common.txt /tmp/
41- COPY tpu_commons /requirements.txt /tmp/
44+ COPY tpu-inference /requirements.txt /tmp/
4245
4346# Run the full dependency installation.
4447# This entire layer is cached and will *only* be rebuilt if
4548# these .txt files change.
46- RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
49+ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
4750 # Set the target device so pip installs the right JAX/libtpu
51+ # Install tpu-inference dependencies
4852 export VLLM_TARGET_DEVICE="tpu" && \
4953 pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
5054 --extra-index-url https://pypi.org/simple/ \
@@ -56,16 +60,34 @@ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
5660 --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
5761 --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
5862
63+ # Install tpu-inference dependencies
64+ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
65+ pip install -r /tmp/requirements.txt --no-cache-dir --pre \
66+ --extra-index-url https://pypi.org/simple/ \
67+ --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
68+ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
69+ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
70+ --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
71+ --find-links https://storage.googleapis.com/libtpu-releases/index.html \
72+ --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
73+ --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
74+
5975# --- STAGE 3: Install Project Source Code ---
6076
6177# Now, copy the full source code. This invalidates cache frequently,
6278# but the next step is fast.
6379COPY vllm /vllm/
64- COPY tpu_commons /tpu_commons/
80+ COPY tpu-inference /tpu-inference/
81+ COPY tunix /tunix
82+
6583
6684# Install in editable mode. This is lightning-fast because all
6785# dependencies were installed and cached in STAGE 2.
68- RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ -e /tpu_commons/
86+ RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
87+ RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
88+
89+ RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
90+ # RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
6991
7092RUN if [ "$MODE" = "post-training-experimental" ]; then \
7193 echo "MODE=grpo-experimental: Re-installing JAX/libtpu" ; \
0 commit comments