@@ -319,7 +319,7 @@ def debug_print(msg, *args):
319319 debug_print ("[RPA debug] q_len={}" , q_len )
320320 debug_print ("[RPA debug] kv_len={}" , kv_len )
321321
322- def flash_attention (
322+ def flash_attention_step1_qk_softmax (
323323 q , # [actual_bq_sz * num_q_heads_per_kv_head, head_dim]
324324 k , # [bkv_sz, head_dim]
325325 v , # [bkv_sz, head_dim]
@@ -335,7 +335,6 @@ def flash_attention(
335335 assert k .dtype == v .dtype
336336 head_l_ref = l_ref .at [kv_head_idx , :q .shape [0 ]]
337337 head_m_ref = m_ref .at [kv_head_idx , :q .shape [0 ]]
338- head_acc_ref = acc_ref .at [kv_head_idx , :q .shape [0 ]]
339338
340339 def load_with_init (ref , init_val ):
341340 return jnp .where (bkv_idx == 0 , jnp .full_like (ref , init_val ),
@@ -376,15 +375,32 @@ def load_with_init(ref, init_val):
376375 head_m_ref [...] = m_curr
377376 p = jnp .exp (s - broadcast_minor (m_curr , s .shape ))
378377
379- pv = jnp .einsum ("nm,md->nd" , p , v , preferred_element_type = jnp .float32 )
380- if v_scale is not None :
381- pv *= v_scale
382-
383378 p_rowsum = jnp .sum (p , axis = 1 , keepdims = True )
384379 exp_m_diff = jnp .exp (m_prev - m_curr )
385380 l_prev = load_with_init (head_l_ref , 0.0 )
386381 l_curr = exp_m_diff * l_prev + p_rowsum
387382 head_l_ref [...] = l_curr
383+
384+ return p , exp_m_diff
385+
386+ def flash_attention_step2_pv (
387+ q_shape_0 ,
388+ v , # [bkv_sz, head_dim]
389+ p , # from step1
390+ exp_m_diff , # from step1
391+ * ,
392+ bkv_idx ,
393+ kv_head_idx ,
394+ ):
395+ head_acc_ref = acc_ref .at [kv_head_idx , :q_shape_0 ]
396+
397+ def load_with_init (ref , init_val ):
398+ return jnp .where (bkv_idx == 0 , jnp .full_like (ref , init_val ),
399+ ref [...])
400+
401+ pv = jnp .einsum ("nm,md->nd" , p , v , preferred_element_type = jnp .float32 )
402+ if v_scale is not None :
403+ pv *= v_scale
388404 o_prev = load_with_init (head_acc_ref , 0.0 )
389405 o_curr = broadcast_minor (exp_m_diff , o_prev .shape ) * o_prev + pv
390406 head_acc_ref [...] = o_curr
@@ -839,6 +855,11 @@ def update_cur_bkv_to_cache():
839855
840856 # Flash attention with cur bkv and bq
841857 # NOTE: kv_packing is divided by 2 because k and v are packed together.
858+ prev_bq_shape_0 = None
859+ prev_kv_head_bv = None
860+ prev_kv_head_idx = None
861+ prev_kv_head_p = None
862+ prev_kv_head_exp_m_diff = None
842863 heads_per_load = max (1 , kv_packing // 2 )
843864 for kv_head_start in range (0 , actual_num_kv_heads ,
844865 heads_per_load ):
@@ -850,21 +871,53 @@ def update_cur_bkv_to_cache():
850871 )
851872 assert len (bkv_lst ) == heads_per_load
852873 for i in range (heads_per_load ):
853- kv_head_idx = kv_head_start + i
854- if kv_head_idx >= actual_num_kv_heads :
874+ cur_kv_head_idx = kv_head_start + i
875+ if cur_kv_head_idx >= actual_num_kv_heads :
855876 break
856- bq = load_bq (bq_sem_idx ,
857- kv_head_idx ,
858- actual_bq_sz = actual_bq_sz )
877+
878+ cur_kv_head_bq = load_bq (bq_sem_idx ,
879+ cur_kv_head_idx ,
880+ actual_bq_sz = actual_bq_sz )
859881 bk , bv = bkv_lst [i ]
860- flash_attention (
861- bq ,
862- bk ,
863- bv ,
864- bq_idx = bq_idx ,
865- bkv_idx = bkv_idx ,
866- kv_head_idx = kv_head_idx ,
867- )
882+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
883+ # and `flash_attention_step2_pv` to pipeline the computation.
884+ # `step2_pv` for the previous KV head, which depends on the softmax
885+ # output, is overlapped with `step1_qk_softmax` for the current KV
886+ # head, reducing overall wait times.
887+ cur_kv_head_p , cur_kv_head_exp_m_diff = (
888+ flash_attention_step1_qk_softmax (
889+ cur_kv_head_bq ,
890+ bk ,
891+ bv ,
892+ bq_idx = bq_idx ,
893+ bkv_idx = bkv_idx ,
894+ kv_head_idx = cur_kv_head_idx ,
895+ ))
896+ if prev_bq_shape_0 is not None :
897+ flash_attention_step2_pv (
898+ prev_bq_shape_0 ,
899+ prev_kv_head_bv ,
900+ prev_kv_head_p ,
901+ prev_kv_head_exp_m_diff ,
902+ bkv_idx = bkv_idx ,
903+ kv_head_idx = prev_kv_head_idx ,
904+ )
905+ prev_bq_shape_0 = cur_kv_head_bq .shape [0 ]
906+ prev_kv_head_bv = bv
907+ prev_kv_head_p = cur_kv_head_p
908+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
909+ prev_kv_head_idx = cur_kv_head_idx
910+
911+ # Execute pv of last attention head.
912+ assert prev_bq_shape_0 is not None
913+ flash_attention_step2_pv (
914+ prev_bq_shape_0 ,
915+ prev_kv_head_bv ,
916+ prev_kv_head_p ,
917+ prev_kv_head_exp_m_diff ,
918+ bkv_idx = bkv_idx ,
919+ kv_head_idx = prev_kv_head_idx ,
920+ )
868921
869922 lax .fori_loop (0 , num_bkv , compute_with_bkv , None , unroll = False )
870923
0 commit comments