Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from flax import nnx
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from torchax.ops.mappings import j2t_dtype
from transformers import PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.model_loader import get_model_loader
Expand All @@ -19,14 +18,14 @@
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
apply_qwix_on_abstract_model, apply_qwix_quantization,
load_random_weights_into_qwix_abstract_model)
from tpu_inference.utils import to_jax_dtype, to_torch_dtype

logger = init_logger(__name__)

_MODEL_REGISTRY = {}

# Architectures that prefer "vllm" implementation type when MODEL_IMPL_TYPE is "auto".
# These architectures are listed here because they have better performance with the
# vLLM PyTorch backend compared to the flax_nnx JAX backend for now.
# List of architectures that are preferred to use "vllm" implementation over
# "flax_nnx" implementation due to various factors such as performance.
_VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
{"GptOssForCausalLM"})

Expand Down Expand Up @@ -216,6 +215,9 @@ def get_flax_model(
mesh: Mesh,
is_draft_model: bool = False,
) -> nnx.Module:
model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
vllm_config.model_config.dtype = model_dtype

if is_draft_model:
model_class = _get_model_architecture(
vllm_config.speculative_config.draft_model_config.hf_config)
Expand Down Expand Up @@ -323,6 +325,8 @@ def get_vllm_model(
rng: jax.Array,
mesh: Mesh,
):
model_dtype = to_torch_dtype(vllm_config.model_config.dtype)
vllm_config.model_config.dtype = model_dtype
from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper

model = VllmModelWrapper(
Expand Down Expand Up @@ -371,8 +375,6 @@ def get_model(
logger.warning(error_msg)

# Fall back to the vLLM model and updating the dtype accordingly
vllm_config.model_config.dtype = j2t_dtype(
vllm_config.model_config.dtype.dtype)
return get_vllm_model(vllm_config, rng, mesh)
case "vllm":
return get_vllm_model(vllm_config, rng, mesh)
Expand Down
24 changes: 1 addition & 23 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from tpu_inference import envs
from tpu_inference.layers.common.sharding import ShardingConfigManager
from tpu_inference.logger import init_logger
from tpu_inference.utils import to_jax_dtype, to_torch_dtype

if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
Expand Down Expand Up @@ -145,35 +144,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if compilation_config.backend == "":
compilation_config.backend = "openxla"

# If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
impl = envs.MODEL_IMPL_TYPE

# NOTE(xiang): convert dtype to jnp.dtype
# NOTE(wenlong): skip this logic for mm model preprocessing
# For mm model preprocessors, it may need the output dtype to be torch.
# In order to avoid a PR to vLLM, we postpone the dtype checking during
# tpu_worker initialization
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
model_dtype = vllm_config.model_config.dtype
try:
dtype = to_jax_dtype(model_dtype)
except ValueError:
logger.warning(f"{model_dtype=} is not supported. "
"Falling back to jnp.bfloat16")
dtype = jnp.bfloat16
if impl == "vllm":
dtype = to_torch_dtype(dtype)
vllm_config.model_config.dtype = dtype

# TODO(cuiq): remove this dependency.
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]
min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config)
if min_page_size > cache_config.block_size:
logger.warning(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM",
"Increase the page size from %s to %s to avoid SMEM OOM",
cache_config.block_size,
min_page_size,
)
Expand Down
6 changes: 2 additions & 4 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from flax import nnx
from jax.experimental import mesh_utils
from jax.sharding import NamedSharding, PartitionSpec
from torchax.ops.mappings import t2j_dtype
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
Expand Down Expand Up @@ -65,7 +64,7 @@
StructuredDecodingManager
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
from tpu_inference.utils import (device_array, make_optimized_mesh,
time_function, to_torch_dtype)
time_function, to_jax_dtype, to_torch_dtype)

logger = init_logger(__name__)

Expand Down Expand Up @@ -1718,8 +1717,7 @@ def _sync_weights(
shard=shard)

def get_intermediate_tensor_spec(self, num_tokens: int):
impl = envs.MODEL_IMPL_TYPE
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
jax_dtype = to_jax_dtype(self.dtype)
num_padded_tokens = runner_utils.get_padded_token_len(
self.num_tokens_paddings, num_tokens)
sharding = NamedSharding(self.mesh, PartitionSpec())
Expand Down
22 changes: 0 additions & 22 deletions tpu_inference/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Callable, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
import jaxlib
import jaxtyping
import vllm.envs as vllm_envs
Expand Down Expand Up @@ -37,12 +36,6 @@

logger = init_logger(__name__)

_DTYPE: dict[str, jnp.dtype] = {
"bfloat16": jnp.bfloat16,
"float": jnp.float32,
"float32": jnp.float32,
}


@dataclass
class PPConfig:
Expand Down Expand Up @@ -77,21 +70,6 @@ def __init__(
ip: str = "localhost",
prev_worker_ip: str = "localhost",
):
# If we use vLLM's model implementation in PyTorch, we should set it
# with torch version of the dtype.
impl = envs.MODEL_IMPL_TYPE
if impl != "vllm": # vllm-pytorch implementation does not need this conversion

# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
if not isinstance(vllm_config.model_config.dtype, str):
logger.warning(
"The model dtype is not properly set for JAX backend. "
"Overwriting it to jnp.bfloat16")
vllm_config.model_config.dtype = jnp.bfloat16
else:
vllm_config.model_config.dtype = _DTYPE.get(
vllm_config.model_config.dtype, jnp.bfloat16)

self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
Expand Down