-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathattention.py
executable file
·198 lines (167 loc) · 9.68 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union
import transformer_engine as te
import torch
from rotary import *
from enums import AttnMaskType
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_number, num_gs, attn_mask_type=AttnMaskType.padding, **kwargs):
super().__init__()
assert config.hidden_size % config.num_mem_heads == 0
self.config = config
self.linear_qkv = nn.Linear(2 * config.hidden_size, 6 * config.hidden_size, bias=config.add_bias_linear)
self.linear_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=config.add_bias_linear)
self.n_head = config.num_mem_heads
self.n_embd = config.hidden_size * 2
self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
world_size = 1
self.world_size = world_size
self.num_gs = num_gs
self.hidden_size_per_attention_head = self.query_projection_size // self.config.num_attention_heads
self.num_attention_heads_per_partition = self.config.num_attention_heads
self.num_query_groups_per_partition = self.config.num_query_groups
self.dpa = te.pytorch.DotProductAttention(num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=0.0,
layer_number=layer_number,
attn_mask_type="causal"
)
self.dpa_generation = te.pytorch.DotProductAttention(num_attention_heads=self.config.num_attention_heads,
kv_channels =self.config.kv_channels,
attention_dropout=0.0, layer_number=layer_number,
attn_mask_type="no_mask")
if self.config.use_shared_attention_lora:
self.linear_q_lora_A_list = nn.ParameterList([])
self.linear_q_lora_B_list = nn.ParameterList([])
self.linear_k_lora_A_list = nn.ParameterList([])
self.linear_k_lora_B_list = nn.ParameterList([])
self.linear_v_lora_A_list = nn.ParameterList([])
self.linear_v_lora_B_list = nn.ParameterList([])
for i in range(self.num_gs):
linear_q_lora_A = nn.Linear(2 * self.config.hidden_size, self.config.lora_rank, bias = False)
linear_q_lora_B = nn.Linear(self.config.lora_rank, 2 * self.query_projection_size, bias = False)
self.linear_q_lora_A_list.append(linear_q_lora_A)
self.linear_q_lora_B_list.append(linear_q_lora_B)
linear_k_lora_A = nn.Linear(2 * self.config.hidden_size,self.config.lora_rank,bias = False)
linear_k_lora_B = nn.Linear(self.config.lora_rank, 2 * self.kv_projection_size, bias = False)
self.linear_k_lora_A_list.append(linear_k_lora_A)
self.linear_k_lora_B_list.append(linear_k_lora_B)
linear_v_lora_A = nn.Linear(2 * self.config.hidden_size, self.config.lora_rank, bias = False)
linear_v_lora_B = nn.Linear(self.config.lora_rank, 2 * self.kv_projection_size, bias = False)
self.linear_v_lora_A_list.append(linear_v_lora_A)
self.linear_v_lora_B_list.append(linear_v_lora_B)
def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype):
"""Allocate memory to store kv cache during inference."""
return torch.empty(
inference_max_sequence_length,
batch_size,
self.num_query_groups_per_partition,
self.hidden_size_per_attention_head * 2,
dtype=dtype,
device=torch.cuda.current_device(),
)
def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb, layer_number):
"""
Saves the generated key and value tensors to the end of the buffers in inference_params.
Returns the full size keys and values from the provided inference_params, as well as
adjusted rotary_pos_emb.
Returns a tuple: (key, value, rotary_pos_emb)
"""
if inference_params is None:
return key, value, rotary_pos_emb
is_first_step = False
if layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, key.dtype
)
inference_value_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, value.dtype
)
inference_params.key_value_memory_dict[layer_number] = (
inference_key_memory,
inference_value_memory,
)
is_first_step = True
else:
inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
layer_number
]
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key.size(0)
assert sequence_end <= inference_key_memory.size(0)
inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value
key = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
if not is_first_step:
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
return key, value, rotary_pos_emb
def forward(self, hidden_states, attention_mask, key_value_states=None, inference_params=None, rotary_pos_emb=None, forward_layer_idx = None):
qkv_out = self.linear_qkv(hidden_states)
new_tensor_shape = qkv_out.size()[:-1] + (
self.num_query_groups_per_partition,
(
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head * 2
),
)
qkv_out = qkv_out.view(*new_tensor_shape)
(query, key, value) = torch.split(
qkv_out,
[
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head * 2
),
self.hidden_size_per_attention_head * 2,
self.hidden_size_per_attention_head * 2,
],
dim=3,
)
if self.config.use_shared_attention_lora:
new_lora_tensor_shape = new_tensor_shape[:-1] + (-1,)
linear_q_lora_A = self.linear_q_lora_A_list[forward_layer_idx]
linear_q_lora_B = self.linear_q_lora_B_list[forward_layer_idx]
q_lora = linear_q_lora_A(hidden_states)
q_lora = linear_q_lora_B(q_lora)
query = query + q_lora.view(new_lora_tensor_shape)
linear_k_lora_A = self.linear_k_lora_A_list[forward_layer_idx]
linear_k_lora_B = self.linear_k_lora_B_list[forward_layer_idx]
k_lora = linear_k_lora_A(hidden_states)
k_lora = linear_k_lora_B(k_lora)
key = key + k_lora.view(new_lora_tensor_shape)
linear_v_lora_A = self.linear_v_lora_A_list[forward_layer_idx]
linear_v_lora_B = self.linear_v_lora_B_list[forward_layer_idx]
v_lora = linear_v_lora_A(hidden_states)
v_lora = linear_v_lora_B(v_lora)
value = value + v_lora.view(new_lora_tensor_shape)
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head * 2)
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
key, value, rotary_pos_emb = self._adjust_key_value_for_inference(
inference_params, key, value, rotary_pos_emb, forward_layer_idx
)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
if inference_params is None or inference_params.sequence_len_offset == 0:
y = self.dpa(query, key, value)
else:
y = self.dpa_generation(query, key, value)
y = self.linear_proj(y)
return y