Skip to content

如何将flash attention算法用在工程中——来看看transformers怎么做的【transformers源码阅读】

yuanzhoulvpi edited this page Jun 29, 2023 · 2 revisions

本文内容介绍

  1. 一个超快而且省显存的注意力算法flash attention 这个大家应该都知道了。
  2. 本文将介绍如何在模型中,使用flash attention。,也就是工程上的实现。

原理速览

  1. 在论文https://arxiv.org/abs/2205.14135中,提出了一种IO感知精确注意力算法。
  2. 随着Transformer变得越来越大、越来越深,但它在长序列上仍然处理的很慢、且耗费内存。(自注意力时间和显存复杂度与序列长度成二次方);
  3. 现有近似注意力方法,在试图通过去牺牲模型质量,以降低计算复杂度来解决该问题。
  4. 但存在一定的局限性,即不能提升运行时的训练速度。
  5. 研究者认为,应该让注意力算法具有IO感知,即考虑显存级间的读写,比如大但慢的HBM(High Bandwidth Memory)技术与小但快的SRAM。
  6. 基于这样的背景,研究人员提出了FlashAttention,具体有两种加速技术:按块递增计算即平铺、并在后向传递中重新计算注意力,将所有注意力操作融合到CUDA内核中。

  1. 研究人员评估了FlashAttention来训练Transformer的影响,包括训练时间、模型准确性,以及注意力运行时间和内存效率。效果就是非常的好。

flash attention更多内容,大家可以看原来的论文了,就不再介绍,这里分享:怎么在模型中用这个flash attention

案例1 transformers包的open_llama实现

前几天,在transformers包里面的open_llama模型的代码里面,看到了memory_efficient_attention的实现。

代码

try:
    from xformers import ops as xops
except ImportError:
    xops = None
    logger.warn(
        "Xformers is not installed correctly. If you want to use memorry_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
    )
        past_key_value = (key_states, value_states) if use_cache else None

        if self.config.use_memorry_efficient_attention and xops is not None and self.training:
            attn_weights = None
            query_states = query_states.transpose(1, 2)
            key_states = key_states.transpose(1, 2)
            value_states = value_states.transpose(1, 2)
            attn_output = xops.memory_efficient_attention(
                query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
            )
        else:
            attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

            if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
                    f" {attn_weights.size()}"
                )

            if attention_mask is not None:
                if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                    raise ValueError(
                        f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                    )
                attn_weights = attn_weights + attention_mask
                attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

            # upcast attention to fp32
            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_output = torch.matmul(attn_weights, value_states)

            if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
                raise ValueError(
                    f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                    f" {attn_output.size()}"
                )

            attn_output = attn_output.transpose(1, 2)

解读

  1. 上面的代码,位于https://github.com/huggingface/transformers/blob/c2c99dc7ef5edab8f7674a1eb00cf6ac6996fd0f/src/transformers/models/open_llama/modeling_open_llama.py#L234 点击链接就可以看到。(也就是在open_llama里面的OpenLlamaAttentionforward里面)。
  2. 使用了xformers包提供的memory_efficient_attention函数来实现。
  3. 需要注意的是,在使用use_memorry_efficient_attention模式的时候,只能在训练的时候。而且不会影响模型的结构,更不会影响模型的权重参数。
  4. 而且只是用到了qkv三个变量,没有用到attn_mask

案例2 x-transformers包的attend实现

transformers包的代码,工程风味更强一点; x-transformers包的代码,算法风味更强一点。而且看x-transformers包的作者的GitHub的沙雕头像,就能感受到随性~

x-transformers里面,也看到他实现了的一个版本

代码

from x_transformers.attend import Attend, Intermediates
        # attend class - includes core attention algorithm + talking heads

        self.attend = Attend(
            heads = heads,
            causal = causal,
            talking_heads = talking_heads,
            dropout = dropout,
            qk_norm = qk_norm,
            scale = qk_norm_scale if qk_norm else self.scale,
            flash = flash
        )
        # attention is all we need

        out, intermediates = self.attend(
            q, k, v,
            mask = final_attn_mask,
            attn_bias = attn_bias,
            prev_attn = prev_attn
        )

        # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients

        if exists(r):
            out = out * r + out

解读

  1. 上面的代码,位于https://github.com/lucidrains/x-transformers/blob/31f0d3657e906c95a1231919a4477b44d5caa916/x_transformers/x_transformers.py#L830

  2. 完整的应用在x_transformers.py文件里面;完整的实现在attend.py文件里面。

  1. 代码看起来清爽舒服,敏感肌也能用,但是我并没有看到他把部分算法用cuda重写了,所以快不快,我不清楚。

案例3 chatglm-v2-6b的实现和falcon-7b的实现

  1. 在最新的chatglm-v2-6b的代码里面,和在falcon-7b的代码里面,他们都是基于torch.nn.functional.scaled_dot_product_attention函数来做提速的。
  2. scaled_dot_product_attention函数,是已经被集成在pytorch2.0中,所以可以直接用,也不需要做而外的操作,用起来非常简单。

chatglm-v2-6b代码

    def forward(self, query_layer, key_layer, value_layer, attention_mask):
        pytorch_major_version = int(torch.__version__.split('.')[0])
        if pytorch_major_version >= 2:
            query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
            if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                                 is_causal=True)
            else:
                if attention_mask is not None:
                    attention_mask = ~attention_mask
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                                 attention_mask)
            context_layer = context_layer.permute(2, 0, 1, 3)
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.reshape(*new_context_layer_shape)
        else:
            # Raw attention scores

            # [b, np, sq, sk]

falcon-7b代码

  if alibi is None:
            query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
            key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
            value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)

            attn_output = F.scaled_dot_product_attention(
                query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
            )

            x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
            x = x.permute(0, 2, 1, 3)
            attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)

            output_tensor = self.dense(attn_output)

            outputs = (output_tensor, present)
            assert not output_attentions  # not supported.
            return outputs
        else:
            attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
            matmul_result = query_layer @ key_layer.transpose(-1, -2)

            # change view to [batch_size, num_heads, q_length, kv_length]
            attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)

            # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
            input_dtype = attention_scores.dtype

参考链接:

  1. xformers包: https://github.com/facebookresearch/xformers
  2. 论文对应的GitHub: https://github.com/HazyResearch/flash-attention
  3. x-transformers包: https://github.com/lucidrains/x-transformers
  4. chatglm-v2-6b模型代码:https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py
  5. falcon-7b模型代码: https://huggingface.co/tiiuae/falcon-7b/blob/main/modelling_RW.py
  6. pytorch的scaled_dot_product_attention对应的介绍: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

更多

我喜欢阅读transformers源码,如果对nlp和transformers包源码感兴趣,欢迎关注我~