Skip to content

Commit b0ee442

Browse files
Remove typedef FAUtils for FlashAttentionUtils
Use attn_log instead of att_log Signed-off-by: Kshitij Janardan Lakhani <[email protected]> Fix lint error Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
1 parent 7e646b9 commit b0ee442

File tree

4 files changed

+71
-75
lines changed

4 files changed

+71
-75
lines changed

tests/pytorch/fused_attn/test_fused_attn.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
_attention_backends,
2222
)
2323
from transformer_engine.pytorch.dot_product_attention.utils import (
24-
FAUtils,
24+
FlashAttentionUtils,
2525
get_attention_backend,
2626
check_set_window_size,
2727
AttentionParams,
@@ -280,12 +280,12 @@ def test_dot_product_attention(
280280
# mannually pads and unpads the input and output of FlashAttention for testing purposes
281281
if (
282282
pad_between_seqs
283-
and FAUtils.is_installed
283+
and FlashAttentionUtils.is_installed
284284
and not (
285285
config.max_seqlen_q != config.max_seqlen_kv
286286
and config.attn_mask_type in ["causal", "padding_causal"]
287287
)
288-
and (config.window_size[0] == -1 or FAUtils.v2_3_plus)
288+
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
289289
):
290290
flash_attn_supported = True
291291

@@ -592,7 +592,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
592592
}
593593

594594

595-
@pytest.mark.skipif(not FAUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
595+
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
596596
@pytest.mark.parametrize("dtype", param_types_lean)
597597
@pytest.mark.parametrize("model_configs", [model_configs_swa])
598598
@pytest.mark.parametrize("model", model_configs_swa.keys())
@@ -614,7 +614,7 @@ def test_dpa_sliding_window(dtype, model_configs, model):
614614
}
615615

616616

617-
@pytest.mark.skipif(not FAUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
617+
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
618618
@pytest.mark.parametrize("dtype", param_types_lean)
619619
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
620620
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
@@ -1456,7 +1456,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
14561456
):
14571457
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
14581458

1459-
if FAUtils.v3_is_installed and not is_training and "padding" not in config.attn_mask_type:
1459+
if FlashAttentionUtils.v3_is_installed and not is_training and "padding" not in config.attn_mask_type:
14601460
os.environ["NVTE_FLASH_ATTN"] = "1"
14611461
os.environ["NVTE_FUSED_ATTN"] = "0"
14621462
_attention_backends["backend_selection_requires_update"] = True
@@ -1482,7 +1482,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
14821482
rtol = 5e-1
14831483
rmse_tol = 0.15
14841484
logging.debug("========== {:^25s} ==========".format("forward output"))
1485-
if FAUtils.v3_is_installed and not is_training and "padding" not in config.attn_mask_type:
1485+
if FlashAttentionUtils.v3_is_installed and not is_training and "padding" not in config.attn_mask_type:
14861486
_error(
14871487
flash_attn_fwd_fp8,
14881488
fused_attn_fwd_f16,
@@ -1667,7 +1667,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
16671667
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
16681668
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
16691669

1670-
if FAUtils.v3_is_installed and not is_training and "padding" not in config.attn_mask_type:
1670+
if FlashAttentionUtils.v3_is_installed and not is_training and "padding" not in config.attn_mask_type:
16711671
os.environ["NVTE_FLASH_ATTN"] = "1"
16721672
os.environ["NVTE_FUSED_ATTN"] = "0"
16731673
_attention_backends["backend_selection_requires_update"] = True
@@ -1696,7 +1696,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
16961696
rmse_tol = 0.11
16971697
bwd_names = ["dq", "dk", "dv"]
16981698
logging.debug("========== {:^25s} ==========".format("forward output"))
1699-
if FAUtils.v3_is_installed and not is_training and "padding" not in config.attn_mask_type:
1699+
if FlashAttentionUtils.v3_is_installed and not is_training and "padding" not in config.attn_mask_type:
17001700
_error(
17011701
flash_attn_fwd_fp8,
17021702
fused_attn_fwd_f16,

tests/pytorch/fused_attn/test_fused_attn_with_cp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
get_device_compute_capability,
1212
get_cudnn_version,
1313
)
14-
from transformer_engine.pytorch.dot_product_attention.utils import FAUtils
14+
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils
1515
from test_fused_attn import ModelConfig
1616

1717
model_configs_flash_attn = {
@@ -50,7 +50,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
5050
return args
5151

5252

53-
@pytest.mark.skipif(not FAUtils.v2_plus, reason="Flash-attn 2.0+ is required.")
53+
@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.")
5454
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
5555
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
5656
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())

transformer_engine/pytorch/attention.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@
7878
import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
7979
import transformer_engine.pytorch.dot_product_attention.inference as dpa_infer
8080
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
81-
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as att_log
81+
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
8282
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
8383

8484

8585
# Setup Attention Logging
86-
att_log.setup_logging()
86+
attn_log.setup_logging()
8787

8888
# Global vars for flash attn imports
8989
flash_attn_cuda_bwd = None
@@ -101,7 +101,7 @@
101101
and get_device_compute_capability() >= (8, 0)
102102
and dpa_utils._NVTE_FLASH_ATTN
103103
):
104-
att_log.fa_logger.debug(
104+
attn_log.fa_logger.debug(
105105
"flash-attn v2 is not installed. To use, please install it by"
106106
""" "pip3 install flash-attn".""",
107107
)
@@ -131,7 +131,7 @@
131131
and get_device_compute_capability() >= (8, 0)
132132
and dpa_utils._NVTE_FLASH_ATTN
133133
):
134-
att_log.fa_logger.warning(
134+
attn_log.fa_logger.warning(
135135
"Supported flash-attn versions are %s. Found flash-attn %s.",
136136
dpa_utils._get_supported_versions(
137137
(
@@ -155,7 +155,7 @@
155155
and get_device_compute_capability() >= (9, 0)
156156
and dpa_utils._NVTE_FLASH_ATTN
157157
):
158-
att_log.fa_logger.debug(
158+
attn_log.fa_logger.debug(
159159
"flash-attn v3 is not installed. To use, please install it by \n%s",
160160
fa_utils.v3_installation_steps,
161161
)
@@ -3926,9 +3926,9 @@ def __init__(
39263926
self.layer_number = 1 if layer_number is None else layer_number
39273927
self.deterministic = deterministic
39283928
self.logger = logging.getLogger("FlashAttention")
3929-
self.logger.setLevel(att_log._log_level)
3929+
self.logger.setLevel(attn_log._log_level)
39303930
if not self.logger.hasHandlers():
3931-
self.logger.addHandler(att_log._stream_handler)
3931+
self.logger.addHandler(attn_log._stream_handler)
39323932

39333933
def forward(
39343934
self,
@@ -5157,9 +5157,9 @@ def __init__(
51575157
super().__init__()
51585158

51595159
self.logger = logging.getLogger("DotProductAttention")
5160-
self.logger.setLevel(att_log._log_level)
5160+
self.logger.setLevel(attn_log._log_level)
51615161
if not self.logger.hasHandlers():
5162-
self.logger.addHandler(att_log._stream_handler)
5162+
self.logger.addHandler(attn_log._stream_handler)
51635163
self.qkv_format = qkv_format
51645164
attn_mask_type = attn_mask_type.replace(",", "_")
51655165
if attn_mask_type == "causal_padding":

0 commit comments

Comments
 (0)