Skip to content

Commit

Permalink
Unify attention arguments more
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Oct 11, 2024
1 parent 6c96ef7 commit 8627c16
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
11 changes: 9 additions & 2 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,16 @@ def attention(
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
softcap: Optional[float] = None,
):
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)

if softcap is None:
softcap = 0.0

return prefill_with_paged_kv_state.get().forward(
query.contiguous(),
causal=causal,
Expand All @@ -250,6 +253,10 @@ def attention(
out = torch.empty_like(query)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")

if softcap is None:
softcap = 0.0

return flash_attn_2_cuda.varlen_fwd(
query,
kv_cache.key,
Expand Down Expand Up @@ -280,7 +287,7 @@ def attention(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2")
raise NotImplementedError("softcap is not available in flash attn v1")

# Flash attention v1 requires q, k and v to have the same number of heads
if key.shape[1] != query.shape[1]:
Expand Down
12 changes: 9 additions & 3 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ def attention(
kv_cache: KVCache,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")

out = torch.empty_like(query)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
Expand Down Expand Up @@ -66,6 +69,9 @@ def paged_attention(
max_s: int,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")

out = torch.empty_like(query)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
Expand Down
5 changes: 4 additions & 1 deletion server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,17 @@ def attention(
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
softcap: Optional[float] = None,
):
if ENGINE == "ck":
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")

out = torch.empty_like(query)

if softcap is None:
softcap = 0.0

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
query,
Expand Down

0 comments on commit 8627c16

Please sign in to comment.