Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 52 additions & 6 deletions tpu_inference/kernels/ragged_paged_attention/v3/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def ref_ragged_paged_attention(
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
distribution: jax.Array, # i32[3]
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
*,
sm_scale: float = 1.0,
sliding_window: int | None = None,
Expand All @@ -56,6 +57,7 @@ def ref_ragged_paged_attention(
page_indices,
cu_q_lens,
distribution,
attention_sink,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
Expand Down Expand Up @@ -143,7 +145,18 @@ def ref_ragged_paged_attention(
if soft_cap is not None:
attn = soft_cap * jnp.tanh(attn / soft_cap)
attn += jnp.where(mask, mask_value, 0.0)
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)

if attention_sink is not None:
reshaped_attention_sink = attention_sink.reshape(
actual_num_q_heads, 1, 1)
reshaped_attention_sink = jnp.repeat(reshaped_attention_sink,
q_len,
axis=1)
attn = jnp.concat([reshaped_attention_sink, attn], axis=2)
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
attn = attn[..., 1:]
else:
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)

out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
if v_scale is not None:
Expand Down Expand Up @@ -236,6 +249,7 @@ def _ragged_paged_attention_kernel(
q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
kv_hbm_ref, # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
attention_sink_ref, # [actual_num_kv_heads, num_q_heads_per_kv_head, head_dim]
# Output
o_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
updated_kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
Expand Down Expand Up @@ -371,7 +385,15 @@ def load_with_init(ref, init_val):
s = soft_cap * jnp.tanh(s / soft_cap)
s += jnp.where(mask, mask_value, 0.0)
s_rowmax = jnp.max(s, axis=1, keepdims=True)
m_prev = load_with_init(head_m_ref, -jnp.inf)

if attention_sink_ref is not None:
sinks = attention_sink_ref[kv_head_idx]
actual_bq_sz = q.shape[0] // num_q_heads_per_kv_head
m_prev_init = jnp.concat([sinks] * actual_bq_sz, axis=0)
m_prev = jnp.where(bkv_idx == 0, m_prev_init, head_m_ref[...])
else:
m_prev = load_with_init(head_m_ref, -jnp.inf)

m_curr = jnp.maximum(m_prev, s_rowmax)
head_m_ref[...] = m_curr
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
Expand All @@ -382,7 +404,7 @@ def load_with_init(ref, init_val):

p_rowsum = jnp.sum(p, axis=1, keepdims=True)
exp_m_diff = jnp.exp(m_prev - m_curr)
l_prev = load_with_init(head_l_ref, 0.0)
l_prev = load_with_init(head_l_ref, 1.0 if attention_sink_ref is not None else 0.0)
l_curr = exp_m_diff * l_prev + p_rowsum
head_l_ref[...] = l_curr
o_prev = load_with_init(head_acc_ref, 0.0)
Expand Down Expand Up @@ -960,6 +982,7 @@ def prepare_inputs(
Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
v: jax.
Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads],
):
max_num_tokens, actual_num_q_heads, actual_head_dim = q.shape
actual_num_kv_heads = k.shape[1]
Expand Down Expand Up @@ -995,7 +1018,13 @@ def prepare_inputs(
.swapaxes(0, 1))
# TODO(kyuyeunk, chengjiyao): Add kv quantization here.
kv = merge_kv(k, v)
return q, kv

if attention_sink is not None:
attention_sink = attention_sink.reshape(
(-1, num_q_heads_per_kv_head, 1))
attention_sink = jnp.repeat(attention_sink, head_dim, -1)

return q, kv, attention_sink


def prepare_outputs(
Expand Down Expand Up @@ -1033,6 +1062,7 @@ def dynamic_validate_inputs(
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
distribution: jax.Array, # i32[3]
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
*,
sm_scale: float = 1.0,
sliding_window: int | None = None,
Expand Down Expand Up @@ -1060,6 +1090,7 @@ def dynamic_validate_inputs(
page_indices,
cu_q_lens,
distribution,
attention_sink,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
Expand Down Expand Up @@ -1123,6 +1154,7 @@ def static_validate_inputs(
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
distribution: jax.Array, # i32[3]
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
*,
sm_scale: float = 1.0,
sliding_window: int | None = None,
Expand Down Expand Up @@ -1155,6 +1187,15 @@ def static_validate_inputs(
raise ValueError(
f"Expected {q.shape[2]=} to be equal to {k.shape[2]=} and {v.shape[2]=}"
)
if attention_sink is not None:
if attention_sink.shape[0] != q.shape[1]:
raise ValueError(
f"Expected {attention_sink.shape[0]=} to be equal to"
f" {q.shape[1]=} (num_q_heads).")
if attention_sink.dtype != jnp.float32:
raise ValueError(
f"Expected {attention_sink.dtype=} to be equal to {jnp.float32=}."
)

actual_head_dim = q.shape[2]
actual_num_q_heads = q.shape[1]
Expand Down Expand Up @@ -1278,6 +1319,7 @@ def ragged_paged_attention(
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
distribution: jax.Array, # i32[3]
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
*,
sm_scale: float = 1.0,
sliding_window: int | None = None,
Expand Down Expand Up @@ -1338,6 +1380,7 @@ def ragged_paged_attention(
page_indices,
cu_q_lens,
distribution,
attention_sink,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
Expand All @@ -1356,7 +1399,7 @@ def ragged_paged_attention(
actual_num_kv_heads = k.shape[1]

actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
q, kv = prepare_inputs(q, k, v)
q, kv, attention_sink = prepare_inputs(q, k, v, attention_sink)
(
_,
max_num_tokens,
Expand Down Expand Up @@ -1395,6 +1438,8 @@ def ragged_paged_attention(
pl.BlockSpec(memory_space=pltpu.HBM),
pl.BlockSpec(memory_space=pltpu.HBM),
pl.BlockSpec(memory_space=pltpu.HBM),
None if attention_sink is None else pl.BlockSpec(
memory_space=pltpu.VMEM)
]

out_specs = [
Expand Down Expand Up @@ -1493,7 +1538,8 @@ def ragged_paged_attention(
name=scope_name,
))

output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache)
output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
attention_sink)
return (
prepare_outputs(output, actual_num_q_heads_per_kv_head,
actual_head_dim),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,10 @@ def load_with_init(ref, init_val):
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
num_q_heads_per_kv_head)
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)

mask = q_span < k_span
if sliding_window is not None:
mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)

if soft_cap is not None:
s = soft_cap * jnp.tanh(s / soft_cap)
Expand Down