@@ -17,7 +17,7 @@ FROM ${BASEIMAGE}
1717ARG  MODE
1818ENV  MODE=$MODE
1919
20- RUN  echo "Installing GRPO dependencies (vLLM, tpu-common, tunix ) with MODE=${MODE}" 
20+ RUN  echo "Installing GRPO dependencies (vLLM, tpu-inference ) with MODE=${MODE}" 
2121
2222#  Uninstall existing jax to avoid conflicts
2323#  RUN pip uninstall -y jax jaxlib libtpu
@@ -27,23 +27,27 @@ RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix) with MODE=${MOD
2727RUN  --mount=type=cache,target=/root/.cache/pip pip install \
2828    aiohttp==3.12.15\
2929    keyring \
30-     keyrings.google-artifactregistry-auth \
30+     keyrings.google-artifactregistry-auth
31+ 
32+ RUN  --mount=type=cache,target=/root/.cache/pip pip install \
3133    numba==0.61.2
3234
35+ #  RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
3336#  --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
3437
3538#  Copy *only* the dependency definition files.
36- #  This assumes vllm and tpu_commons  are in the build context, copied from the parent directory.
39+ #  This assumes vllm and tpu-inference  are in the build context, copied from the parent directory.
3740COPY  vllm/requirements/tpu.txt /tmp/
3841COPY  vllm/requirements/build.txt /tmp/
3942COPY  vllm/requirements/common.txt /tmp/
40- COPY  tpu_commons /requirements.txt /tmp/
43+ COPY  tpu-inference /requirements.txt /tmp/
4144
4245#  Run the full dependency installation.
4346#  This entire layer is cached and will *only* be rebuilt if
4447#  these .txt files change.
45- RUN  --mount=type=cache,target=/root/.cache/pip bash -c ' \ 
48+ RUN  --mount=type=cache,target=/root/.cache/pip bash -c ' \    
4649    # Set the target device so pip installs the right JAX/libtpu 
50+     # Install tpu-inference dependencies 
4751    export VLLM_TARGET_DEVICE="tpu" && \ 
4852    pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \ 
4953        --extra-index-url https://pypi.org/simple/ \ 
@@ -55,16 +59,34 @@ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
5559        --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ 
5660        --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' 
5761
62+     #  Install tpu-inference dependencies
63+ RUN   --mount=type=cache,target=/root/.cache/pip bash -c ' \ 
64+         pip install -r /tmp/requirements.txt --no-cache-dir --pre \ 
65+         --extra-index-url https://pypi.org/simple/ \ 
66+         --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ 
67+         --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ 
68+         --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ 
69+         --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ 
70+         --find-links https://storage.googleapis.com/libtpu-releases/index.html \ 
71+         --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ 
72+         --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' 
73+ 
5874#  --- STAGE 3: Install Project Source Code ---
5975
6076#  Now, copy the full source code. This invalidates cache frequently,
6177#  but the next step is fast.
6278COPY  vllm /vllm/
63- COPY  tpu_commons /tpu_commons/
79+ COPY  tpu-inference /tpu-inference/
80+ COPY  tunix /tunix
81+ 
6482
6583#  Install in editable mode. This is lightning-fast because all
6684#  dependencies were installed and cached in STAGE 2.
67- RUN  --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu"  pip install -e /vllm/ -e /tpu_commons/
85+ RUN  --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu"  pip install -e /vllm/
86+ RUN  --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
87+ 
88+ RUN  --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
89+ #  RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
6890
6991RUN  if [ "$MODE"  = "grpo-experimental"  ]; then \
7092    echo "MODE=grpo-experimental: Re-installing JAX/libtpu" ; \
0 commit comments