Skip to content

Commit 7f1c765

Browse files
committed
fix lint
Signed-off-by: Charlene Yang <[email protected]>
1 parent 496776b commit 7f1c765

File tree

1 file changed

+54
-60
lines changed

1 file changed

+54
-60
lines changed

transformer_engine/pytorch/attention.py

+54-60
Original file line numberDiff line numberDiff line change
@@ -1906,6 +1906,7 @@ def get_fa_args(
19061906
dk=None,
19071907
dv=None,
19081908
):
1909+
"""Get forward/backward arguments for flash-attn v2 and v3."""
19091910
if use_flash_attn_3:
19101911
if forward:
19111912
if qkv_format == "thd":
@@ -1918,66 +1919,59 @@ def get_fa_args(
19181919
max_seqlen_kv,
19191920
*[None] * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
19201921
]
1921-
else:
1922-
return [
1923-
*[None] * 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k
1924-
max_seqlen_q,
1925-
max_seqlen_kv,
1926-
*[None] * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
1927-
]
1928-
else:
1929-
if qkv_format == "thd":
1930-
return [
1931-
cu_seqlens_q,
1932-
cu_seqlens_kv,
1933-
None, # sequed_q
1934-
None, # sequed_k
1935-
max_seqlen_q,
1936-
max_seqlen_kv,
1937-
dq,
1938-
dk,
1939-
dv,
1940-
]
1941-
else:
1942-
return [
1943-
None, # cu_seqlens_q
1944-
None, # cu_seqlens_kv
1945-
None, # sequed_q
1946-
None, # sequed_k
1947-
max_seqlen_q,
1948-
max_seqlen_kv,
1949-
dq,
1950-
dk,
1951-
dv,
1952-
]
1953-
else:
1954-
if forward:
1955-
if qkv_format == "thd":
1956-
return [
1957-
cu_seqlens_q,
1958-
cu_seqlens_kv,
1959-
max_seqlen_q,
1960-
max_seqlen_kv,
1961-
]
1962-
else:
1963-
return []
1964-
else:
1965-
if qkv_format == "thd":
1966-
return [
1967-
dq,
1968-
dk,
1969-
dv,
1970-
cu_seqlens_q,
1971-
cu_seqlens_kv,
1972-
max_seqlen_q,
1973-
max_seqlen_kv,
1974-
]
1975-
else:
1976-
return [
1977-
dq,
1978-
dk,
1979-
dv,
1980-
]
1922+
return [
1923+
*[None] * 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k
1924+
max_seqlen_q,
1925+
max_seqlen_kv,
1926+
*[None] * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
1927+
]
1928+
if qkv_format == "thd":
1929+
return [
1930+
cu_seqlens_q,
1931+
cu_seqlens_kv,
1932+
None, # sequed_q
1933+
None, # sequed_k
1934+
max_seqlen_q,
1935+
max_seqlen_kv,
1936+
dq,
1937+
dk,
1938+
dv,
1939+
]
1940+
return [
1941+
None, # cu_seqlens_q
1942+
None, # cu_seqlens_kv
1943+
None, # sequed_q
1944+
None, # sequed_k
1945+
max_seqlen_q,
1946+
max_seqlen_kv,
1947+
dq,
1948+
dk,
1949+
dv,
1950+
]
1951+
if forward:
1952+
if qkv_format == "thd":
1953+
return [
1954+
cu_seqlens_q,
1955+
cu_seqlens_kv,
1956+
max_seqlen_q,
1957+
max_seqlen_kv,
1958+
]
1959+
return []
1960+
if qkv_format == "thd":
1961+
return [
1962+
dq,
1963+
dk,
1964+
dv,
1965+
cu_seqlens_q,
1966+
cu_seqlens_kv,
1967+
max_seqlen_q,
1968+
max_seqlen_kv,
1969+
]
1970+
return [
1971+
dq,
1972+
dk,
1973+
dv,
1974+
]
19811975

19821976

19831977
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):

0 commit comments

Comments
 (0)