Skip to content

Conversation

@NeoLegends
Copy link
Member

@NeoLegends NeoLegends commented Dec 3, 2025

For efficiency in training, and also because torch's scaled_dot_product_attention is automatically exported into an optimized ONNX attention op.

Open Qs:

  • Is this actually more efficient in reality?
  • I am still doing the integration of the relative positional encoding outside of the op, because I can only have one additional summand (via the mask parameter) and not an additional factor + sum. Can this be done better?
  • Does this need a flag to be turned on and off (i.e. switched back to a non-fused implementation)?

Tests pass, so the output continues to be torch.allclose(...) to the ESPNet output even when the fused op is used.

@NeoLegends NeoLegends self-assigned this Dec 3, 2025
@NeoLegends NeoLegends force-pushed the moritz-rel-pos-conf-sdpa branch from 324fc4a to 358de26 Compare December 3, 2025 13:35
@NeoLegends NeoLegends force-pushed the moritz-rel-pos-conf-sdpa branch from eac4313 to 58699f6 Compare December 3, 2025 13:47
@NeoLegends NeoLegends marked this pull request as ready for review December 3, 2025 13:57
@NeoLegends NeoLegends changed the title MHSA: use fused SDPA for attention computation Rel. pos. MHSA: use fused SDPA for attention computation Dec 3, 2025
@NeoLegends NeoLegends changed the title Rel. pos. MHSA: use fused SDPA for attention computation Rel. pos. MHSA: use fused op for attention computation Dec 3, 2025
@albertz
Copy link
Member

albertz commented Dec 3, 2025

Did you check what SDPA backend it would actually use? And what it does use?

I was checking a bit the logic. I think you can see that here:

https://github.com/pytorch/pytorch/blob/7ba4680f3755a560af81aa0f688791e367aa3609/aten/src/ATen/native/transformers/attention.cpp#L718
https://github.com/pytorch/pytorch/blob/e3f24fd73ad74c6e7176687986436956c7c18235/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L764

I think for example Flash Attention will not be used, because Flash Attention does not support attn_mask.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants