Skip to content

Commit

Permalink
Simplify the attention function
Browse files Browse the repository at this point in the history
- Use one definition rather than multiple.
- Add `key`/`value` arguments, so that we don't need the
  `PREFILL_IN_KVCACHE` constant.
- Make it kwargs-only (to avoid mixing up the various `Tensor` args).
  • Loading branch information
danieldk committed Oct 8, 2024
1 parent 0da4df4 commit c56df2d
Show file tree
Hide file tree
Showing 21 changed files with 243 additions and 302 deletions.
4 changes: 0 additions & 4 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,20 @@
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "rocm":
from .rocm import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "ipex":
from .ipex import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
Expand All @@ -40,7 +37,6 @@
"attention",
"paged_attention",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",
Expand Down
126 changes: 49 additions & 77 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import (
ATTENTION,
Expand Down Expand Up @@ -38,8 +39,7 @@ def reshape_and_cache(

def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
Expand Down Expand Up @@ -80,7 +80,7 @@ def paged_attention(

return decode_state.get().forward(
query.contiguous(),
paged_kv_cache=(key_cache, value_cache),
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
Expand All @@ -98,8 +98,8 @@ def paged_attention(
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
Expand Down Expand Up @@ -135,8 +135,8 @@ def paged_attention(
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
Expand Down Expand Up @@ -168,8 +168,8 @@ def paged_attention(
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
Expand Down Expand Up @@ -218,52 +218,42 @@ def paged_attention(

SUPPORTS_WINDOWING = V2

if ATTENTION == "flashinfer":

def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
):
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)

return prefill_with_paged_kv_state.get().forward(
q.contiguous(),
query.contiguous(),
causal=causal,
paged_kv_cache=(key_cache, value_cache),
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
)

elif V2:

def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
elif V2:
out = torch.empty_like(query)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
query,
kv_cache.key,
kv_cache.value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
Expand All @@ -284,19 +274,7 @@ def attention(
None,
)[0]

else:

def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
else:
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
Expand All @@ -305,36 +283,36 @@ def attention(
raise NotImplementedError("softcap is only available with flash attn v2")

# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
if key.shape[1] != query.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
if key.shape[1] == 1:
key = key.expand(-1, query.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
original_shape = key.shape
key = (
key.unsqueeze(2)
.expand(-1, -1, query.shape[1] // key.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
if value.shape[1] != query.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
if value.shape[1] == 1:
value = value.expand(-1, query.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
original_shape = value.shape
value = (
value.unsqueeze(2)
.expand(-1, -1, query.shape[1] // value.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)

out = torch.empty_like(q)
out = torch.empty_like(query)
flash_attn_cuda.fwd(
q,
k,
v,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
Expand All @@ -351,13 +329,7 @@ def attention(
return out


# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2

__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
Expand Down
26 changes: 13 additions & 13 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional

SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False


def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap: Optional[float] = None,
):
out = torch.empty_like(q)
out = torch.empty_like(query)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q.contiguous() if q.device.type == "xpu" else q,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
Expand Down Expand Up @@ -56,8 +58,7 @@ def reshape_and_cache(

def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
Expand All @@ -69,8 +70,8 @@ def paged_attention(
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
Expand All @@ -83,7 +84,6 @@ def paged_attention(


__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
Expand Down
9 changes: 4 additions & 5 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import reshape_and_cache


class KVCache:
Expand All @@ -24,10 +23,8 @@ def __init__(
):
"""Construct the key-value cache for a layer."""

if (
dtype == torch.float8_e5m2
and (ATTENTION != "flashinfer"
or SYSTEM != "cuda")
if dtype == torch.float8_e5m2 and (
ATTENTION != "flashinfer" or SYSTEM != "cuda"
):
raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
Expand Down Expand Up @@ -118,4 +115,6 @@ def store(
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
from text_generation_server.layers.attention import reshape_and_cache

reshape_and_cache(key, value, key_cache, value_cache, slots)
Loading

0 comments on commit c56df2d

Please sign in to comment.