Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 73 additions & 31 deletions tpu_inference/kernels/ragged_paged_attention/v3/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def _ragged_paged_attention_kernel(
# TODO(jevinjiang): merge these into one so we can save SMEM.
distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
bo_ids_ref, # [6] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
bq_fetch_ids_ref, # [2] (bq_sem_0_sz, bq_sem_1_sz)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just call bq_ids_ref

# Input
q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
kv_hbm_ref, # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
Expand Down Expand Up @@ -562,50 +563,89 @@ def loop_body(i, states):
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
sem = sems.at[1, bq_sem_idx]
vmem_ref = bq_x2_ref.at[bq_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bq_idx={}", bq_idx)
debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bq_fetch_ids_ref[bq_sem_idx] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bq_fetch_ids_ref[bq_sem_idx]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely need to retune and update the tuned block sizes. I understand you may not have autotuen script. But please write a benchmarking script even with same block size, we want to see perf on different block sizes and different models. I am very strict with this in Google internal kernel development as well. We don't want to just check in the code without really understanding how much it can bring in different model(shapes) and block sizes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even appending throughput change on different models is acceptable. Thanks


debug_print("[RPA debug] sz (from scratch)={}", sz)
dst = vmem_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
sem = sems.at[2, bo_sem_idx]
vmem_ref = bo_x2_ref.at[bo_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_send_bo-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bo_idx={}", bo_idx)
debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bo_ids_ref[bo_sem_idx + 4] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bo_ids_ref[bo_sem_idx + 4]

debug_print("[RPA debug] sz (from scratch)={}", sz)

dst = o_hbm_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
Expand Down Expand Up @@ -1445,10 +1485,12 @@ def ragged_paged_attention(
distribution,
# (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
jnp.zeros((3, ), jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
jnp.full((4, ), -1, jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bq_sem_0_sz, bq_sem_1_sz)
jnp.full((2, ), -1, jnp.int32),
)

scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
Expand Down Expand Up @@ -1487,8 +1529,8 @@ def ragged_paged_attention(
dtype=kv_cache.dtype),
],
input_output_aliases={
7: 0,
9: 1
8: 0,
10: 1
},
name=scope_name,
))
Expand Down
105 changes: 74 additions & 31 deletions tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def _ragged_paged_attention_kernel(
# TODO(jevinjiang): merge these into one so we can save SMEM.
distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
bo_ids_ref, # [6] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
bq_fetch_ids_ref, # [2] (bq_sem_0_sz, bq_sem_1_sz)
# Input
q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
kv_hbm_ref, # [max_num_tokens, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
Expand Down Expand Up @@ -619,50 +620,90 @@ def loop_body(i, states):
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
sem = sems.at[1, bq_sem_idx]
vmem_ref = bq_x2_ref.at[bq_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bq_idx={}", bq_idx)
debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bq_fetch_ids_ref[bq_sem_idx] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bq_fetch_ids_ref[bq_sem_idx]

debug_print("[RPA debug] sz (from scratch)={}", sz)

dst = vmem_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
sem = sems.at[2, bo_sem_idx]
vmem_ref = bo_x2_ref.at[bo_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_send_bo-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bo_idx={}", bo_idx)
debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bo_ids_ref[bo_sem_idx + 4] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bo_ids_ref[bo_sem_idx + 4]

debug_print("[RPA debug] sz (from scratch)={}", sz)

dst = o_hbm_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
Expand Down Expand Up @@ -1511,10 +1552,12 @@ def ragged_paged_attention_hd64(
distribution,
# (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
jnp.zeros((3, ), jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
jnp.full((4, ), -1, jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bq_sem_0_sz, bq_sem_1_sz)
jnp.full((2, ), -1, jnp.int32),
)

scope_name = f"RPA-HD_64-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
Expand Down Expand Up @@ -1554,8 +1597,8 @@ def ragged_paged_attention_hd64(
dtype=kv_cache.dtype),
],
input_output_aliases={
7: 0,
9: 1
8: 0,
10: 1
},
name=scope_name,
))
Expand Down