diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index ff86d012c0..d393ced56a 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -36,7 +36,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, void *softmax_lse_d, float p_dropout, float softmax_scale, - bool is_causal) { + bool is_causal, + void *glm_mask_d) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -103,6 +104,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.is_causal = is_causal; params.is_seqlens_k_cumulative = true; + params.is_glm_causal = ! (glm_mask_d == nullptr); + params.glm_mask = static_cast(glm_mask_d); } void set_params_dgrad(Flash_bwd_params ¶ms, @@ -134,7 +137,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, void *dsoftmax_sum_d, float p_dropout, float softmax_scale, - bool is_causal) { + bool is_causal, + void *glm_mask_d) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, @@ -145,7 +149,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, softmax_lse_d, p_dropout, softmax_scale, - is_causal); + is_causal, + glm_mask_d); // Set the pointers and strides. params.do_ptr = dout.data_ptr(); @@ -239,7 +244,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head const float softmax_scale, const bool is_causal, const bool return_softmax, - c10::optional gen_) { + c10::optional gen_, + const c10::optional &glm_mask // batch_size + ) { auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -325,6 +332,15 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); } + // glm_mask support + if (glm_mask.has_value()) { + TORCH_CHECK(is_causal, "is_causal must be true"); + TORCH_CHECK(glm_mask.value().dtype() == torch::kInt32, "glm_mask must have dtype int32"); + TORCH_CHECK(glm_mask.value().is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(glm_mask.value().is_contiguous(), "glm_mask must be contiguous"); + CHECK_SHAPE(glm_mask.value(), batch_size); + } + Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -339,7 +355,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head softmax_lse.data_ptr(), p_dropout, softmax_scale, - is_causal); + is_causal, + glm_mask.has_value() ? glm_mask->data_ptr() : nullptr); // This needs to match with run_mha_fwd_splitkv_dispatch const int block_n = is_sm90 || is_sm8x @@ -403,7 +420,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const bool zero_tensors, const bool is_causal, const bool return_softmax, - c10::optional gen_) { + c10::optional gen_, + const c10::optional &glm_mask // batch_size + ) { auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -503,6 +522,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q if (return_softmax) {p.zero_();} } + // glm_mask support + if (glm_mask.has_value()) { + TORCH_CHECK(is_causal, "is_causal must be true"); + TORCH_CHECK(glm_mask.value().dtype() == torch::kInt32, "glm_mask must have dtype int32"); + TORCH_CHECK(glm_mask.value().is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(glm_mask.value().is_contiguous(), "glm_mask must be contiguous"); + CHECK_SHAPE(glm_mask.value(), batch_size); + } + Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -517,7 +545,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q softmax_lse.data_ptr(), p_dropout, softmax_scale, - is_causal); + is_causal, + glm_mask.has_value() ? glm_mask->data_ptr() : nullptr); // number of times random will be generated per thread, to offset philox counter in thc random // state @@ -584,6 +613,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const float softmax_scale, const bool is_causal, c10::optional gen_, + const c10::optional &glm_mask, // batch_size c10::optional &rng_state) { auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -714,6 +744,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si dv_expanded = dv; } + // glm_mask support + if (glm_mask.has_value()) { + TORCH_CHECK(is_causal, "is_causal must be true"); + TORCH_CHECK(glm_mask.value().dtype() == torch::kInt32, "glm_mask must have dtype int32"); + TORCH_CHECK(glm_mask.value().is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(glm_mask.value().is_contiguous(), "glm_mask must be contiguous"); + CHECK_SHAPE(glm_mask.value(), batch_size); + } + Flash_bwd_params params; set_params_dgrad(params, @@ -735,7 +774,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si softmax_d.data_ptr(), p_dropout, softmax_scale, - is_causal); + is_causal, + glm_mask.has_value() ? glm_mask->data_ptr() : nullptr); auto launch = &run_mha_bwd; // launch(params, stream, /*configure=*/true); @@ -792,6 +832,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const bool zero_tensors, const bool is_causal, c10::optional gen_, + const c10::optional &glm_mask, // batch_size c10::optional &rng_state ) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -934,6 +975,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size softmax_d.zero_(); } + + // glm_mask support + if (glm_mask.has_value()) { + TORCH_CHECK(is_causal, "is_causal must be true"); + TORCH_CHECK(glm_mask.value().dtype() == torch::kInt32, "glm_mask must have dtype int32"); + TORCH_CHECK(glm_mask.value().is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(glm_mask.value().is_contiguous(), "glm_mask must be contiguous"); + CHECK_SHAPE(glm_mask.value(), batch_size); + } + Flash_bwd_params params; set_params_dgrad(params, @@ -953,7 +1004,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size softmax_d.data_ptr(), p_dropout, softmax_scale, - is_causal); + is_causal, + glm_mask.has_value() ? glm_mask->data_ptr() : nullptr); auto launch = &run_mha_bwd; // launch(params, stream, /*configure=*/true); diff --git a/csrc/flash_attn/src/block_info.h b/csrc/flash_attn/src/block_info.h index 18793e9c8e..4ff84c4517 100644 --- a/csrc/flash_attn/src/block_info.h +++ b/csrc/flash_attn/src/block_info.h @@ -19,7 +19,9 @@ struct BlockInfo { // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - , actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + // origin logic, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) + , break_point(params.is_glm_causal ? params.glm_mask[bidb] : 0) { } @@ -35,10 +37,11 @@ struct BlockInfo { const int sum_s_q; const int sum_s_k; - const int actual_seqlen_q; // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. const int seqlen_k_cache; - const int actual_seqlen_k; + const uint32_t actual_seqlen_q; + const uint32_t actual_seqlen_k; + const int break_point; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index ff49cb8e19..5b312036a6 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -116,6 +116,9 @@ struct Flash_fwd_params : public Qkv_params { bool is_seqlens_k_cumulative; int num_splits; // For split-KV version + // glm mask + bool is_glm_causal; + int * __restrict__ glm_mask; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index fc5724c90d..cb5bd929c5 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -654,14 +654,50 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded; int m_block = m_block_max - 1; - int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM); - // We're guaranteed that m_block_min <= m_block: - // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case, - // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q. - // So m_block_min <= (actual_seqlen_q - 1) / kBlockM. - // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM. - // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM. - // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop. + int m_block_min = 0; + if (Is_causal) { + m_block_min = std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM); + if (params.is_glm_causal) { + m_block_min = binfo.break_point > n_block * kBlockN ? 0 : m_block_min; + } + } + + // We might need to exit early and write 0 to dK and dV. + // Otherwise we get wrong result for the case where we don't enter the for loop. + // And we might read OOB elements from gQ and gdO. + // TODO: what if we're not parallelizing, do we need to compute dot_do_o? + if (Is_causal && m_block < m_block_min) { + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + auto gmem_thr_copy_dKV = typename Kernel_traits::GmemTiledCopydKV{}.get_thread_slice(tidx); + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + clear(tdKrdK); + clear(tdVrdV); + Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + flash::copy( + gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + return; + } if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ tQsQ.data() = tQsQ.data() + size(sQ); @@ -792,7 +828,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), binfo.actual_seqlen_q, // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, - AtomLayoutMS * 16); + AtomLayoutMS * 16, binfo.break_point); } } // if (cute::thread(32, 0)) { print(scores); } @@ -1338,7 +1374,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, binfo.actual_seqlen_q, - AtomLayoutMS * 16); + AtomLayoutMS * 16, binfo.break_point); } // Compute the exponential value. flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index c0e3df5559..bca189522d 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -142,8 +142,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + if (params.is_glm_causal) { + n_block_max = std::max(n_block_max, cute::ceil_div(binfo.break_point, kBlockN)); + } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } @@ -426,7 +428,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // m_block * kBlockM + get<0>(idx_row(0)), m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, binfo.actual_seqlen_q, - kNWarps * 16); + kNWarps * 16, binfo.break_point); // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); } diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index f72313add3..2b7315d97b 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -141,9 +141,9 @@ inline __device__ void apply_mask(Tensor &tensor, const int max_ } template -inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride) { +inline __device__ void apply_mask_causal(Tensor &tensor, const uint32_t col_idx_offset_, + const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, + const uint32_t warp_row_stride, const uint32_t break_point) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; @@ -162,8 +162,8 @@ inline __device__ void apply_mask_causal(Tensor &tensor, const i const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + const uint32_t col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit && col_idx >= break_point) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } diff --git a/flash_attn/flash_attention.py b/flash_attn/flash_attention.py new file mode 100644 index 0000000000..13e326c6aa --- /dev/null +++ b/flash_attn/flash_attention.py @@ -0,0 +1,105 @@ +""" +this file is deleted in v2.0 release, recover and update it for api compat +""" + +import math +import torch +import torch.nn as nn + +from einops import rearrange + +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +class FlashAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, + max_s=None, need_weights=False): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None + if unpadded: (nnz, 3, h, d) + key_padding_mask: a bool tensor of shape (B, S) + """ + assert not need_weights + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + + if cu_seqlens is None: + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=qkv.device) + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + output_unpad = flash_attn_unpadded_qkvpacked_func( + x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices, batch_size, seqlen), + 'b s (h d) -> b s h d', h=nheads) + else: + assert max_s is not None + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + + return output, None + + +class FlashMHA(nn.Module): + + def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0, + causal=False, device=None, dtype=None) -> None: + assert batch_first + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.embed_dim = embed_dim + self.causal = causal + + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + self.inner_attn = FlashAttention(attention_dropout=attention_dropout) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + + def forward(self, x, key_padding_mask=None, need_weights=False): + """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) + key_padding_mask: bool tensor of shape (batch, seqlen) + """ + qkv = self.Wqkv(x) + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) + context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask, + need_weights=need_weights, causal=self.causal) + return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index d3c13054ab..e91b3dff87 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -41,11 +41,11 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): return (128, 64) if is_sm80 else (64, 64) -def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): +def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax, glm_mask): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( - q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None + q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None, glm_mask ) return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state @@ -62,6 +62,7 @@ def _flash_attn_varlen_forward( softmax_scale, causal, return_softmax, + glm_mask ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -80,6 +81,7 @@ def _flash_attn_varlen_forward( causal, return_softmax, None, + glm_mask ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -87,7 +89,7 @@ def _flash_attn_varlen_forward( def _flash_attn_backward( - dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, rng_state=None + dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, glm_mask, rng_state=None ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous @@ -106,6 +108,7 @@ def _flash_attn_backward( softmax_scale, causal, None, + glm_mask, rng_state, ) return dq, dk, dv, softmax_d @@ -128,6 +131,7 @@ def _flash_attn_varlen_backward( dropout_p, softmax_scale, causal, + glm_mask, rng_state=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x @@ -152,6 +156,7 @@ def _flash_attn_varlen_backward( False, causal, None, + glm_mask, rng_state, ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): @@ -161,7 +166,7 @@ def _flash_attn_varlen_backward( class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): + def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax, glm_mask): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( @@ -172,8 +177,9 @@ def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, + glm_mask=glm_mask ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, glm_mask) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -181,7 +187,7 @@ def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state, glm_mask = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) _flash_attn_backward( @@ -197,15 +203,16 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + glm_mask, rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None + return dqkv, None, None, None, None, None class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): + def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax, glm_mask): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( @@ -220,8 +227,9 @@ def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, + glm_mask=glm_mask ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state, glm_mask) ctx.dropout_p = dropout_p ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale @@ -230,7 +238,7 @@ def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens, rng_state, glm_mask = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) _flash_attn_varlen_backward( @@ -250,15 +258,16 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + glm_mask, rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): + def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax, glm_mask): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( @@ -269,8 +278,9 @@ def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, + glm_mask=glm_mask ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, glm_mask) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -278,7 +288,7 @@ def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state, glm_mask = ctx.saved_tensors dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) @@ -295,11 +305,12 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + glm_mask, rng_state=rng_state, ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., :dout.shape[-1]] + return dq, dkv, None, None, None, None, None class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): @@ -316,6 +327,7 @@ def forward( softmax_scale, causal, return_softmax, + glm_mask, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -331,9 +343,10 @@ def forward( softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, + glm_mask=glm_mask ) ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, glm_mask ) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q @@ -344,7 +357,7 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, glm_mask = ctx.saved_tensors dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) @@ -365,16 +378,17 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + glm_mask, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): + def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax, glm_mask): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( @@ -385,8 +399,9 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, + glm_mask=glm_mask ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, glm_mask) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -394,7 +409,7 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state, glm_mask = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, @@ -409,48 +424,28 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + glm_mask, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - return_softmax, - ): + def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, causal, return_softmax, glm_mask): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal=causal, - return_softmax=return_softmax and dropout_p > 0, - ) - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, + glm_mask=glm_mask ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, + cu_seqlens_q, cu_seqlens_k, rng_state, glm_mask) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k @@ -460,7 +455,7 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, glm_mask = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_varlen_backward( dout, @@ -479,17 +474,19 @@ def backward(ctx, dout, *args): ctx.dropout_p, ctx.softmax_scale, ctx.causal, + glm_mask, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None -def flash_attn_qkvpacked_func( - qkv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False -): + + +def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, glm_mask=None): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation @@ -515,12 +512,11 @@ def flash_attn_qkvpacked_func( The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs) + return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs, glm_mask) -def flash_attn_kvpacked_func( - q, kv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False -): +def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, glm_mask=None): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation @@ -561,12 +557,11 @@ def flash_attn_kvpacked_func( The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs) + return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs, glm_mask) -def flash_attn_func( - q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False -): +def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, glm_mask=None): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. @@ -605,18 +600,11 @@ def flash_attn_func( The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs) + return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs, glm_mask) -def flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - return_attn_probs=False, -): +def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, + causal=False, return_attn_probs=False, glm_mask=None): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation @@ -646,22 +634,13 @@ def flash_attn_varlen_qkvpacked_func( pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnVarlenQKVPackedFunc.apply( - qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs + qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs, glm_mask ) -def flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - return_attn_probs=False, -): +def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, glm_mask=None): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation @@ -719,22 +698,14 @@ def flash_attn_varlen_kvpacked_func( softmax_scale, causal, return_attn_probs, + glm_mask ) -def flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - return_attn_probs=False, -): + +def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, glm_mask=None): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. @@ -791,6 +762,7 @@ def flash_attn_varlen_func( softmax_scale, causal, return_attn_probs, + glm_mask ) @@ -868,3 +840,8 @@ def flash_attn_with_kvcache( q, k_cache, v_cache, k, v, cache_seqlens, None, softmax_scale, causal, num_splits ) return out + +# alias for compat +flash_attn_unpadded_func = flash_attn_varlen_func +flash_attn_unpadded_qkvpacked_func = flash_attn_varlen_qkvpacked_func +flash_attn_unpadded_kvpacked_func = flash_attn_varlen_kvpacked_func diff --git a/patches/d30f2e1c.patch b/patches/d30f2e1c.patch new file mode 100644 index 0000000000..f6501c5000 --- /dev/null +++ b/patches/d30f2e1c.patch @@ -0,0 +1,151 @@ +# git diff d30f2e1cd50185c98ed88c0684b4a603f15bee37 a4e5d1edddd67f9299fba510732b3c67dcab7219 +diff --git a/README.md b/README.md +index 88ba4ed..79d3345 100644 +--- a/README.md ++++ b/README.md +@@ -101,7 +101,7 @@ Return: + flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads +-than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. ++than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + +@@ -131,7 +131,7 @@ These functions have been renamed: + If the inputs have the same sequence lengths in the same batch, it is simpler + and faster to use these functions: + ```python +-flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False) ++flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) + ``` + ```python + flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) +diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h +index 98b2242..7c9638b 100644 +--- a/csrc/flash_attn/src/flash_bwd_kernel.h ++++ b/csrc/flash_attn/src/flash_bwd_kernel.h +@@ -1020,9 +1020,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + +- // If we don't need syncthreads here since we're writing to the same location as sK and sV. +- // Unless Is_V_in_regs. If Is_last, there's already a __syncthreads() at the end of the loop. +- if (Kernel_traits::Is_V_in_regs && !Is_last) { __syncthreads(); } ++ // We need syncthreads here since we're writing to the same location as sK and sV. ++ // Without syncthreads, some thread might modify the location of sK while another thread ++ // is reading it for dQ gemm, leading to a race condition. ++ // If Is_last, there's already a __syncthreads() at the end of the loop. ++ if (!Is_last) { __syncthreads(); } + + copy(smem_thr_copy_dKV, taccdKrdK, taccdKsdK); + copy(smem_thr_copy_dKV, taccdVrdV, taccdVsdV); +diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py +index f9a3f96..0572f72 100644 +--- a/flash_attn/__init__.py ++++ b/flash_attn/__init__.py +@@ -1,4 +1,4 @@ +-__version__ = "2.0.3" ++__version__ = "2.0.4" + + from flash_attn.flash_attn_interface import flash_attn_func + from flash_attn.flash_attn_interface import flash_attn_kvpacked_func +diff --git a/setup.py b/setup.py +index 88353f8..1cef260 100644 +--- a/setup.py ++++ b/setup.py +@@ -172,6 +172,7 @@ ext_modules.append( + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", ++ # "--ptxas-options=-O2", + "-lineinfo" + ] + + generator_flag +diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py +index 2260c88..7ee2863 100644 +--- a/tests/test_flash_attn.py ++++ b/tests/test_flash_attn.py +@@ -785,44 +785,49 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_ + + # @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) + @pytest.mark.parametrize('dtype', [torch.float16]) +-@pytest.mark.parametrize('causal', [False, True]) +-# @pytest.mark.parametrize('causal', [True]) ++# @pytest.mark.parametrize('causal', [False, True]) ++@pytest.mark.parametrize('causal', [False]) + # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) + # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +-@pytest.mark.parametrize('d', [64]) ++@pytest.mark.parametrize('d', [128]) + # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) +-@pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) +-# @pytest.mark.parametrize('seqlen', [193]) ++# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) ++@pytest.mark.parametrize('seqlen', [128]) + # @pytest.mark.parametrize('dropout_p', [0.0, 0.17]) + @pytest.mark.parametrize('dropout_p', [0.0]) + def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): +- if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: +- pytest.skip() # Reference implementation OOM + device = 'cuda' + # set seed + torch.random.manual_seed(0) +- batch_size = 32 ++ batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger + nheads = 4 +- qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) ++ qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, ++ requires_grad=True) + out0, lse0, _ = flash_attn_qkvpacked_func( + qkv, dropout_p, return_attn_probs=True, causal=causal + ) + g = torch.randn_like(out0) +- dqkv0, = torch.autograd.grad(out0, qkv, g) ++ if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): ++ dqkv0, = torch.autograd.grad(out0, qkv, g) ++ # Numerical error if we just do any arithmetic on dq ++ dq_atol = 2 * ((dqkv0[:, :, 0] + 0.3 - 0.3) - dqkv0[:, :, 0]).abs().max().item() + +- for _ in range(200): ++ for i in range(200): + torch.random.manual_seed(0) + out, lse, S_dmask = flash_attn_qkvpacked_func( + qkv, dropout_p, return_attn_probs=True, causal=causal + ) + assert torch.equal(out, out0) + assert torch.equal(lse, lse0) +- # sm_lse has some parts that are uninitialized from torch.empty +- # assert torch.equal(sm_lse, sm_lse_0) + +- if not (is_sm75 and d == 128): ++ if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + dqkv, = torch.autograd.grad(out, qkv, g) +- assert torch.equal(dqkv[:, :, 0], dqkv0[:, :, 0]) ++ dq_equal = torch.allclose(dqkv[:, :, 0], dqkv0[:, :, 0], atol=dq_atol) ++ if not dq_equal: ++ dq0 = dqkv0[:, :, 0] ++ dq = dqkv[:, :, 0] ++ print(f'Iter {i}, {dq_atol = }, dQ max diff: {(dqkv[:, :, 0] - dqkv0[:, :, 0]).abs().max().item()}') ++ assert dq_equal + assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1]) + assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2]) + +diff --git a/training/Dockerfile b/training/Dockerfile +index 94a5768..3bc477a 100644 +--- a/training/Dockerfile ++++ b/training/Dockerfile +@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr + RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 + + # Install FlashAttention +-RUN pip install flash-attn==2.0.3 ++RUN pip install flash-attn==2.0.4 + + # Install CUDA extensions for cross-entropy, fused dense, layer norm + RUN git clone https://github.com/HazyResearch/flash-attention \ +- && cd flash-attention && git checkout v2.0.3 \ ++ && cd flash-attention && git checkout v2.0.4 \ + && cd csrc/fused_softmax && pip install . && cd ../../ \ + && cd csrc/rotary && pip install . && cd ../../ \ + && cd csrc/xentropy && pip install . && cd ../../ \ diff --git a/tests/_glm_mask_test.py b/tests/_glm_mask_test.py new file mode 100644 index 0000000000..2663fbd354 --- /dev/null +++ b/tests/_glm_mask_test.py @@ -0,0 +1,106 @@ +import torch +import math +from torch.testing._internal.common_utils import freeze_rng_state +from flash_attn import flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func + + +device = 'cuda' +# set seed +torch.random.manual_seed(0) +#torch.backends.cudnn.deterministic = True +#batch_size = 4 +#nheads = 1 +#seqlen = 20 +#d = 32 +batch_size = 4 +nheads = 16 +seqlen = 512 +# d = 32 +d = 128 +#dtype=torch.bfloat16 +dtype=torch.float16 + +qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, + requires_grad=True) +glm_mask = torch.randint(10, seqlen - 10, (batch_size,), device=device, dtype=torch.int32) +# glm_mask = torch.tensor([200, 200, 200, 200], device=device, dtype=torch.int32) + +with freeze_rng_state(): + out, lse, S_dmask = flash_attn_qkvpacked_func( + qkv, 0.0, return_attn_probs=True, causal=True, glm_mask=glm_mask + ) + + +# k = 10 +# k = 200 +Q = qkv[:, :, 0, :, :] +K = qkv[:, :, 1, :, :] +V = qkv[:, :, 2, :, :] + + +# conventional transformer +def build_mask_matrix(hidden_states, seq_length, sep, memory_length=0): + m = hidden_states.new_ones((1, seq_length, seq_length)) + m = torch.tril(m) + if False: # is_scalar: + m[0, :, :int(sep)] = 1 + else: + m = m.expand(batch_size, -1, -1) + ids = torch.arange(seq_length, device=sep.device, dtype=sep.dtype).view(1, -1) + mask = ids < sep.view(-1, 1) + m = m.masked_fill(mask.unsqueeze(1).expand_as(m), 1) + if memory_length > 0: + m = m.expand(batch_size, -1, -1) + m = torch.cat((hidden_states.new_ones((batch_size, seq_length, memory_length)), m), dim=2) + m = m.unsqueeze(1) + return m + +def ref_attn_compute(Q, K, V, glm_mask): + attn_mask = build_mask_matrix(Q, seqlen, glm_mask) + # attn_mask = torch.ones(seqlen, seqlen, dtype=torch.float32, device=device).tril(diagonal=0) + # attn_mask[:, :k] = 1. + attn_mask = (1-attn_mask)*(-60000.) + # attn_mask = attn_mask.to(dtype) + #import pdb;pdb.set_trace() + #attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_mask, dim=-1) + #attn = attn_weight @ V + attn_weight = torch.softmax(torch.matmul(Q.permute(0, 2, 1, 3), K.permute(0, 2, 3, 1)) / math.sqrt(Q.size(-1)) + attn_mask, dim=-1) + attn = attn_weight @ V.permute(0, 2, 1, 3) + attn = attn.permute(0, 2, 1, 3) + return attn + + + + + + +#with freeze_rng_state(): +with freeze_rng_state(), torch.cuda.amp.autocast(dtype=dtype): + autocast_attn = ref_attn_compute(Q, K, V, glm_mask) +fp32_attn = ref_attn_compute(Q.to(torch.float32), K.to(torch.float32), V.to(torch.float32), glm_mask) + +out = torch.squeeze(out) + +print(f'Output max diff: {(out - fp32_attn).abs().max().item()}') +print(f'Output mean diff: {(out - fp32_attn).abs().mean().item()}') +print(f'Pytorch max diff: {(autocast_attn - fp32_attn).abs().max().item()}') +print(f'Pytorch mean diff: {(autocast_attn - fp32_attn).abs().mean().item()}') + +g = torch.randn_like(out) +dqkv, = torch.autograd.grad(out, qkv, g) +dqkv_ref, = torch.autograd.grad(fp32_attn, qkv, g) +dqkv_pt, = torch.autograd.grad(autocast_attn, qkv, g) + +print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') +print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') +print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') +print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}') +print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') +print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') +print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') +print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}') + + +#print((out-attn).abs().max()) +import pdb;pdb.set_trace() +# print(torch.allclose(attn, out, atol=1e-3, rtol=1e-3)) \ No newline at end of file diff --git a/training/Dockerfile b/training/Dockerfile index a99da835ee..3bc477aab9 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.2.2 +RUN pip install flash-attn==2.0.4 # Install CUDA extensions for cross-entropy, fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.2.2 \ + && cd flash-attention && git checkout v2.0.4 \ && cd csrc/fused_softmax && pip install . && cd ../../ \ && cd csrc/rotary && pip install . && cd ../../ \ && cd csrc/xentropy && pip install . && cd ../../ \