-
Notifications
You must be signed in to change notification settings - Fork 379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring attention.py part 1 #1542
Refactoring attention.py part 1 #1542
Conversation
4a6ac72
to
9063a52
Compare
I wonder if it makes sense to leave those flash-attn version checks/imports as is, for now. We'd have to think of a cleaner way to do it, but I feel Also, could you have a look at the usage of |
2b3e300
to
c65f750
Compare
/te-ci pytorch L0 L1 L2 L3 |
I'm thinking if we should create a folder naming like |
b2fa7e2
to
8e774d8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this probably is a good structure for our finished refactoring (maybe not now):
te/pytorch:
- transformer.py: TELayer (~800 LOC)
- multihead_attention.py: MHA (~800 LOC)
- attention/
- attention.py: DPA (~1000 LOC)
- backends.py: FusedAttention, FlashAttention, UnfusedDPA (~1200 LOC)
- context_parallel.py: P2P, A2A, AllGather, attn_with_cp (~3300 LOC)
- softmax.py: only used in attention.py IIUC
- utils.py: ~1600 LOC, but can move _SplitAlongDim, _combine_tesnsors to te/pytorch/utils.py (they can probably be merged with noop_cat, need investigation), and can condense PackTensors/UnpackTensors related funcs (one pack_tensors func instead of 3)
- rope.py: ~200 LOC, but might get longer (a couple of PRs in the pipeline)
- inference.py: ~50 LOC, but will get longer (PR 1355, 800 LOC)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good.
Will address it in the part 2 PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that if we're keeping multihead_attention.py
outside, we should rename the dir structure as follows:
te/pytorch:
- transformer.py: TELayer (~800 LOC)
- multihead_attention.py: MHA (~800 LOC)
- dot_product_attention/
- dot_product_attention.py: DPA (~1000 LOC)
- backends.py: FusedAttention, FlashAttention, UnfusedDPA (~1200 LOC)
- context_parallel.py: P2P, A2A, AllGather, attn_with_cp (~3300 LOC)
- softmax.py: only used in attention.py IIUC
- utils.py: ~1600 LOC, but can move _SplitAlongDim, _combine_tesnsors to te/pytorch/utils.py (they can probably be merged with noop_cat, need investigation), and can condense PackTensors/UnpackTensors related funcs (one pack_tensors func instead of 3)
- rope.py: ~200 LOC, but might get longer (a couple of PRs in the pipeline)
- inference.py: ~50 LOC, but will get longer (PR 1355, 800 LOC)
Or keep multihead_attention.py
inside attention as follows. Keeping m_h_a.py
as a file top level in addition to attention as directory with attention/attention.py as a file could be misleading.
te/pytorch:
- transformer.py: TELayer (~800 LOC)
- attention/
- multihead_attention.py: MHA/GQA/MLA
- dot_product_attention.py: DPA (~1000 LOC)
- backends.py: FusedAttention, FlashAttention, UnfusedDPA (~1200 LOC)
- context_parallel.py: P2P, A2A, AllGather, attn_with_cp (~3300 LOC)
- softmax.py: only used in attention.py IIUC
- utils.py: ~1600 LOC, but can move _SplitAlongDim, _combine_tesnsors to te/pytorch/utils.py (they can probably be merged with noop_cat, need investigation), and can condense PackTensors/UnpackTensors related funcs (one pack_tensors func instead of 3)
- rope.py: ~200 LOC, but might get longer (a couple of PRs in the pipeline)
- inference.py: ~50 LOC, but will get longer (PR 1355, 800 LOC)
The latter looks cleaner to me, if doesn't break more stuff than the former
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The former looks better to me, the latter option's attention/
is a bit too chunky. But let's discuss this in Part 2, together with where rope.py
, softmax.py
, inference.py
should be.
7dbe271
to
b0ee442
Compare
/te-ci pytorch L0 L1 L2 L3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall lgtm since this is just setting up for Part 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that if we're keeping multihead_attention.py
outside, we should rename the dir structure as follows:
te/pytorch:
- transformer.py: TELayer (~800 LOC)
- multihead_attention.py: MHA (~800 LOC)
- dot_product_attention/
- dot_product_attention.py: DPA (~1000 LOC)
- backends.py: FusedAttention, FlashAttention, UnfusedDPA (~1200 LOC)
- context_parallel.py: P2P, A2A, AllGather, attn_with_cp (~3300 LOC)
- softmax.py: only used in attention.py IIUC
- utils.py: ~1600 LOC, but can move _SplitAlongDim, _combine_tesnsors to te/pytorch/utils.py (they can probably be merged with noop_cat, need investigation), and can condense PackTensors/UnpackTensors related funcs (one pack_tensors func instead of 3)
- rope.py: ~200 LOC, but might get longer (a couple of PRs in the pipeline)
- inference.py: ~50 LOC, but will get longer (PR 1355, 800 LOC)
Or keep multihead_attention.py
inside attention as follows. Keeping m_h_a.py
as a file top level in addition to attention as directory with attention/attention.py as a file could be misleading.
te/pytorch:
- transformer.py: TELayer (~800 LOC)
- attention/
- multihead_attention.py: MHA/GQA/MLA
- dot_product_attention.py: DPA (~1000 LOC)
- backends.py: FusedAttention, FlashAttention, UnfusedDPA (~1200 LOC)
- context_parallel.py: P2P, A2A, AllGather, attn_with_cp (~3300 LOC)
- softmax.py: only used in attention.py IIUC
- utils.py: ~1600 LOC, but can move _SplitAlongDim, _combine_tesnsors to te/pytorch/utils.py (they can probably be merged with noop_cat, need investigation), and can condense PackTensors/UnpackTensors related funcs (one pack_tensors func instead of 3)
- rope.py: ~200 LOC, but might get longer (a couple of PRs in the pipeline)
- inference.py: ~50 LOC, but will get longer (PR 1355, 800 LOC)
The latter looks cleaner to me, if doesn't break more stuff than the former
/te-ci pytorch L0 L1 L2 L3 |
# See LICENSE for license information. | ||
|
||
""" | ||
Rotary Position Embedding implementation of different types along with hlper functions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "helper"
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) | ||
|
||
|
||
# ----Helper/Util classes and methods----- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe remove these comments? L57 and L49? It doesn't look like we make these comments elsewhere in utils.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. After all CI passes, you can merge.
@@ -87,71 +74,45 @@ | |||
restore_from_saved, | |||
) | |||
|
|||
# Import attention utils | |||
import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils | |||
import transformer_engine.pytorch.dot_product_attention.inference as dpa_infer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: could probably do "from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams"?
Let's have a think about this in Refactoring Part 2. @KshitijLakhani |
Move attention logging into a separate class in pytorch/d_p_a/utils.py Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…ning info Move versioning info out of pytorch/attention.py Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…_p_a/utils.py Fix tests and imports for the above refactor change Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
…antizers() to d_p_a/utils.py Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
….py to d_p_a/utils.py Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
…d_p_a/utils.py Rename cumulative functions from using _cu_ to using _cumul_ to differentiate from CUDA cu calls protocol Rename tensor packaging methods with leading underscore to make them as internal to file Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
….py to it Modify tests and other files to import InferenceParams correctly Signed-off-by: Kshitij Janardan Lakhani <[email protected]> Modify docs api for InferenceParams Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
…to it Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Code clean up Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Code clean up Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Use attn_log instead of att_log Signed-off-by: Kshitij Janardan Lakhani <[email protected]> Fix lint error Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
9654cb8
to
d0bed1c
Compare
/te-ci pytorch L0 L1 L2 L3 |
Description
attention.py
has grown in size to 8500+ lines of code and hence the motivation to refactor the code. This PR is part 1 (of a 2 part PR effort) to refactorattention.py
into sub-modules for ease of development and testing.Fixes # (issue)
Type of change
Changes
Create a new module,
dot_product_attention
which contains sub modulesinference.py
,rope.py
andutils.py
and move the appropriate methods fromattention.py
into these 3 sub modules. The details of these methods are in a collapsed list below for those interested:List of Functions Moved/Added
rope.py (~200 lines of code)
class RotaryPositionEmbedding(torch.nn.Module)
class FusedRoPEFunc(torch.autograd.Function):
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
def apply_rotary_pos_emb(
inference.py (estimated to be ~1000 when finished)
class InferenceParams: # pylint: disable=too-few-public-methods
utils.py (~1500 lines of code)
class AttentionParams: Agree with a different class. Unsure of it being in utils.py
def get_attention_backend(
def get_full_mask(
def get_alibi(
def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
def _get_full_cu_seqlens(
def pack_tensor(
def pack_2_tensors(
def pack_3_tensors(
def unpack_tensor(
def unpack_2_tensors(
def unpack_3_tensors(
class PackTensors(torch.autograd.Function):
class UnpackTensor(torch.autograd.Function):
def get_qkv_layout(
def check_set_window_size(
def get_attention_quantizers - Added new clas
Create AttentionLogging class - Added new class
Create FlashAttentioUtils class - Added new class
Notable function signature changes(comments have been added for these in the respective functions) :
get_attention_backend()
- This will not populate the global_attention_backends
cache now but rather rest the responsibility of populating the global_attention_backends
on the caller of this functionget_alibi()
- This will now accept the global_alibi_cache
as a function parameter from the caller and read/write to the_alibi_cache
cache from within the functionTODO: Refactoring part 2 PR will create new modules/submodules for
MultiHeadAttention.py
,context_parallelism
while moving some more generic pytorch utils functions topytorch/utils.py
and also moving whatever is leftover ofattention.py
intodot_product_attention
.Additionally, the part 2 PR will also address any larger changes suggested as part of this part 1 PR's review process.
Checklist: