Skip to content

Conversation

@jayhshah
Copy link
Collaborator

@jayhshah jayhshah commented Jan 8, 2026

We add varlen to sm100 backward pass and expose this capability through the flash_attn_varlen_func API.

Head-to-head benchmark against non-varlen sm100 backward for equal sequence lengths shows minimal overhead:

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16 ###
FA Python fwd: 1.445ms, 1522.0 TFLOPS
FA Python varlen fwd: 1.485ms, 1481.1 TFLOPS
FA Python bwd: 4.461ms, 1232.2 TFLOPS
FA Python varlen bwd: 4.577ms, 1201.0 TFLOPS
FA Python bwd (deterministic): 6.007ms, 915.1 TFLOPS
FA Python varlen bwd (deterministic): 6.020ms, 913.3 TFLOPS

### headdim = 128, causal = True, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16 ###
FA Python fwd: 0.775ms, 1419.4 TFLOPS
FA Python varlen fwd: 0.792ms, 1388.1 TFLOPS
FA Python bwd: 2.447ms, 1123.4 TFLOPS
FA Python varlen bwd: 2.542ms, 1081.4 TFLOPS
FA Python bwd (deterministic): 2.923ms, 940.4 TFLOPS
FA Python varlen bwd (deterministic): 3.012ms, 912.6 TFLOPS

To fix an alignment issue with loading padded LSE in the backward kernel, we change padded offsets to FA3 style, e.g.:

padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
if cutlass.const_expr(self.arch >= 90):
    padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size

@jayhshah jayhshah force-pushed the jshah/sm100-varlen-bwd branch from 3b798d3 to 3d5f721 Compare January 8, 2026 08:17
@kiddyboots216
Copy link

varlen_fwd (and bwd) training matches FA2 on Blackwell
image

@jayhshah jayhshah force-pushed the jshah/sm100-varlen-bwd branch from 10ccba2 to 63147ed Compare January 8, 2026 18:51
@jayhshah jayhshah requested a review from tridao January 8, 2026 19:07
@jayhshah jayhshah force-pushed the jshah/sm100-varlen-bwd branch from 63147ed to 7b7c045 Compare January 8, 2026 19:17
@v0i0 v0i0 self-requested a review January 9, 2026 19:26
if const_expr(self.qhead_per_kvhead > 1):
self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None
self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None
self.use_tma_store = not (self.qhead_per_kvhead == 1 and self.is_varlen_k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that meant to be an or? what is the logic here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right to call this out, I only need to not use tma store dK/dV for cu_seqlens_k. I'll change this.

In general:

  1. varlen k is the condition for the varlen scheduler
  2. varlen q is a condition for checking process tile (since mblocks processed may then equal 0 if length for that batch is 0)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though we also disable tma store for seqused_q only in the forward kernel, so there will be an inconsistency here (albeit with rarely used settings)

Copy link
Collaborator Author

@jayhshah jayhshah Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To address your other implicit question: note that since we use a special padded intermediate tensor for doing TMA reduce add with dK/dV accum when gqa, we are free to use tma store then without worrying about the usual problem of overwriting other batch's outputs. So it should be 'and' and not 'or'.

For the same reason we can use TMA reduce add for dQ accum.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the idea is: if we have gqa we post-process, so we can always use tma, even with seqlen_k. and we can use it without seqlen_k. so the only case where we need to not use tma is not seqlen_k and not gqa.

Copy link
Collaborator Author

@jayhshah jayhshah Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes (assuming you mean not tma = cu_seqlens_k and mha). I also tried using postprocess with mha and cu_seqlens_k to allow for tma store (hence why I had separated out the dKV_postprocess boolean) but that was slightly slower.

@jayhshah jayhshah merged commit ed6a82f into main Jan 9, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants