Skip to content

Add embedding models #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: 0.6.6-b17
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,13 @@ def _make_attention_mask(
return mask


def use_sdp_causal(head_dim, query_states, logits_soft_cap):
def use_sdp_causal(head_dim, query_states, logits_soft_cap, attn_type):
return (
(logits_soft_cap != 0 # for gemma model
or head_dim in [-1, 64, 80, 96, 128]) # for now
and query_states.device.type == "xpu" # GPU
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
and attn_type is AttentionType.DECODER
)

def use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap):
Expand Down Expand Up @@ -344,18 +345,18 @@ def forward(
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
raise NotImplementedError("Encoder/decoder cross-attention "
"is not implemented for "
"IpexAttnBackendImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache is not None and attn_type == AttentionType.DECODER:
# Only update the kv_cache with decoder architecture...
if self.using_gqa_kernel:
key_cache, value_cache = self.split_kv_cache_ipexllm(
kv_cache, self.num_kv_heads, self.head_size)
Expand Down Expand Up @@ -401,6 +402,11 @@ def forward(

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
# If mask is not set, then is_causal=True
# If mask is set, then is_causal=False
is_causal = not self.need_mask
if attn_type == AttentionType.ENCODER_ONLY:
is_causal = False

if prefill_meta := attn_metadata.prefill_metadata:
assert prefill_meta.seq_lens is not None
Expand Down Expand Up @@ -445,7 +451,7 @@ def forward(
pdropout=0.0,
softmax_scale=self.scale,
zero_tensors=False,
is_causal=True,
is_causal=is_causal,
return_softmax=False,
gen_=None,
logits_soft_cap=self.logits_soft_cap)
Expand All @@ -462,7 +468,7 @@ def forward(
for seq_len, mask in zip(prefill_meta.seq_lens,
prefill_meta.attn_bias):
end = start + seq_len
if self.alibi_slopes is None and use_sdp_causal(self.head_size, query, self.logits_soft_cap):
if self.alibi_slopes is None and use_sdp_causal(self.head_size, query, self.logits_soft_cap, attn_type):
import xe_addons
if mask is not None:
mask = mask.unsqueeze(0)
Expand Down Expand Up @@ -490,7 +496,7 @@ def forward(
value[None, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
is_causal=is_causal,
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out
Expand Down
135 changes: 135 additions & 0 deletions vllm/worker/xpu_pooling_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch

from vllm.forward_context import set_forward_context
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
# from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
# ModelInputForCPUBuilder)
from vllm.worker.xpu_model_runner import ModelInputForXPU, XPUModelRunnerBase, ModelInputForXPUBuilder


@dataclasses.dataclass(frozen=True)
class ModelInputForXPUWithPoolingMetadata(ModelInputForXPU):
"""
Used by the CPUPoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None


class XPUPoolingModelRunner(
XPUModelRunnerBase[ModelInputForXPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForXPUWithPoolingMetadata] = (
ModelInputForXPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder

@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForXPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"Currently multi-step worker does not support multi-steps...")

num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# TODO: check if we need float16...
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]

model_executable = self.model
cross_enc_kwargs = {}
# if model_input.token_type_ids is not None:
# cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**cross_enc_kwargs,
"intermediate_tensors":
intermediate_tensors,
}

with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable(**execute_model_kwargs)

# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return []

return [
self.model.pooler(hidden_states=hidden_states,
pooling_metadata=model_input.pooling_metadata)
]

def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForXPUWithPoolingMetadata:
return ModelInputForXPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)

def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForXPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)

return dataclasses.replace(model_input,
virtual_engine=virtual_engine,
pooling_metadata=pooling_metadata)

def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
seq_groups.append((seq_ids, pooling_params))

seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)

pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)

return pooling_metadata
13 changes: 9 additions & 4 deletions vllm/worker/xpu_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A XPU worker class."""
import gc
import os
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Type

# import intel_extension_for_pytorch # noqa: F401
# TODO: handle case for oneccl_bindings for dual cards
Expand All @@ -18,7 +18,8 @@
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
from vllm.worker.worker_base import WorkerBase
from vllm.worker.xpu_model_runner import XPUModelRunner
from vllm.worker.xpu_model_runner import XPUModelRunner, XPUModelRunnerBase
from vllm.worker.xpu_pooling_model_runner import XPUPoolingModelRunner

logger = init_logger(__name__)

Expand Down Expand Up @@ -55,16 +56,20 @@ def __init__(
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
ModelRunnerClass: Type[XPUModelRunnerBase] = XPUModelRunner
model_config = self.model_config
if model_config.task == "embed":
ModelRunnerClass = XPUPoolingModelRunner

self.model_runner = XPUModelRunner( # type: ignore
self.model_runner = ModelRunnerClass( # type: ignore
vllm_config=vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine]
self.gpu_cache: Optional[List[List[torch.Tensor]]]
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None

def init_device(self) -> None:
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
Expand Down