-
Notifications
You must be signed in to change notification settings - Fork 59
Save size in scalar scratch for bo and bq #1201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
18e55a4
c2b91a7
4920fb7
90f185c
754250d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -230,8 +230,9 @@ def _ragged_paged_attention_kernel( | |
| # TODO(jevinjiang): merge these into one so we can save SMEM. | ||
| distribution_ref, # [3] (decode_end, prefill_end, mixed_end) | ||
| sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx) | ||
| bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx) | ||
| bo_ids_ref, # [6] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz) | ||
| bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz) | ||
| bq_fetch_ids_ref, # [2] (bq_sem_0_sz, bq_sem_1_sz) | ||
| # Input | ||
| q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim] | ||
| kv_hbm_ref, # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] | ||
|
|
@@ -562,50 +563,89 @@ def loop_body(i, states): | |
| def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False): | ||
| sem = sems.at[1, bq_sem_idx] | ||
| vmem_ref = bq_x2_ref.at[bq_sem_idx] | ||
| q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz | ||
| q_end = cu_q_lens_ref[seq_idx + 1] | ||
| sz = jnp.minimum(bq_sz, q_end - q_len_start) | ||
|
|
||
| debug_print( | ||
| "[RPA debug]" | ||
| f" -----------{'wait' if wait else 'start'}_fetch_bq-----------") | ||
| debug_print("[RPA debug] seq_idx={}", seq_idx) | ||
| debug_print("[RPA debug] bq_idx={}", bq_idx) | ||
| debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx) | ||
| debug_print("[RPA debug] q_len_start={}", q_len_start) | ||
| debug_print("[RPA debug] q_end={}", q_end) | ||
| debug_print("[RPA debug] sz={}", sz) | ||
|
|
||
| _async_copy( | ||
| q_hbm_ref.at[:, pl.ds(q_len_start, sz)], | ||
| vmem_ref.at[:, pl.ds(0, sz)], | ||
| sem, | ||
| wait, | ||
| ) | ||
|
|
||
| if not wait: | ||
| # Calculate sz and store it in scratch | ||
| q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz | ||
| q_end = cu_q_lens_ref[seq_idx + 1] | ||
| sz = jnp.minimum(bq_sz, q_end - q_len_start) | ||
|
|
||
| # Store sz in scratch for later use | ||
| bq_fetch_ids_ref[bq_sem_idx] = sz | ||
|
|
||
| debug_print("[RPA debug] q_len_start={}", q_len_start) | ||
| debug_print("[RPA debug] q_end={}", q_end) | ||
| debug_print("[RPA debug] sz={}", sz) | ||
|
|
||
| _async_copy( | ||
| q_hbm_ref.at[:, pl.ds(q_len_start, sz)], | ||
| vmem_ref.at[:, pl.ds(0, sz)], | ||
| sem, | ||
| wait, | ||
| ) | ||
| else: | ||
| # Retrieve sz from scratch instead of recalculating | ||
| sz = bq_fetch_ids_ref[bq_sem_idx] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely need to retune and update the tuned block sizes. I understand you may not have autotuen script. But please write a benchmarking script even with same block size, we want to see perf on different block sizes and different models. I am very strict with this in Google internal kernel development as well. We don't want to just check in the code without really understanding how much it can bring in different model(shapes) and block sizes.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even appending throughput change on different models is acceptable. Thanks |
||
|
|
||
| debug_print("[RPA debug] sz (from scratch)={}", sz) | ||
| dst = vmem_ref.at[:, pl.ds(0, sz)] | ||
| _async_copy( | ||
| dst, | ||
| dst, | ||
| sem, | ||
| wait, | ||
| ) | ||
|
|
||
| def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False): | ||
| sem = sems.at[2, bo_sem_idx] | ||
| vmem_ref = bo_x2_ref.at[bo_sem_idx] | ||
| q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz | ||
| q_end = cu_q_lens_ref[seq_idx + 1] | ||
| sz = jnp.minimum(bq_sz, q_end - q_len_start) | ||
|
|
||
| debug_print( | ||
| "[RPA debug]" | ||
| f" -----------{'wait' if wait else 'start'}_send_bo-----------") | ||
| debug_print("[RPA debug] seq_idx={}", seq_idx) | ||
| debug_print("[RPA debug] bo_idx={}", bo_idx) | ||
| debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx) | ||
| debug_print("[RPA debug] q_len_start={}", q_len_start) | ||
| debug_print("[RPA debug] q_end={}", q_end) | ||
| debug_print("[RPA debug] sz={}", sz) | ||
|
|
||
| _async_copy( | ||
| vmem_ref.at[:, pl.ds(0, sz)], | ||
| o_hbm_ref.at[:, pl.ds(q_len_start, sz)], | ||
| sem, | ||
| wait, | ||
| ) | ||
|
|
||
| if not wait: | ||
| # Calculate sz and store it in scratch | ||
| q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz | ||
| q_end = cu_q_lens_ref[seq_idx + 1] | ||
| sz = jnp.minimum(bq_sz, q_end - q_len_start) | ||
|
|
||
| # Store sz in scratch for later use | ||
| bo_ids_ref[bo_sem_idx + 4] = sz | ||
|
|
||
| debug_print("[RPA debug] q_len_start={}", q_len_start) | ||
| debug_print("[RPA debug] q_end={}", q_end) | ||
| debug_print("[RPA debug] sz={}", sz) | ||
|
|
||
| _async_copy( | ||
| vmem_ref.at[:, pl.ds(0, sz)], | ||
| o_hbm_ref.at[:, pl.ds(q_len_start, sz)], | ||
| sem, | ||
| wait, | ||
| ) | ||
| else: | ||
| # Retrieve sz from scratch instead of recalculating | ||
| sz = bo_ids_ref[bo_sem_idx + 4] | ||
|
|
||
| debug_print("[RPA debug] sz (from scratch)={}", sz) | ||
|
|
||
| dst = o_hbm_ref.at[:, pl.ds(0, sz)] | ||
| _async_copy( | ||
| dst, | ||
| dst, | ||
| sem, | ||
| wait, | ||
| ) | ||
|
|
||
| def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx): | ||
| return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx) | ||
|
|
@@ -1445,10 +1485,12 @@ def ragged_paged_attention( | |
| distribution, | ||
| # (bq_sem_idx, bkv_sem_idx, bo_sem_idx) | ||
| jnp.zeros((3, ), jnp.int32), | ||
| # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx) | ||
| jnp.full((4, ), -1, jnp.int32), | ||
| # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz) | ||
| jnp.full((6, ), -1, jnp.int32), | ||
| # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz) | ||
| jnp.full((6, ), -1, jnp.int32), | ||
| # (bq_sem_0_sz, bq_sem_1_sz) | ||
| jnp.full((2, ), -1, jnp.int32), | ||
| ) | ||
|
|
||
| scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}" | ||
|
|
@@ -1487,8 +1529,8 @@ def ragged_paged_attention( | |
| dtype=kv_cache.dtype), | ||
| ], | ||
| input_output_aliases={ | ||
| 7: 0, | ||
| 9: 1 | ||
| 8: 0, | ||
rupengliu-meta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| 10: 1 | ||
| }, | ||
| name=scope_name, | ||
| )) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just call
bq_ids_ref