Skip to content

Commit d0bed1c

Browse files
nit: Fix typos, explicit imports and remove extra comments
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
1 parent 9b8a37b commit d0bed1c

File tree

3 files changed

+5
-7
lines changed

3 files changed

+5
-7
lines changed

Diff for: transformer_engine/pytorch/attention.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676

7777
# Import attention utils
7878
import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
79-
import transformer_engine.pytorch.dot_product_attention.inference as dpa_infer
79+
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
8080
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
8181
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
8282
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
@@ -5384,7 +5384,7 @@ def forward(
53845384
core_attention_bias: Optional[torch.Tensor] = None,
53855385
alibi_slopes: Optional[torch.Tensor] = None,
53865386
fast_zero_fill: bool = True,
5387-
inference_params: Optional[dpa_infer.InferenceParams] = None,
5387+
inference_params: Optional[InferenceParams] = None,
53885388
pad_between_seqs: Optional[bool] = None,
53895389
) -> torch.Tensor:
53905390
"""
@@ -5545,7 +5545,7 @@ def forward(
55455545
to the attention score of query i and key j.
55465546
fast_zero_fill: bool, default = `True`
55475547
Whether to use the fast path to set output tensors to 0 or not.
5548-
inference_params: Optional[dpa_infer.InferenceParams], default = `None`
5548+
inference_params: Optional[InferenceParams], default = `None`
55495549
Optimizes execution performance during inference by caching Keys and Values of the
55505550
current decoding iteration. These cached values are appended to the K and V values
55515551
computed in previous iterations, eliminating the need to recalculate them for the
@@ -6501,7 +6501,7 @@ def forward(
65016501
window_size: Optional[Tuple[int, int]] = None,
65026502
is_first_microbatch: Optional[bool] = None,
65036503
checkpoint_core_attention: bool = False,
6504-
inference_params: Optional[dpa_infer.InferenceParams] = None,
6504+
inference_params: Optional[InferenceParams] = None,
65056505
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
65066506
core_attention_bias_type: str = "no_bias",
65076507
core_attention_bias: Optional[torch.Tensor] = None,

Diff for: transformer_engine/pytorch/dot_product_attention/rope.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# See LICENSE for license information.
44

55
"""
6-
Rotary Position Embedding implementation of different types along with hlper functions
6+
Rotary Position Embedding implementation of different types along with helper functions
77
"""
88
from typing import Optional, Tuple, Union
99
import torch

Diff for: transformer_engine/pytorch/dot_product_attention/utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,13 @@
4646

4747
from transformer_engine.pytorch.jit import jit_fuser
4848

49-
# ----Global constants----
5049
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
5150
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
5251
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
5352
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
5453
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
5554

5655

57-
# ----Helper/Util classes and methods-----
5856
class AttentionLogging:
5957
"""
6058
Manage logging for attention module

0 commit comments

Comments
 (0)