Skip to content

Commit

Permalink
Simplify the attention function
Browse files Browse the repository at this point in the history
- Use one definition rather than multiple.
- Add `key`/`value` arguments, so that we don't need the
  `PREFILL_IN_KVCACHE` constant.
- Make it kwargs-only (to avoid mixing up the various `Tensor` args).
  • Loading branch information
danieldk committed Oct 7, 2024
1 parent 0da4df4 commit 9581f20
Show file tree
Hide file tree
Showing 21 changed files with 216 additions and 246 deletions.
4 changes: 0 additions & 4 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,20 @@
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "rocm":
from .rocm import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "ipex":
from .ipex import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
Expand All @@ -40,7 +37,6 @@
"attention",
"paged_attention",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",
Expand Down
103 changes: 38 additions & 65 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,50 +218,41 @@ def paged_attention(

SUPPORTS_WINDOWING = V2

if ATTENTION == "flashinfer":

def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)

return prefill_with_paged_kv_state.get().forward(
q.contiguous(),
query.contiguous(),
causal=causal,
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
)

elif V2:

def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
elif V2:
out = torch.empty_like(query)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
query,
key_cache,
value_cache,
out,
Expand All @@ -284,19 +275,7 @@ def attention(
None,
)[0]

else:

def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
else:
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
Expand All @@ -305,36 +284,36 @@ def attention(
raise NotImplementedError("softcap is only available with flash attn v2")

# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
if key.shape[1] != query.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
if key.shape[1] == 1:
key = key.expand(-1, query.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
original_shape = key.shape
key = (
key.unsqueeze(2)
.expand(-1, -1, query.shape[1] // key.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
if value.shape[1] != query.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
if value.shape[1] == 1:
value = value.expand(-1, query.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
original_shape = value.shape
value = (
value.unsqueeze(2)
.expand(-1, -1, query.shape[1] // value.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)

out = torch.empty_like(q)
out = torch.empty_like(query)
flash_attn_cuda.fwd(
q,
k,
v,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
Expand All @@ -351,13 +330,7 @@ def attention(
return out


# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2

__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
Expand Down
15 changes: 8 additions & 7 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from typing import Optional

SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False


def attention(
q: torch.Tensor,
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
Expand All @@ -19,13 +21,13 @@ def attention(
causal=True,
softcap: Optional[float] = None,
):
out = torch.empty_like(q)
out = torch.empty_like(query)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q.contiguous() if q.device.type == "xpu" else q,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
Expand Down Expand Up @@ -83,7 +85,6 @@ def paged_attention(


__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
Expand Down
6 changes: 2 additions & 4 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ def __init__(
):
"""Construct the key-value cache for a layer."""

if (
dtype == torch.float8_e5m2
and (ATTENTION != "flashinfer"
or SYSTEM != "cuda")
if dtype == torch.float8_e5m2 and (
ATTENTION != "flashinfer" or SYSTEM != "cuda"
):
raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
Expand Down
71 changes: 31 additions & 40 deletions server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"

PREFILL_IN_KV_CACHE = False

use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
if use_rocm_custom_paged_attn:
Expand Down Expand Up @@ -227,29 +225,33 @@ def paged_attention(


SUPPORTS_WINDOWING = False
if ENGINE == "ck":

def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
):


def attention(
*,
query,
key: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
value: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
):
if ENGINE == "ck":
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")

out = torch.empty_like(q)
out = torch.empty_like(query)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
Expand All @@ -270,30 +272,19 @@ def attention(
None,
)[0]

elif ENGINE == "triton":
from .flash_attn_triton import triton_attention

def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
elif ENGINE == "triton":
from .flash_attn_triton import triton_attention

if softcap is not None:
raise NotImplementedError("softcap is only available with CK flash attn")

out = torch.empty_like(q)
out = torch.empty_like(query)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
key_cache,
value_cache,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
Expand All @@ -304,11 +295,11 @@ def attention(
)
return output

else:
raise RuntimeError(f"Unknown attention engine {ENGINE}")
else:
raise RuntimeError(f"Unknown attention engine {ENGINE}")


__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
Expand Down Expand Up @@ -296,12 +295,14 @@ def forward(
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=key,
value=value,
key_cache=kv_cache.key,
value_cache=kv_cache.value,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
Expand Down
Loading

0 comments on commit 9581f20

Please sign in to comment.