Skip to content

Commit

Permalink
[TPU] Reduce compilation time & Upgrade PyTorch XLA version (vllm-pro…
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jul 27, 2024
1 parent f954d07 commit fad5576
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.tpu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG NIGHTLY_DATE="20240713"
ARG NIGHTLY_DATE="20240726"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"

FROM $BASE_IMAGE
Expand Down
9 changes: 8 additions & 1 deletion docs/source/getting_started/tpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ First, install the dependencies:
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240713"
$ export DATE="+20240726"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
Expand All @@ -75,6 +75,13 @@ Next, build vLLM from source. This will only take a few seconds:
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
.. note::

Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
The compilation time may take 20~30 minutes in the first run.
However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default).


.. tip::

If you encounter the following error:
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
import torch_xla.experimental.dynamo_set_buffer_donor

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
Expand Down
3 changes: 2 additions & 1 deletion vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt


Expand All @@ -20,7 +21,7 @@ def __init__(self, group: ProcessGroup):
local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
xm._init_world_size_ordinal()
xr._init_world_size_ordinal()

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
Expand Down
15 changes: 13 additions & 2 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
Expand Down Expand Up @@ -127,7 +128,7 @@ def load_model(self) -> None:
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xm.get_ordinal()
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
Expand All @@ -146,7 +147,17 @@ def load_model(self) -> None:
xm.wait_device_ops()

model = ModelWrapper(model)
self.model = torch.compile(model, backend="openxla", fullgraph=True)
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Setting dynamic=True can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=True)

def _dummy_run(
self,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
import torch_xla.runtime as xr

import vllm.envs as envs
Expand Down

0 comments on commit fad5576

Please sign in to comment.