-
Notifications
You must be signed in to change notification settings - Fork 386
如何将flash attention算法用在工程中——来看看transformers怎么做的【transformers源码阅读】
yuanzhoulvpi edited this page Jun 29, 2023
·
2 revisions
- 一个超快而且省显存的注意力算法
flash attention
这个大家应该都知道了。 - 本文将介绍如何在模型中,使用
flash attention
。,也就是工程上的实现。
- 在论文https://arxiv.org/abs/2205.14135中,提出了一种IO感知精确注意力算法。
- 随着Transformer变得越来越大、越来越深,但它在长序列上仍然处理的很慢、且耗费内存。(自注意力时间和显存复杂度与序列长度成二次方);
- 现有近似注意力方法,在试图通过去牺牲模型质量,以降低计算复杂度来解决该问题。
- 但存在一定的局限性,即不能提升运行时的训练速度。
- 研究者认为,应该让注意力算法具有IO感知,即考虑显存级间的读写,比如大但慢的HBM(High Bandwidth Memory)技术与小但快的SRAM。
- 基于这样的背景,研究人员提出了FlashAttention,具体有两种加速技术:按块递增计算即平铺、并在后向传递中重新计算注意力,将所有注意力操作融合到CUDA内核中。
- 研究人员评估了FlashAttention来训练Transformer的影响,包括训练时间、模型准确性,以及注意力运行时间和内存效率。效果就是非常的好。
flash attention
更多内容,大家可以看原来的论文了,就不再介绍,这里分享:怎么在模型中用这个flash attention
。
前几天,在transformers
包里面的open_llama
模型的代码里面,看到了memory_efficient_attention
的实现。
try:
from xformers import ops as xops
except ImportError:
xops = None
logger.warn(
"Xformers is not installed correctly. If you want to use memorry_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
)
past_key_value = (key_states, value_states) if use_cache else None
if self.config.use_memorry_efficient_attention and xops is not None and self.training:
attn_weights = None
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = xops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
- 上面的代码,位于https://github.com/huggingface/transformers/blob/c2c99dc7ef5edab8f7674a1eb00cf6ac6996fd0f/src/transformers/models/open_llama/modeling_open_llama.py#L234 点击链接就可以看到。(也就是在
open_llama
里面的OpenLlamaAttention
的forward
里面)。 - 使用了
xformers
包提供的memory_efficient_attention
函数来实现。 - 需要注意的是,在使用
use_memorry_efficient_attention
模式的时候,只能在训练的时候。而且不会影响模型的结构,更不会影响模型的权重参数。 - 而且只是用到了
q
、k
、v
三个变量,没有用到attn_mask
。
transformers
包的代码,工程风味更强一点;
x-transformers
包的代码,算法风味更强一点。而且看x-transformers
包的作者的GitHub的沙雕头像,就能感受到随性~
在x-transformers
里面,也看到他实现了的一个版本
from x_transformers.attend import Attend, Intermediates
# attend class - includes core attention algorithm + talking heads
self.attend = Attend(
heads = heads,
causal = causal,
talking_heads = talking_heads,
dropout = dropout,
qk_norm = qk_norm,
scale = qk_norm_scale if qk_norm else self.scale,
flash = flash
)
# attention is all we need
out, intermediates = self.attend(
q, k, v,
mask = final_attn_mask,
attn_bias = attn_bias,
prev_attn = prev_attn
)
# https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
if exists(r):
out = out * r + out
-
完整的应用在
x_transformers.py
文件里面;完整的实现在attend.py
文件里面。
- 代码看起来清爽舒服,敏感肌也能用,但是我并没有看到他把部分算法用cuda重写了,所以快不快,我不清楚。
- 在最新的chatglm-v2-6b的代码里面,和在falcon-7b的代码里面,他们都是基于torch.nn.functional.scaled_dot_product_attention函数来做提速的。
- scaled_dot_product_attention函数,是已经被集成在pytorch2.0中,所以可以直接用,也不需要做而外的操作,用起来非常简单。
def forward(self, query_layer, key_layer, value_layer, attention_mask):
pytorch_major_version = int(torch.__version__.split('.')[0])
if pytorch_major_version >= 2:
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
is_causal=True)
else:
if attention_mask is not None:
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
attention_mask)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
else:
# Raw attention scores
# [b, np, sq, sk]
if alibi is None:
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
)
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
x = x.permute(0, 2, 1, 3)
attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
output_tensor = self.dense(attn_output)
outputs = (output_tensor, present)
assert not output_attentions # not supported.
return outputs
else:
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
matmul_result = query_layer @ key_layer.transpose(-1, -2)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
-
xformers
包: https://github.com/facebookresearch/xformers - 论文对应的GitHub: https://github.com/HazyResearch/flash-attention
-
x-transformers
包: https://github.com/lucidrains/x-transformers -
chatglm-v2-6b
模型代码:https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py -
falcon-7b
模型代码: https://huggingface.co/tiiuae/falcon-7b/blob/main/modelling_RW.py - pytorch的scaled_dot_product_attention对应的介绍: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
我喜欢阅读transformers
源码,如果对nlp和transformers
包源码感兴趣,欢迎关注我~