diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 5940986240..bad320fe5c 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -95,7 +95,7 @@ def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tuple[int, i device=q.device, **block_mask_kwargs, ) - out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale) + out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale, enable_gqa=True) return out_ref.transpose(1, 2).contiguous() @@ -809,7 +809,7 @@ def run_flex_reference_bwd(q, k, v, block_mask, grad_out, dtype=None): # Use flex_attention directly without torch.compile for backward tests # torch.compile can hang on certain mask patterns (e.g., mini_causal with float32) - out_ref = flex_attention(q_ref, k_ref, v_ref, block_mask=block_mask) + out_ref = flex_attention(q_ref, k_ref, v_ref, block_mask=block_mask, enable_gqa=True) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), grad_out_ref) # Transpose back to BSHD