Skip to content

Commit 8e774d8

Browse files
nit: Resolve lint errors
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
1 parent 86a206c commit 8e774d8

File tree

4 files changed

+14
-3
lines changed

4 files changed

+14
-3
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from packaging.version import Version as PkgVersion
1818

1919
import torch
20-
import torch.nn.functional as F
2120

2221
import transformer_engine_torch as tex
2322
from transformer_engine.pytorch.utils import (

transformer_engine/pytorch/dot_product_attention/inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#
33
# See LICENSE for license information.
44

5+
"""
6+
Inference classes for attention
7+
"""
8+
59

610
class InferenceParams: # pylint: disable=too-few-public-methods
711
"""

transformer_engine/pytorch/dot_product_attention/rope.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#
33
# See LICENSE for license information.
44

5+
"""
6+
Rotary Position Embedding implementation of different types along with hlper functions
7+
"""
58
from typing import Optional, Tuple, Union
69
import torch
710
import transformer_engine_torch as tex

transformer_engine/pytorch/dot_product_attention/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#
33
# See LICENSE for license information.
44

5+
"""
6+
Utils/Helper classes and methods for attention
7+
"""
58
import math
69
import os
710
from typing import Any, Dict, List, Optional, Tuple, Union, TypeAlias
@@ -60,7 +63,6 @@ class AttentionLogging:
6063
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
6164
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
6265
_stream_handler = logging.StreamHandler()
63-
# TODO: Move fa_logger to FAUtils
6466
fa_logger = logging.getLogger(__name__)
6567

6668
@staticmethod
@@ -1348,7 +1350,10 @@ def _unpack_3_tensors(
13481350

13491351

13501352
class PackTensors(torch.autograd.Function):
1351-
1353+
"""
1354+
Autograd function to pack a tensor.
1355+
"""
1356+
13521357
@staticmethod
13531358
def forward(
13541359
ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]

0 commit comments

Comments
 (0)