-
Notifications
You must be signed in to change notification settings - Fork 2.3k
block-sparse backward SM90 #2136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
5742bb0 to
632f364
Compare
Adds block-sparse support to SM90 backward pass: - Block-sparse iteration with process_tile, get_block_sparse_iteration_info_bwd - m_block_safe clamping for loads when subtile_factor>1 - Zero-fill for KV tiles with no Q blocks - dQaccum_store with blocksparse_tensors parameter - bwd_subtile_factor=2 for SM90 block sparsity (matches BlockMask 128 granularity) - Tile size m_block_size=64 when using block sparsity stack-info: PR: Dao-AILab#2136, branch: drisspg/stack/7
flash_attn/cute/interface.py
Outdated
|
|
||
| use_block_sparsity = block_sparse_tensors is not None | ||
|
|
||
| # For SM90 with block sparsity, use tile_m=64 with subtile_factor=2 to match |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was mostly to find the GCD between a m_block_size that would fit and the base block_m of 128 from fwd and block-sparse size for subtiling.
flash_attn/cute/interface.py
Outdated
| expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( | ||
| batch_size, num_head, seqlen_q, seqlen_k, | ||
| m_block_size, n_block_size, subtile_factor, | ||
| m_block_size, n_block_size, bwd_subtile_factor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nb: bwd_subtile_factor is always 2 but we could make this larger in a follow up and allow for smaller tile sizes
632f364 to
246cde5
Compare
246cde5 to
2cc732e
Compare
2cc732e to
4e91b34
Compare
stack-info: PR: Dao-AILab#2136, branch: drisspg/stack/7
da4b3e8 to
7be008a
Compare
7be008a to
d592b8d
Compare
stack-info: PR: Dao-AILab#2136, branch: drisspg/stack/7
stack-info: PR: Dao-AILab#2136, branch: drisspg/stack/7
d592b8d to
d0f91aa
Compare
stack-info: PR: #2136, branch: drisspg/stack/7
d0f91aa to
edd5b15
Compare
Stacked PRs:
block-sparse backward SM90