@@ -1906,6 +1906,7 @@ def get_fa_args(
1906
1906
dk = None ,
1907
1907
dv = None ,
1908
1908
):
1909
+ """Get forward/backward arguments for flash-attn v2 and v3."""
1909
1910
if use_flash_attn_3 :
1910
1911
if forward :
1911
1912
if qkv_format == "thd" :
@@ -1918,66 +1919,59 @@ def get_fa_args(
1918
1919
max_seqlen_kv ,
1919
1920
* [None ] * 8 , # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
1920
1921
]
1921
- else :
1922
- return [
1923
- * [None ] * 9 , # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k
1924
- max_seqlen_q ,
1925
- max_seqlen_kv ,
1926
- * [None ] * 8 , # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
1927
- ]
1928
- else :
1929
- if qkv_format == "thd" :
1930
- return [
1931
- cu_seqlens_q ,
1932
- cu_seqlens_kv ,
1933
- None , # sequed_q
1934
- None , # sequed_k
1935
- max_seqlen_q ,
1936
- max_seqlen_kv ,
1937
- dq ,
1938
- dk ,
1939
- dv ,
1940
- ]
1941
- else :
1942
- return [
1943
- None , # cu_seqlens_q
1944
- None , # cu_seqlens_kv
1945
- None , # sequed_q
1946
- None , # sequed_k
1947
- max_seqlen_q ,
1948
- max_seqlen_kv ,
1949
- dq ,
1950
- dk ,
1951
- dv ,
1952
- ]
1953
- else :
1954
- if forward :
1955
- if qkv_format == "thd" :
1956
- return [
1957
- cu_seqlens_q ,
1958
- cu_seqlens_kv ,
1959
- max_seqlen_q ,
1960
- max_seqlen_kv ,
1961
- ]
1962
- else :
1963
- return []
1964
- else :
1965
- if qkv_format == "thd" :
1966
- return [
1967
- dq ,
1968
- dk ,
1969
- dv ,
1970
- cu_seqlens_q ,
1971
- cu_seqlens_kv ,
1972
- max_seqlen_q ,
1973
- max_seqlen_kv ,
1974
- ]
1975
- else :
1976
- return [
1977
- dq ,
1978
- dk ,
1979
- dv ,
1980
- ]
1922
+ return [
1923
+ * [None ] * 9 , # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k
1924
+ max_seqlen_q ,
1925
+ max_seqlen_kv ,
1926
+ * [None ] * 8 , # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
1927
+ ]
1928
+ if qkv_format == "thd" :
1929
+ return [
1930
+ cu_seqlens_q ,
1931
+ cu_seqlens_kv ,
1932
+ None , # sequed_q
1933
+ None , # sequed_k
1934
+ max_seqlen_q ,
1935
+ max_seqlen_kv ,
1936
+ dq ,
1937
+ dk ,
1938
+ dv ,
1939
+ ]
1940
+ return [
1941
+ None , # cu_seqlens_q
1942
+ None , # cu_seqlens_kv
1943
+ None , # sequed_q
1944
+ None , # sequed_k
1945
+ max_seqlen_q ,
1946
+ max_seqlen_kv ,
1947
+ dq ,
1948
+ dk ,
1949
+ dv ,
1950
+ ]
1951
+ if forward :
1952
+ if qkv_format == "thd" :
1953
+ return [
1954
+ cu_seqlens_q ,
1955
+ cu_seqlens_kv ,
1956
+ max_seqlen_q ,
1957
+ max_seqlen_kv ,
1958
+ ]
1959
+ return []
1960
+ if qkv_format == "thd" :
1961
+ return [
1962
+ dq ,
1963
+ dk ,
1964
+ dv ,
1965
+ cu_seqlens_q ,
1966
+ cu_seqlens_kv ,
1967
+ max_seqlen_q ,
1968
+ max_seqlen_kv ,
1969
+ ]
1970
+ return [
1971
+ dq ,
1972
+ dk ,
1973
+ dv ,
1974
+ ]
1981
1975
1982
1976
1983
1977
class AttnFuncWithCPAndKVP2P (torch .autograd .Function ):
0 commit comments