|
76 | 76 |
|
77 | 77 | # Import attention utils
|
78 | 78 | import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
|
79 |
| -import transformer_engine.pytorch.dot_product_attention.inference as dpa_infer |
| 79 | +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams |
80 | 80 | from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
|
81 | 81 | from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
|
82 | 82 | from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
|
@@ -5384,7 +5384,7 @@ def forward(
|
5384 | 5384 | core_attention_bias: Optional[torch.Tensor] = None,
|
5385 | 5385 | alibi_slopes: Optional[torch.Tensor] = None,
|
5386 | 5386 | fast_zero_fill: bool = True,
|
5387 |
| - inference_params: Optional[dpa_infer.InferenceParams] = None, |
| 5387 | + inference_params: Optional[InferenceParams] = None, |
5388 | 5388 | pad_between_seqs: Optional[bool] = None,
|
5389 | 5389 | ) -> torch.Tensor:
|
5390 | 5390 | """
|
@@ -5545,7 +5545,7 @@ def forward(
|
5545 | 5545 | to the attention score of query i and key j.
|
5546 | 5546 | fast_zero_fill: bool, default = `True`
|
5547 | 5547 | Whether to use the fast path to set output tensors to 0 or not.
|
5548 |
| - inference_params: Optional[dpa_infer.InferenceParams], default = `None` |
| 5548 | + inference_params: Optional[InferenceParams], default = `None` |
5549 | 5549 | Optimizes execution performance during inference by caching Keys and Values of the
|
5550 | 5550 | current decoding iteration. These cached values are appended to the K and V values
|
5551 | 5551 | computed in previous iterations, eliminating the need to recalculate them for the
|
@@ -6501,7 +6501,7 @@ def forward(
|
6501 | 6501 | window_size: Optional[Tuple[int, int]] = None,
|
6502 | 6502 | is_first_microbatch: Optional[bool] = None,
|
6503 | 6503 | checkpoint_core_attention: bool = False,
|
6504 |
| - inference_params: Optional[dpa_infer.InferenceParams] = None, |
| 6504 | + inference_params: Optional[InferenceParams] = None, |
6505 | 6505 | rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
6506 | 6506 | core_attention_bias_type: str = "no_bias",
|
6507 | 6507 | core_attention_bias: Optional[torch.Tensor] = None,
|
|
0 commit comments