Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/split #1

Draft
wants to merge 19 commits into
base: mlperf_features
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
from typing import List, Optional

import pandas as pd
import torch
import uvloop
from PIL import Image
Expand Down Expand Up @@ -150,9 +151,10 @@ def run_vllm(
use_beam_search = False

if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
for _ in range(2):
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
Expand Down Expand Up @@ -202,16 +204,26 @@ async def run_vllm_async(
max_tokens=request.expected_output_len,
))

generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start
for _ in range(2):
generators = []
start_time = []
latencies = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
start_time.append(time.perf_counter())
latencies.append([])
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
lat = time.perf_counter() - start_time[i]
latencies[i].append(lat)
end = time.perf_counter()
import pandas as pd
first_latency = pd.Series([lat[0] * 1000 for lat in latencies])
next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000
for lat in latencies])
return end - start, (first_latency, next_latency)


def run_hf(
Expand Down Expand Up @@ -335,7 +347,7 @@ def main(args: argparse.Namespace):
for request in requests)
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
elapsed_time, (first_latency, next_latency) = uvloop.run(
run_vllm_async(
requests,
args.n,
Expand All @@ -345,6 +357,7 @@ def main(args: argparse.Namespace):
else:
elapsed_time = run_vllm(requests, args.n,
EngineArgs.from_cli_args(args))
first_latency, next_latency = None, None
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand All @@ -366,6 +379,12 @@ def main(args: argparse.Namespace):
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
if first_latency is not None:
latency_breakdown = "\nFirst token latency(msecs):\n"
latency_breakdown += f"{first_latency.describe()}"
latency_breakdown += "\nNext token latency(msecs):\n"
latency_breakdown += f"{next_latency.describe()}"
print(f"{latency_breakdown if first_latency is not None else ''}")

# Output JSON results if specified
if args.output_json:
Expand Down
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@1fd1fcf
18 changes: 14 additions & 4 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.

Expand Down Expand Up @@ -208,8 +209,18 @@ def forward(
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
block_indices = kwargs.get('block_indices', None)
block_offsets = kwargs.get('block_offsets', None)
seq_lens_tensor = kwargs.get('seq_lens_tensor', None)
attn_bias = kwargs.get('attn_bias', None)
if block_indices is None:
block_indices = attn_metadata.block_indices
if block_offsets is None:
block_offsets = attn_metadata.block_offsets
if seq_lens_tensor is None:
seq_lens_tensor = attn_metadata.seq_lens_tensor
if attn_bias is None: # This is the case for prompt run
attn_bias = attn_metadata.attn_bias
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
Expand All @@ -235,7 +246,6 @@ def forward(
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
Expand All @@ -256,7 +266,7 @@ def forward(
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
valid_seq_lengths=seq_lens_tensor,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
else:
Expand Down
39 changes: 10 additions & 29 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,37 +128,18 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
**kwargs,
) -> torch.Tensor:

if self.use_direct_call:
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
elif self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, attn_type,
self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type,
self.layer_name)
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type,
**kwargs)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down
7 changes: 6 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ class CacheConfig:
prefix caching enabled.
enable_prefix_caching: Whether to enable prefix caching.
cpu_offload_gb: Size of the CPU offload buffer in GiB.
split_qk_v: Whether to split qk and v calculations.
split_gate_up: Whether to split gate and up calculations.
"""

def __init__(
Expand All @@ -750,6 +752,8 @@ def __init__(
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
cpu_offload_gb: float = 0,
split_qk_v: bool = False,
split_gate_up: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
Expand All @@ -760,7 +764,8 @@ def __init__(
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb

self.split_qk_v = split_qk_v
self.split_gate_up = split_gate_up
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
Expand Down
12 changes: 12 additions & 0 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def split_tensor_along_last_dim(

return tensor_list

def split_tensor_along_x_dim(
tensor: torch.Tensor,
dim: int,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
dim_size = divide(tensor.size()[dim], num_partitions)
tensor_list = torch.split(tensor, dim_size, dim=dim)
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list


def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]:
Expand Down
14 changes: 13 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class EngineArgs:
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
split_qk_v: bool = False
split_gate_up: bool = False
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_num_prefill_seqs: Optional[int] = None
Expand Down Expand Up @@ -501,7 +503,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
' of GPU blocks. Used for testing preemption.')
'of GPU blocks. Used for testing preemption.')
parser.add_argument('--split-qk-v',
action='store_true',
default=EngineArgs.split_qk_v,
help='Whether to separate qk and v calculations.')
parser.add_argument('--split-gate-up',
action='store_true',
default=EngineArgs.split_gate_up,
help='Whether to separate gate and up calculations.')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
Expand Down Expand Up @@ -1050,6 +1060,8 @@ def create_engine_config(self,
cache_dtype=self.kv_cache_dtype,
is_attention_free=model_config.is_attention_free,
num_gpu_blocks_override=self.num_gpu_blocks_override,
split_qk_v=self.split_qk_v,
split_gate_up=self.split_gate_up,
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb,
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
self.model_config.enforce_eager,
"disable_custom_all_reduce":
self.parallel_config.disable_custom_all_reduce,
"split_qk_v":
self.cache_config.split_qk_v,
"split_gate_up":
self.cache_config.split_gate_up,
})

if self.tokenizer:
Expand Down
77 changes: 71 additions & 6 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
split_tensor_along_x_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
Expand Down Expand Up @@ -996,13 +997,21 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
do_split: bool=False, # should enable for donw_proj, disable for o_proj
split_threshold:int = 128,
split_size:int = 2):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
self.collective_func = tensor_model_parallel_all_reduce
self.do_split = do_split
self.split_threshold = split_threshold
self.split_size = split_size
self.prefix = prefix
self.skip_seq_split = False

# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
Expand Down Expand Up @@ -1099,13 +1108,69 @@ def forward(self, input_):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,

# print(input_parallel.shape) # [batch_size, seq_lens, hidden_size//tp_size]

# split v2:
# stretage: we split the input tensor on 1 dim(seq length dim), but only when seq_length greater
# than a threshold, otherwise we dont split. which means, decode phase will never split
# why split on 1st dim:
# the 0th dim is batch size, when batch size = 1, we can not split anyway.
# 2nd dim(hidden_size): tp already split on this dim, will change much more if split on this

_, seq_len, _ = input_parallel.shape
shape_total = input_parallel.shape[0] * input_parallel.shape[1] * input_parallel.shape[2]
do_split = self.do_split and seq_len > 1 # split decode
# NOTE: we found split tensor when it is too small is not helping with the performance.
# 1 * 1024 * 4096 * 3 is [batch_size, seq_len, hidden_size * 3]
do_split = do_split and shape_total > 1 * 1024 * 8192 * 3 and not self.skip_seq_split

if do_split:
input_parallels = split_tensor_along_x_dim(input_parallel, 1, self.split_size)
output_parallels = []
for input_parallel in input_parallels:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_parallels.append(output)
output = torch.cat(output_parallels, dim=1)

else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel


# split v1:
# why split on 0th dim:
# 1st dim(seq_lens): due to decode phase seq_lens is always 1, so we can not split on this dim
# 2nd dim(hidden_size): tp already split on this dim, will change much more if split on this
# Other limitations & FIXME: need to set VLLM_DECODE_BS_BUCKET_MIN=2, VLLM_PROMPT_BS_BUCKET_MIN=2, otherwise it cannot divide and split.
# Overheads:
# 1. split overhead.
# 2. append may have some overhead, I am not sure whether the output tensor need ready.
# 3. cat tensor overhead. we can do some optimization here. but I am afraid there will always be some copy.
# split = 2
# input_parallels = split_tensor_along_x_dim(input_parallel, 0, split)
# output_parallels = []
# for input_parallel in input_parallels:
# output_parallel = self.quant_method.apply(self,
# input_parallel,
# bias=bias_)
# if self.reduce_results and self.tp_size > 1:
# output = tensor_model_parallel_all_reduce(output_parallel)
# else:
# output = output_parallel
# output_parallels.append(output)
# output = torch.cat(output_parallels, dim=0)

output_bias = self.bias if self.skip_bias_add else None

Expand Down
Loading
Loading