diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index fb035bde1..082bc3f9d 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -5,7 +5,6 @@ import torch from flax import nnx from jax.sharding import Mesh, NamedSharding, PartitionSpec -from torchax.ops.mappings import j2t_dtype from transformers import PretrainedConfig from vllm.config import VllmConfig from vllm.model_executor.model_loader import get_model_loader @@ -19,14 +18,14 @@ from tpu_inference.models.jax.utils.quantization.quantization_utils import ( apply_qwix_on_abstract_model, apply_qwix_quantization, load_random_weights_into_qwix_abstract_model) +from tpu_inference.utils import to_jax_dtype, to_torch_dtype logger = init_logger(__name__) _MODEL_REGISTRY = {} -# Architectures that prefer "vllm" implementation type when MODEL_IMPL_TYPE is "auto". -# These architectures are listed here because they have better performance with the -# vLLM PyTorch backend compared to the flax_nnx JAX backend for now. +# List of architectures that are preferred to use "vllm" implementation over +# "flax_nnx" implementation due to various factors such as performance. _VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset( {"GptOssForCausalLM"}) @@ -216,6 +215,9 @@ def get_flax_model( mesh: Mesh, is_draft_model: bool = False, ) -> nnx.Module: + model_dtype = to_jax_dtype(vllm_config.model_config.dtype) + vllm_config.model_config.dtype = model_dtype + if is_draft_model: model_class = _get_model_architecture( vllm_config.speculative_config.draft_model_config.hf_config) @@ -323,6 +325,8 @@ def get_vllm_model( rng: jax.Array, mesh: Mesh, ): + model_dtype = to_torch_dtype(vllm_config.model_config.dtype) + vllm_config.model_config.dtype = model_dtype from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper model = VllmModelWrapper( @@ -371,8 +375,6 @@ def get_model( logger.warning(error_msg) # Fall back to the vLLM model and updating the dtype accordingly - vllm_config.model_config.dtype = j2t_dtype( - vllm_config.model_config.dtype.dtype) return get_vllm_model(vllm_config, rng, mesh) case "vllm": return get_vllm_model(vllm_config, rng, mesh) diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 1240039bf..0f2e0d2e3 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -12,7 +12,6 @@ from tpu_inference import envs from tpu_inference.layers.common.sharding import ShardingConfigManager from tpu_inference.logger import init_logger -from tpu_inference.utils import to_jax_dtype, to_torch_dtype if TYPE_CHECKING: from vllm.attention.backends.registry import AttentionBackendEnum @@ -145,26 +144,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if compilation_config.backend == "": compilation_config.backend = "openxla" - # If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype. - impl = envs.MODEL_IMPL_TYPE - - # NOTE(xiang): convert dtype to jnp.dtype - # NOTE(wenlong): skip this logic for mm model preprocessing - # For mm model preprocessors, it may need the output dtype to be torch. - # In order to avoid a PR to vLLM, we postpone the dtype checking during - # tpu_worker initialization - if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm": - model_dtype = vllm_config.model_config.dtype - try: - dtype = to_jax_dtype(model_dtype) - except ValueError: - logger.warning(f"{model_dtype=} is not supported. " - "Falling back to jnp.bfloat16") - dtype = jnp.bfloat16 - if impl == "vllm": - dtype = to_torch_dtype(dtype) - vllm_config.model_config.dtype = dtype - # TODO(cuiq): remove this dependency. from vllm.v1.attention.backends.pallas import PallasAttentionBackend cache_config.block_size = PallasAttentionBackend.get_page_size( @@ -172,8 +151,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config) if min_page_size > cache_config.block_size: logger.warning( - "Increase the page size from %s to %s to make sure there's" - "no SMEM OOM", + "Increase the page size from %s to %s to avoid SMEM OOM", cache_config.block_size, min_page_size, ) diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index b2fd73ae6..863102ce5 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -13,7 +13,6 @@ from flax import nnx from jax.experimental import mesh_utils from jax.sharding import NamedSharding, PartitionSpec -from torchax.ops.mappings import t2j_dtype from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.distributed.kv_transfer import (get_kv_transfer_group, @@ -65,7 +64,7 @@ StructuredDecodingManager from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer from tpu_inference.utils import (device_array, make_optimized_mesh, - time_function, to_torch_dtype) + time_function, to_jax_dtype, to_torch_dtype) logger = init_logger(__name__) @@ -1718,8 +1717,7 @@ def _sync_weights( shard=shard) def get_intermediate_tensor_spec(self, num_tokens: int): - impl = envs.MODEL_IMPL_TYPE - jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype + jax_dtype = to_jax_dtype(self.dtype) num_padded_tokens = runner_utils.get_padded_token_len( self.num_tokens_paddings, num_tokens) sharding = NamedSharding(self.mesh, PartitionSpec()) diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 9b52d43e4..092dc26f0 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -6,7 +6,6 @@ from typing import Callable, Dict, Optional, Tuple import jax -import jax.numpy as jnp import jaxlib import jaxtyping import vllm.envs as vllm_envs @@ -37,12 +36,6 @@ logger = init_logger(__name__) -_DTYPE: dict[str, jnp.dtype] = { - "bfloat16": jnp.bfloat16, - "float": jnp.float32, - "float32": jnp.float32, -} - @dataclass class PPConfig: @@ -77,21 +70,6 @@ def __init__( ip: str = "localhost", prev_worker_ip: str = "localhost", ): - # If we use vLLM's model implementation in PyTorch, we should set it - # with torch version of the dtype. - impl = envs.MODEL_IMPL_TYPE - if impl != "vllm": # vllm-pytorch implementation does not need this conversion - - # NOTE(wenlong): because sometimes mm needs to use torch for preprocessing - if not isinstance(vllm_config.model_config.dtype, str): - logger.warning( - "The model dtype is not properly set for JAX backend. " - "Overwriting it to jnp.bfloat16") - vllm_config.model_config.dtype = jnp.bfloat16 - else: - vllm_config.model_config.dtype = _DTYPE.get( - vllm_config.model_config.dtype, jnp.bfloat16) - self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config