-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[Cute, Bwd, Sm100] Add varlen for sm100 bwd #2150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
3b798d3 to
3d5f721
Compare
10ccba2 to
63147ed
Compare
63147ed to
7b7c045
Compare
flash_attn/cute/flash_bwd_sm100.py
Outdated
| if const_expr(self.qhead_per_kvhead > 1): | ||
| self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None | ||
| self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None | ||
| self.use_tma_store = not (self.qhead_per_kvhead == 1 and self.is_varlen_k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is that meant to be an or? what is the logic here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right to call this out, I only need to not use tma store dK/dV for cu_seqlens_k. I'll change this.
In general:
- varlen k is the condition for the varlen scheduler
- varlen q is a condition for checking process tile (since mblocks processed may then equal 0 if length for that batch is 0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though we also disable tma store for seqused_q only in the forward kernel, so there will be an inconsistency here (albeit with rarely used settings)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To address your other implicit question: note that since we use a special padded intermediate tensor for doing TMA reduce add with dK/dV accum when gqa, we are free to use tma store then without worrying about the usual problem of overwriting other batch's outputs. So it should be 'and' and not 'or'.
For the same reason we can use TMA reduce add for dQ accum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the idea is: if we have gqa we post-process, so we can always use tma, even with seqlen_k. and we can use it without seqlen_k. so the only case where we need to not use tma is not seqlen_k and not gqa.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes (assuming you mean not tma = cu_seqlens_k and mha). I also tried using postprocess with mha and cu_seqlens_k to allow for tma store (hence why I had separated out the dKV_postprocess boolean) but that was slightly slower.

We add varlen to sm100 backward pass and expose this capability through the
flash_attn_varlen_funcAPI.Head-to-head benchmark against non-varlen sm100 backward for equal sequence lengths shows minimal overhead:
To fix an alignment issue with loading padded LSE in the backward kernel, we change padded offsets to FA3 style, e.g.: