diff --git a/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py b/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py index f10e7962e..e9bf728ab 100644 --- a/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +++ b/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py @@ -230,8 +230,9 @@ def _ragged_paged_attention_kernel( # TODO(jevinjiang): merge these into one so we can save SMEM. distribution_ref, # [3] (decode_end, prefill_end, mixed_end) sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx) - bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx) + bo_ids_ref, # [6] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz) bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz) + bq_fetch_ids_ref, # [2] (bq_sem_0_sz, bq_sem_1_sz) # Input q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim] kv_hbm_ref, # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] @@ -562,9 +563,6 @@ def loop_body(i, states): def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False): sem = sems.at[1, bq_sem_idx] vmem_ref = bq_x2_ref.at[bq_sem_idx] - q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz - q_end = cu_q_lens_ref[seq_idx + 1] - sz = jnp.minimum(bq_sz, q_end - q_len_start) debug_print( "[RPA debug]" @@ -572,23 +570,42 @@ def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False): debug_print("[RPA debug] seq_idx={}", seq_idx) debug_print("[RPA debug] bq_idx={}", bq_idx) debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx) - debug_print("[RPA debug] q_len_start={}", q_len_start) - debug_print("[RPA debug] q_end={}", q_end) - debug_print("[RPA debug] sz={}", sz) - - _async_copy( - q_hbm_ref.at[:, pl.ds(q_len_start, sz)], - vmem_ref.at[:, pl.ds(0, sz)], - sem, - wait, - ) + + if not wait: + # Calculate sz and store it in scratch + q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz + q_end = cu_q_lens_ref[seq_idx + 1] + sz = jnp.minimum(bq_sz, q_end - q_len_start) + + # Store sz in scratch for later use + bq_fetch_ids_ref[bq_sem_idx] = sz + + debug_print("[RPA debug] q_len_start={}", q_len_start) + debug_print("[RPA debug] q_end={}", q_end) + debug_print("[RPA debug] sz={}", sz) + + _async_copy( + q_hbm_ref.at[:, pl.ds(q_len_start, sz)], + vmem_ref.at[:, pl.ds(0, sz)], + sem, + wait, + ) + else: + # Retrieve sz from scratch instead of recalculating + sz = bq_fetch_ids_ref[bq_sem_idx] + + debug_print("[RPA debug] sz (from scratch)={}", sz) + dst = vmem_ref.at[:, pl.ds(0, sz)] + _async_copy( + dst, + dst, + sem, + wait, + ) def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False): sem = sems.at[2, bo_sem_idx] vmem_ref = bo_x2_ref.at[bo_sem_idx] - q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz - q_end = cu_q_lens_ref[seq_idx + 1] - sz = jnp.minimum(bq_sz, q_end - q_len_start) debug_print( "[RPA debug]" @@ -596,16 +613,39 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False): debug_print("[RPA debug] seq_idx={}", seq_idx) debug_print("[RPA debug] bo_idx={}", bo_idx) debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx) - debug_print("[RPA debug] q_len_start={}", q_len_start) - debug_print("[RPA debug] q_end={}", q_end) - debug_print("[RPA debug] sz={}", sz) - - _async_copy( - vmem_ref.at[:, pl.ds(0, sz)], - o_hbm_ref.at[:, pl.ds(q_len_start, sz)], - sem, - wait, - ) + + if not wait: + # Calculate sz and store it in scratch + q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz + q_end = cu_q_lens_ref[seq_idx + 1] + sz = jnp.minimum(bq_sz, q_end - q_len_start) + + # Store sz in scratch for later use + bo_ids_ref[bo_sem_idx + 4] = sz + + debug_print("[RPA debug] q_len_start={}", q_len_start) + debug_print("[RPA debug] q_end={}", q_end) + debug_print("[RPA debug] sz={}", sz) + + _async_copy( + vmem_ref.at[:, pl.ds(0, sz)], + o_hbm_ref.at[:, pl.ds(q_len_start, sz)], + sem, + wait, + ) + else: + # Retrieve sz from scratch instead of recalculating + sz = bo_ids_ref[bo_sem_idx + 4] + + debug_print("[RPA debug] sz (from scratch)={}", sz) + + dst = o_hbm_ref.at[:, pl.ds(0, sz)] + _async_copy( + dst, + dst, + sem, + wait, + ) def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx): return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx) @@ -1445,10 +1485,12 @@ def ragged_paged_attention( distribution, # (bq_sem_idx, bkv_sem_idx, bo_sem_idx) jnp.zeros((3, ), jnp.int32), - # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx) - jnp.full((4, ), -1, jnp.int32), + # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz) + jnp.full((6, ), -1, jnp.int32), # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz) jnp.full((6, ), -1, jnp.int32), + # (bq_sem_0_sz, bq_sem_1_sz) + jnp.full((2, ), -1, jnp.int32), ) scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}" @@ -1487,8 +1529,8 @@ def ragged_paged_attention( dtype=kv_cache.dtype), ], input_output_aliases={ - 7: 0, - 9: 1 + 8: 0, + 10: 1 }, name=scope_name, )) diff --git a/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py b/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py index e94ee3579..e1a14a5c5 100644 --- a/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +++ b/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py @@ -246,8 +246,9 @@ def _ragged_paged_attention_kernel( # TODO(jevinjiang): merge these into one so we can save SMEM. distribution_ref, # [3] (decode_end, prefill_end, mixed_end) sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx) - bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx) + bo_ids_ref, # [6] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz) bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz) + bq_fetch_ids_ref, # [2] (bq_sem_0_sz, bq_sem_1_sz) # Input q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim] kv_hbm_ref, # [max_num_tokens, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2] @@ -619,9 +620,6 @@ def loop_body(i, states): def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False): sem = sems.at[1, bq_sem_idx] vmem_ref = bq_x2_ref.at[bq_sem_idx] - q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz - q_end = cu_q_lens_ref[seq_idx + 1] - sz = jnp.minimum(bq_sz, q_end - q_len_start) debug_print( "[RPA debug]" @@ -629,23 +627,43 @@ def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False): debug_print("[RPA debug] seq_idx={}", seq_idx) debug_print("[RPA debug] bq_idx={}", bq_idx) debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx) - debug_print("[RPA debug] q_len_start={}", q_len_start) - debug_print("[RPA debug] q_end={}", q_end) - debug_print("[RPA debug] sz={}", sz) - - _async_copy( - q_hbm_ref.at[:, pl.ds(q_len_start, sz)], - vmem_ref.at[:, pl.ds(0, sz)], - sem, - wait, - ) + + if not wait: + # Calculate sz and store it in scratch + q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz + q_end = cu_q_lens_ref[seq_idx + 1] + sz = jnp.minimum(bq_sz, q_end - q_len_start) + + # Store sz in scratch for later use + bq_fetch_ids_ref[bq_sem_idx] = sz + + debug_print("[RPA debug] q_len_start={}", q_len_start) + debug_print("[RPA debug] q_end={}", q_end) + debug_print("[RPA debug] sz={}", sz) + + _async_copy( + q_hbm_ref.at[:, pl.ds(q_len_start, sz)], + vmem_ref.at[:, pl.ds(0, sz)], + sem, + wait, + ) + else: + # Retrieve sz from scratch instead of recalculating + sz = bq_fetch_ids_ref[bq_sem_idx] + + debug_print("[RPA debug] sz (from scratch)={}", sz) + + dst = vmem_ref.at[:, pl.ds(0, sz)] + _async_copy( + dst, + dst, + sem, + wait, + ) def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False): sem = sems.at[2, bo_sem_idx] vmem_ref = bo_x2_ref.at[bo_sem_idx] - q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz - q_end = cu_q_lens_ref[seq_idx + 1] - sz = jnp.minimum(bq_sz, q_end - q_len_start) debug_print( "[RPA debug]" @@ -653,16 +671,39 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False): debug_print("[RPA debug] seq_idx={}", seq_idx) debug_print("[RPA debug] bo_idx={}", bo_idx) debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx) - debug_print("[RPA debug] q_len_start={}", q_len_start) - debug_print("[RPA debug] q_end={}", q_end) - debug_print("[RPA debug] sz={}", sz) - - _async_copy( - vmem_ref.at[:, pl.ds(0, sz)], - o_hbm_ref.at[:, pl.ds(q_len_start, sz)], - sem, - wait, - ) + + if not wait: + # Calculate sz and store it in scratch + q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz + q_end = cu_q_lens_ref[seq_idx + 1] + sz = jnp.minimum(bq_sz, q_end - q_len_start) + + # Store sz in scratch for later use + bo_ids_ref[bo_sem_idx + 4] = sz + + debug_print("[RPA debug] q_len_start={}", q_len_start) + debug_print("[RPA debug] q_end={}", q_end) + debug_print("[RPA debug] sz={}", sz) + + _async_copy( + vmem_ref.at[:, pl.ds(0, sz)], + o_hbm_ref.at[:, pl.ds(q_len_start, sz)], + sem, + wait, + ) + else: + # Retrieve sz from scratch instead of recalculating + sz = bo_ids_ref[bo_sem_idx + 4] + + debug_print("[RPA debug] sz (from scratch)={}", sz) + + dst = o_hbm_ref.at[:, pl.ds(0, sz)] + _async_copy( + dst, + dst, + sem, + wait, + ) def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx): return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx) @@ -1511,10 +1552,12 @@ def ragged_paged_attention_hd64( distribution, # (bq_sem_idx, bkv_sem_idx, bo_sem_idx) jnp.zeros((3, ), jnp.int32), - # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx) - jnp.full((4, ), -1, jnp.int32), + # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz) + jnp.full((6, ), -1, jnp.int32), # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz) jnp.full((6, ), -1, jnp.int32), + # (bq_sem_0_sz, bq_sem_1_sz) + jnp.full((2, ), -1, jnp.int32), ) scope_name = f"RPA-HD_64-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}" @@ -1554,8 +1597,8 @@ def ragged_paged_attention_hd64( dtype=kv_cache.dtype), ], input_output_aliases={ - 7: 0, - 9: 1 + 8: 0, + 10: 1 }, name=scope_name, ))