Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring attention.py part 1 #1542

Merged

Conversation

KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Mar 6, 2025

Description

attention.py has grown in size to 8500+ lines of code and hence the motivation to refactor the code. This PR is part 1 (of a 2 part PR effort) to refactor attention.py into sub-modules for ease of development and testing.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Create a new module, dot_product_attention which contains sub modules inference.py, rope.py and utils.py and move the appropriate methods from attention.py into these 3 sub modules. The details of these methods are in a collapsed list below for those interested:

List of Functions Moved/Added
  • rope.py (~200 lines of code)
    class RotaryPositionEmbedding(torch.nn.Module)
    class FusedRoPEFunc(torch.autograd.Function):
    def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    def apply_rotary_pos_emb(

  • inference.py (estimated to be ~1000 when finished)
    class InferenceParams: # pylint: disable=too-few-public-methods

  • utils.py (~1500 lines of code)
    class AttentionParams: Agree with a different class. Unsure of it being in utils.py
    def get_attention_backend(
    def get_full_mask(
    def get_alibi(
    def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    def _get_full_cu_seqlens(
    def pack_tensor(
    def pack_2_tensors(
    def pack_3_tensors(
    def unpack_tensor(
    def unpack_2_tensors(
    def unpack_3_tensors(
    class PackTensors(torch.autograd.Function):
    class UnpackTensor(torch.autograd.Function):
    def get_qkv_layout(
    def check_set_window_size(
    def get_attention_quantizers - Added new clas
    Create AttentionLogging class - Added new class
    Create FlashAttentioUtils class - Added new class

Notable function signature changes(comments have been added for these in the respective functions) :

  1. get_attention_backend() - This will not populate the global _attention_backends cache now but rather rest the responsibility of populating the global _attention_backends on the caller of this function
  2. get_alibi() - This will now accept the global _alibi_cache as a function parameter from the caller and read/write to the _alibi_cache cache from within the function

TODO: Refactoring part 2 PR will create new modules/submodules for MultiHeadAttention.py, context_parallelism while moving some more generic pytorch utils functions to pytorch/utils.py and also moving whatever is leftover of attention.py into dot_product_attention .
Additionally, the part 2 PR will also address any larger changes suggested as part of this part 1 PR's review process.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Mar 6, 2025
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/refactor-pyt-attn-1 branch 2 times, most recently from 4a6ac72 to 9063a52 Compare March 9, 2025 19:20
@ptrendx ptrendx requested a review from cyanguwa March 10, 2025 21:45
@cyanguwa cyanguwa added the 2.2.0 label Mar 10, 2025
@cyanguwa
Copy link
Collaborator

cyanguwa commented Mar 11, 2025

I wonder if it makes sense to leave those flash-attn version checks/imports as is, for now. We'd have to think of a cleaner way to do it, but I feel FlashAttentionUtils might not cut it.

Also, could you have a look at the usage of inference.py and rope.py, see if it makes more sense to keep them in dot_product_attention/ or outside that directory? Thanks!

@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/refactor-pyt-attn-1 branch 3 times, most recently from 2b3e300 to c65f750 Compare March 12, 2025 17:07
@KshitijLakhani KshitijLakhani marked this pull request as ready for review March 12, 2025 17:11
@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1 L2 L3

@KshitijLakhani KshitijLakhani changed the title Refactoring attention.py Refactoring attention.py part 1 Mar 12, 2025
@yaox12
Copy link
Collaborator

yaox12 commented Mar 13, 2025

Also, could you have a look at the usage of inference.py and rope.py, see if it makes more sense to keep them in dot_product_attention/ or outside that directory? Thanks!

I'm thinking if we should create a folder naming like transformer_engine.pytorch.functional (similar to torch.nn.functional) for operators (here I mean what we usually use as autograd functions, not through layers or modules). And move RoPE, softmax, permutation, cross_entropy to it. How do you like it? @cyanguwa

@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/refactor-pyt-attn-1 branch from b2fa7e2 to 8e774d8 Compare March 13, 2025 17:19
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this probably is a good structure for our finished refactoring (maybe not now):

te/pytorch:
  - transformer.py: TELayer (~800 LOC)
  - multihead_attention.py: MHA (~800 LOC)
  - attention/
     - attention.py: DPA (~1000 LOC)
     - backends.py: FusedAttention, FlashAttention, UnfusedDPA (~1200 LOC)
     - context_parallel.py: P2P, A2A, AllGather, attn_with_cp (~3300 LOC)
     - softmax.py: only used in attention.py IIUC
     - utils.py: ~1600 LOC, but can move _SplitAlongDim, _combine_tesnsors to te/pytorch/utils.py (they can probably be merged with noop_cat, need investigation), and can condense PackTensors/UnpackTensors related funcs (one pack_tensors func instead of 3)
  - rope.py: ~200 LOC, but might get longer (a couple of PRs in the pipeline)
  - inference.py: ~50 LOC, but will get longer (PR 1355, 800 LOC)

Copy link
Collaborator Author

@KshitijLakhani KshitijLakhani Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.
Will address it in the part 2 PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that if we're keeping multihead_attention.py outside, we should rename the dir structure as follows:

te/pytorch:
  - transformer.py: TELayer (~800 LOC)
  - multihead_attention.py: MHA (~800 LOC)
  - dot_product_attention/
     - dot_product_attention.py: DPA (~1000 LOC)
     - backends.py: FusedAttention, FlashAttention, UnfusedDPA (~1200 LOC)
     - context_parallel.py: P2P, A2A, AllGather, attn_with_cp (~3300 LOC)
     - softmax.py: only used in attention.py IIUC
     - utils.py: ~1600 LOC, but can move _SplitAlongDim, _combine_tesnsors to te/pytorch/utils.py (they can probably be merged with noop_cat, need investigation), and can condense PackTensors/UnpackTensors related funcs (one pack_tensors func instead of 3)
  - rope.py: ~200 LOC, but might get longer (a couple of PRs in the pipeline)
  - inference.py: ~50 LOC, but will get longer (PR 1355, 800 LOC)

Or keep multihead_attention.py inside attention as follows. Keeping m_h_a.py as a file top level in addition to attention as directory with attention/attention.py as a file could be misleading.

te/pytorch:
  - transformer.py: TELayer (~800 LOC)
  - attention/
     - multihead_attention.py: MHA/GQA/MLA
     - dot_product_attention.py: DPA (~1000 LOC)
     - backends.py: FusedAttention, FlashAttention, UnfusedDPA (~1200 LOC)
     - context_parallel.py: P2P, A2A, AllGather, attn_with_cp (~3300 LOC)
     - softmax.py: only used in attention.py IIUC
     - utils.py: ~1600 LOC, but can move _SplitAlongDim, _combine_tesnsors to te/pytorch/utils.py (they can probably be merged with noop_cat, need investigation), and can condense PackTensors/UnpackTensors related funcs (one pack_tensors func instead of 3)
  - rope.py: ~200 LOC, but might get longer (a couple of PRs in the pipeline)
  - inference.py: ~50 LOC, but will get longer (PR 1355, 800 LOC)

The latter looks cleaner to me, if doesn't break more stuff than the former

Copy link
Collaborator

@cyanguwa cyanguwa Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The former looks better to me, the latter option's attention/ is a bit too chunky. But let's discuss this in Part 2, together with where rope.py, softmax.py, inference.py should be.

@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/refactor-pyt-attn-1 branch from 7dbe271 to b0ee442 Compare March 13, 2025 23:37
@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1 L2 L3

Copy link
Collaborator

@sudhakarsingh27 sudhakarsingh27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall lgtm since this is just setting up for Part 2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that if we're keeping multihead_attention.py outside, we should rename the dir structure as follows:

te/pytorch:
  - transformer.py: TELayer (~800 LOC)
  - multihead_attention.py: MHA (~800 LOC)
  - dot_product_attention/
     - dot_product_attention.py: DPA (~1000 LOC)
     - backends.py: FusedAttention, FlashAttention, UnfusedDPA (~1200 LOC)
     - context_parallel.py: P2P, A2A, AllGather, attn_with_cp (~3300 LOC)
     - softmax.py: only used in attention.py IIUC
     - utils.py: ~1600 LOC, but can move _SplitAlongDim, _combine_tesnsors to te/pytorch/utils.py (they can probably be merged with noop_cat, need investigation), and can condense PackTensors/UnpackTensors related funcs (one pack_tensors func instead of 3)
  - rope.py: ~200 LOC, but might get longer (a couple of PRs in the pipeline)
  - inference.py: ~50 LOC, but will get longer (PR 1355, 800 LOC)

Or keep multihead_attention.py inside attention as follows. Keeping m_h_a.py as a file top level in addition to attention as directory with attention/attention.py as a file could be misleading.

te/pytorch:
  - transformer.py: TELayer (~800 LOC)
  - attention/
     - multihead_attention.py: MHA/GQA/MLA
     - dot_product_attention.py: DPA (~1000 LOC)
     - backends.py: FusedAttention, FlashAttention, UnfusedDPA (~1200 LOC)
     - context_parallel.py: P2P, A2A, AllGather, attn_with_cp (~3300 LOC)
     - softmax.py: only used in attention.py IIUC
     - utils.py: ~1600 LOC, but can move _SplitAlongDim, _combine_tesnsors to te/pytorch/utils.py (they can probably be merged with noop_cat, need investigation), and can condense PackTensors/UnpackTensors related funcs (one pack_tensors func instead of 3)
  - rope.py: ~200 LOC, but might get longer (a couple of PRs in the pipeline)
  - inference.py: ~50 LOC, but will get longer (PR 1355, 800 LOC)

The latter looks cleaner to me, if doesn't break more stuff than the former

@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1 L2 L3

# See LICENSE for license information.

"""
Rotary Position Embedding implementation of different types along with hlper functions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "helper"

_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))


# ----Helper/Util classes and methods-----
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe remove these comments? L57 and L49? It doesn't look like we make these comments elsewhere in utils.py.

Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. After all CI passes, you can merge.

@@ -87,71 +74,45 @@
restore_from_saved,
)

# Import attention utils
import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
import transformer_engine.pytorch.dot_product_attention.inference as dpa_infer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could probably do "from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams"?

@cyanguwa
Copy link
Collaborator

Also, could you have a look at the usage of inference.py and rope.py, see if it makes more sense to keep them in dot_product_attention/ or outside that directory? Thanks!

I'm thinking if we should create a folder naming like transformer_engine.pytorch.functional (similar to torch.nn.functional) for operators (here I mean what we usually use as autograd functions, not through layers or modules). And move RoPE, softmax, permutation, cross_entropy to it. How do you like it? @cyanguwa

Let's have a think about this in Refactoring Part 2. @KshitijLakhani

KshitijLakhani and others added 4 commits March 14, 2025 13:19
Move attention logging into a separate class in pytorch/d_p_a/utils.py

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…ning info

Move versioning info out of pytorch/attention.py

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…_p_a/utils.py

Fix tests and imports for the above refactor change

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
KshitijLakhani and others added 22 commits March 14, 2025 13:19
…antizers() to d_p_a/utils.py

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
….py to d_p_a/utils.py

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…d_p_a/utils.py

Rename cumulative functions from using _cu_ to using _cumul_ to differentiate from CUDA cu calls protocol
Rename tensor packaging methods with leading underscore to make them as internal to file

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
….py to it

Modify tests and other files to import InferenceParams correctly

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>

Modify docs api for InferenceParams

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Code clean up

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Code clean up

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Use attn_log instead of att_log

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>

Fix lint error

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/refactor-pyt-attn-1 branch from 9654cb8 to d0bed1c Compare March 14, 2025 20:22
@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1 L2 L3

@KshitijLakhani KshitijLakhani merged commit 3733947 into NVIDIA:main Mar 14, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants