@@ -324,8 +324,12 @@ def fused_attn_fwd(
324324 )
325325
326326 if return_max_score :
327- qkv_format = qkv_layout .replace ("3" ,"" ).replace ("2" ,"" ).split ("_" )[0 ]
328- print (f"dev { torch .cuda .current_device ()} qkv_format: { qkv_format } , cu_seqlens_q: { cu_seqlens_q } , cu_seqlens_kv: { cu_seqlens_kv } , cu_seqlens_q_padded: { cu_seqlens_q_padded } , cu_seqlens_kv_padded: { cu_seqlens_kv_padded } " )
327+ qkv_format = qkv_layout .replace ("3" , "" ).replace ("2" , "" ).split ("_" )[0 ]
328+ print (
329+ f"dev { torch .cuda .current_device ()} qkv_format: { qkv_format } , cu_seqlens_q:"
330+ f" { cu_seqlens_q } , cu_seqlens_kv: { cu_seqlens_kv } , cu_seqlens_q_padded:"
331+ f" { cu_seqlens_q_padded } , cu_seqlens_kv_padded: { cu_seqlens_kv_padded } "
332+ )
329333 # print(f"dev {torch.cuda.current_device()} q: {q.shape}, k: {k.shape}, v: {v.shape}")
330334 # print(f"dev {torch.cuda.current_device()} output_tensors[0] is nan: {torch.isnan(output_tensors[0]).sum()}, output_tensors[0]: {output_tensors[0].shape}, output_tensors[0].min(): {output_tensors[0].min()}, output_tensors[0].max(): {output_tensors[0].max()}")
331335 # print(f"dev {torch.cuda.current_device()} output_tensors[1] is nan: {torch.isnan(output_tensors[1]).sum()}, output_tensors[1]: {output_tensors[1].shape}, output_tensors[1].min(): {output_tensors[1].min()}, output_tensors[1].max(): {output_tensors[1].max()}")
@@ -335,7 +339,10 @@ def fused_attn_fwd(
335339 stats = output_tensors [1 ] + torch .log (output_tensors [2 ])
336340 zero_indices_1 = (output_tensors [1 ] == 0 ).nonzero ()
337341 zero_indices_2 = (output_tensors [2 ] == 0 ).nonzero ()
338- print (f"dev { torch .cuda .current_device ()} zero_indices_1: { zero_indices_1 } , zero_indices_2: { zero_indices_2 } " )
342+ print (
343+ f"dev { torch .cuda .current_device ()} zero_indices_1: { zero_indices_1 } ,"
344+ f" zero_indices_2: { zero_indices_2 } "
345+ )
339346 if torch .cuda .current_device () == 0 and not os .path .exists ("output_tensors1.pt" ):
340347 torch .save (output_tensors [1 ], "output_tensors1.pt" )
341348 torch .save (output_tensors [2 ], "output_tensors2.pt" )
@@ -344,16 +351,39 @@ def fused_attn_fwd(
344351 # Max [tq, h, 1] -> max_score [h]
345352 max_score = torch .amax (output_tensors [1 ], dim = (0 , 2 )).to (dtype = output_tensors [0 ].dtype )
346353 print (f"dev { torch .cuda .current_device ()} max_score: { max_score } " )
347- print (f"dev { torch .cuda .current_device ()} output_tensors[0] is nan: { torch .isnan (output_tensors [0 ]).sum ()} , output_tensors[0]: { output_tensors [0 ].shape } , output_tensors[0].min(): { output_tensors [0 ].min ()} , output_tensors[0].max(): { output_tensors [0 ].max ()} " )
348- print (f"dev { torch .cuda .current_device ()} output_tensors[1] is nan: { torch .isnan (output_tensors [1 ]).sum ()} , output_tensors[1]: { output_tensors [1 ].shape } , output_tensors[1].min(): { output_tensors [1 ].min ()} , output_tensors[1].max(): { output_tensors [1 ].max ()} " )
349- print (f"dev { torch .cuda .current_device ()} output_tensors[2] is nan: { torch .isnan (output_tensors [2 ]).sum ()} , output_tensors[2]: { output_tensors [2 ].shape } , output_tensors[2].min(): { output_tensors [2 ].min ()} , output_tensors[2].max(): { output_tensors [2 ].max ()} " )
350- print (f"dev { torch .cuda .current_device ()} stats is nan: { torch .isnan (stats ).sum ()} , stats: { stats .shape } , stats.min(): { stats .min ()} , stats.max(): { stats .max ()} " )
351- print (f"dev { torch .cuda .current_device ()} max_score is nan: { torch .isnan (max_score ).sum ()} , max_score: { max_score .shape } " )
354+ print (
355+ f"dev { torch .cuda .current_device ()} output_tensors[0] is nan:"
356+ f" { torch .isnan (output_tensors [0 ]).sum ()} , output_tensors[0]:"
357+ f" { output_tensors [0 ].shape } , output_tensors[0].min(): { output_tensors [0 ].min ()} ,"
358+ f" output_tensors[0].max(): { output_tensors [0 ].max ()} "
359+ )
360+ print (
361+ f"dev { torch .cuda .current_device ()} output_tensors[1] is nan:"
362+ f" { torch .isnan (output_tensors [1 ]).sum ()} , output_tensors[1]:"
363+ f" { output_tensors [1 ].shape } , output_tensors[1].min(): { output_tensors [1 ].min ()} ,"
364+ f" output_tensors[1].max(): { output_tensors [1 ].max ()} "
365+ )
366+ print (
367+ f"dev { torch .cuda .current_device ()} output_tensors[2] is nan:"
368+ f" { torch .isnan (output_tensors [2 ]).sum ()} , output_tensors[2]:"
369+ f" { output_tensors [2 ].shape } , output_tensors[2].min(): { output_tensors [2 ].min ()} ,"
370+ f" output_tensors[2].max(): { output_tensors [2 ].max ()} "
371+ )
372+ print (
373+ f"dev { torch .cuda .current_device ()} stats is nan: { torch .isnan (stats ).sum ()} ,"
374+ f" stats: { stats .shape } , stats.min(): { stats .min ()} , stats.max(): { stats .max ()} "
375+ )
376+ print (
377+ f"dev { torch .cuda .current_device ()} max_score is nan:"
378+ f" { torch .isnan (max_score ).sum ()} , max_score: { max_score .shape } "
379+ )
352380 else :
353381 # output_tensors: out [b, sq, h, d] or [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
354382 stats = output_tensors [1 ] + torch .log (output_tensors [2 ])
355383 # Max [b, h, sq, 1] -> max_score [h]
356- max_score = torch .amax (output_tensors [1 ], dim = (0 , 2 , 3 )).to (dtype = output_tensors [0 ].dtype )
384+ max_score = torch .amax (output_tensors [1 ], dim = (0 , 2 , 3 )).to (
385+ dtype = output_tensors [0 ].dtype
386+ )
357387 aux_ctx_tensors = [stats ]
358388 aux_ctx_tensors .extend (output_tensors [3 :])
359389 return output_tensors [0 ], aux_ctx_tensors , max_score
0 commit comments