diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 7902d02e306..a590b9b6bfd 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -214,12 +214,13 @@ def _make_attention_mask( return mask -def use_sdp_causal(head_dim, query_states, logits_soft_cap): +def use_sdp_causal(head_dim, query_states, logits_soft_cap, attn_type): return ( (logits_soft_cap != 0 # for gemma model or head_dim in [-1, 64, 80, 96, 128]) # for now and query_states.device.type == "xpu" # GPU and query_states.dtype in [torch.float, torch.half] # fp32/fp16 + and attn_type is AttentionType.DECODER ) def use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap): @@ -344,10 +345,9 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert k_scale == 1.0 and v_scale == 1.0 - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " + if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: + raise NotImplementedError("Encoder/decoder cross-attention " + "is not implemented for " "IpexAttnBackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. @@ -355,7 +355,8 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - if kv_cache is not None: + if kv_cache is not None and attn_type == AttentionType.DECODER: + # Only update the kv_cache with decoder architecture... if self.using_gqa_kernel: key_cache, value_cache = self.split_kv_cache_ipexllm( kv_cache, self.num_kv_heads, self.head_size) @@ -401,6 +402,11 @@ def forward( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens + # If mask is not set, then is_causal=True + # If mask is set, then is_causal=False + is_causal = not self.need_mask + if attn_type == AttentionType.ENCODER_ONLY: + is_causal = False if prefill_meta := attn_metadata.prefill_metadata: assert prefill_meta.seq_lens is not None @@ -445,7 +451,7 @@ def forward( pdropout=0.0, softmax_scale=self.scale, zero_tensors=False, - is_causal=True, + is_causal=is_causal, return_softmax=False, gen_=None, logits_soft_cap=self.logits_soft_cap) @@ -462,7 +468,7 @@ def forward( for seq_len, mask in zip(prefill_meta.seq_lens, prefill_meta.attn_bias): end = start + seq_len - if self.alibi_slopes is None and use_sdp_causal(self.head_size, query, self.logits_soft_cap): + if self.alibi_slopes is None and use_sdp_causal(self.head_size, query, self.logits_soft_cap, attn_type): import xe_addons if mask is not None: mask = mask.unsqueeze(0) @@ -490,7 +496,7 @@ def forward( value[None, :, start:end, :], attn_mask=mask, dropout_p=0.0, - is_causal=not self.need_mask, + is_causal=is_causal, scale=self.scale).squeeze(0).movedim( query.dim() - 2, 0) output[start:end, :, :] = sub_out diff --git a/vllm/worker/xpu_pooling_model_runner.py b/vllm/worker/xpu_pooling_model_runner.py new file mode 100644 index 00000000000..d3427bcdfee --- /dev/null +++ b/vllm/worker/xpu_pooling_model_runner.py @@ -0,0 +1,135 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch + +from vllm.forward_context import set_forward_context +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MultiModalKwargs +from vllm.pooling_params import PoolingParams +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, + SequenceGroupMetadata) +# from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU, + # ModelInputForCPUBuilder) +from vllm.worker.xpu_model_runner import ModelInputForXPU, XPUModelRunnerBase, ModelInputForXPUBuilder + + +@dataclasses.dataclass(frozen=True) +class ModelInputForXPUWithPoolingMetadata(ModelInputForXPU): + """ + Used by the CPUPoolingModelRunner. + """ + pooling_metadata: Optional["PoolingMetadata"] = None + + +class XPUPoolingModelRunner( + XPUModelRunnerBase[ModelInputForXPUWithPoolingMetadata]): + _model_input_cls: Type[ModelInputForXPUWithPoolingMetadata] = ( + ModelInputForXPUWithPoolingMetadata) + _builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForXPUWithPoolingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: + if num_steps > 1: + raise ValueError( + "Currently multi-step worker does not support multi-steps...") + + num_layers = self.model_config.get_num_layers(self.parallel_config) + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # TODO: check if we need float16... + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + for _ in range(num_layers) + ] + + model_executable = self.model + cross_enc_kwargs = {} + # if model_input.token_type_ids is not None: + # cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids + execute_model_kwargs = { + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + "kv_caches": + kv_caches, + "attn_metadata": + model_input.attn_metadata, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + **cross_enc_kwargs, + "intermediate_tensors": + intermediate_tensors, + } + + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_states = model_executable(**execute_model_kwargs) + + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return [] + + return [ + self.model.pooler(hidden_states=hidden_states, + pooling_metadata=model_input.pooling_metadata) + ] + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForXPUWithPoolingMetadata: + return ModelInputForXPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForXPUWithPoolingMetadata: + assert seq_group_metadata_list is not None + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Prepare PoolingMetadata. + assert model_input.seq_lens is not None + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + model_input.seq_lens) + + return dataclasses.replace(model_input, + virtual_engine=virtual_engine, + pooling_metadata=pooling_metadata) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index fb7962dfebd..ce4ca0a86c0 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -1,7 +1,7 @@ """A XPU worker class.""" import gc import os -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type # import intel_extension_for_pytorch # noqa: F401 # TODO: handle case for oneccl_bindings for dual cards @@ -18,7 +18,8 @@ from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker from vllm.worker.worker_base import WorkerBase -from vllm.worker.xpu_model_runner import XPUModelRunner +from vllm.worker.xpu_model_runner import XPUModelRunner, XPUModelRunnerBase +from vllm.worker.xpu_pooling_model_runner import XPUPoolingModelRunner logger = init_logger(__name__) @@ -55,8 +56,12 @@ def __init__( if parallel_config and is_driver_worker: assert rank % parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." + ModelRunnerClass: Type[XPUModelRunnerBase] = XPUModelRunner + model_config = self.model_config + if model_config.task == "embed": + ModelRunnerClass = XPUPoolingModelRunner - self.model_runner = XPUModelRunner( # type: ignore + self.model_runner = ModelRunnerClass( # type: ignore vllm_config=vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, @@ -64,7 +69,7 @@ def __init__( # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CacheEngine] - self.gpu_cache: Optional[List[List[torch.Tensor]]] + self.gpu_cache: Optional[List[List[torch.Tensor]]] = None def init_device(self) -> None: if self.device_config.device.type == "xpu" and current_platform.is_xpu(