-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetala_qv_selfaug.py
286 lines (238 loc) · 10.2 KB
/
metala_qv_selfaug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import math
import torch
import torch.nn.functional as F
import torch.nn as nn
from megatron import mpu
from megatron.model.activations import get_activation
from einops import rearrange
import triton
import triton.language as tl
from fla.ops.gla import fused_chunk_gla, chunk_gla, fused_recurrent_gla
from megatron.model.norms import LayerNorm, get_norm
from causal_conv1d import causal_conv1d_fn
import einops
class LLaMAParallelMLP(nn.Module):
"""LLaMA's MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
Note: multiple_of is used to compute the hidden dimension of the MLP
"""
def __init__(
self,
neox_args,
init_method,
output_layer_init_method,
parallel_output=False,
multiple_of=256,
MOE=False,
MoE_mp_size=1,
):
super().__init__()
self.activation_func = get_activation(neox_args)
self.activation_type = neox_args.activation
self.multiple_of = multiple_of
# Allow custom intermediate size, e.g. for Mistral
if neox_args.intermediate_size is not None:
ff_dim = neox_args.intermediate_size
else:
ff_dim = int(2 * neox_args.hidden_size * 4 / 3)
ff_dim = self.multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
self.w1 = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
bias=False,
MOE=MOE,
MoE_mp_size=MoE_mp_size,
)
self.w3 = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
bias=False,
MOE=MOE,
MoE_mp_size=MoE_mp_size,
)
self.w2 = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
bias=False,
MOE=MOE,
MoE_mp_size=MoE_mp_size,
)
def forward(self, hidden_states):
w1_out, _ = self.w1(hidden_states)
w3_out, _ = self.w3(hidden_states)
return self.w2(self.activation_func(w1_out) * w3_out)
class ParallelMetaLA_Attention_selfaug(nn.Module):
def __init__(self, neox_args, init_method, output_layer_init_method,):
super().__init__()
self.embed_dim = neox_args.hidden_size
self.num_heads = neox_args.num_attention_heads
self.gate_fn = nn.functional.silu
self.q_proj = mpu.ColumnParallelLinear(neox_args=neox_args,
input_size=self.embed_dim,
output_size=self.embed_dim//2,
bias=False,
gather_output=True,
init_method=init_method,
skip_bias_add=not False)
self.k_gate = mpu.ColumnParallelLinear(neox_args=neox_args,
input_size=self.embed_dim,
output_size=self.embed_dim//2,
bias=False,
gather_output=True,
init_method=init_method,
skip_bias_add=not False)
self.v_proj = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=self.embed_dim,
output_size=self.embed_dim,
gather_output=True,
init_method=init_method,
skip_bias_add=not False,
bias=False,
)
self.g_proj = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=self.embed_dim,
output_size=self.embed_dim,
gather_output=True,
init_method=init_method,
skip_bias_add=not True,
bias=True,
)
self.out_proj = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=self.embed_dim,
output_size=self.embed_dim,
input_is_parallel=False,
init_method=output_layer_init_method,
skip_bias_add=not False,
bias=False,
parallel_output=False,
)
self.head_dim = self.embed_dim // self.num_heads
self.group_norm = LayerNorm(self.head_dim, eps=1e-5, elementwise_affine=False)
self.aug_balance = nn.Parameter(0.0 * torch.zeros(self.embed_dim//2))
self.d_conv = 4
self.conv1d = nn.Conv1d(
in_channels=self.embed_dim,
out_channels=self.embed_dim,
bias=False,
kernel_size=self.d_conv,
groups=self.embed_dim,
padding=self.d_conv - 1,
# **factory_kwargs,
)
def forward(self, x, hidden_states=None):
x = x.transpose(0, 1).contiguous()
##### short convolution #####
x = rearrange(x, 'b l d -> b d l').contiguous()
x = causal_conv1d_fn(
x=x,
weight=einops.rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias.to(self.precision)
if self.conv1d.bias is not None
else self.conv1d.bias,
activation="silu",
)
x = rearrange(x, 'b d l -> b l d').contiguous()
q, _ = self.q_proj(x)
k_gate, _ = self.k_gate(x)
k = 1
v, _ = self.v_proj(x)
g, _ = self.g_proj(x)
output, new_hidden_states = self.meta_linear_attention(q, k, v, k_gate, hidden_states=hidden_states)
output = self.gate_fn(g) * output
output, _ = self.out_proj(output)
output = output.transpose(0, 1)
return output, new_hidden_states
def meta_linear_attention(self, q, k, v, gk, normalizer=16, hidden_states=None):
##### remove key #####
gk = F.logsigmoid(gk) / normalizer
k = 1 - torch.exp(gk)
q = rearrange(q, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
k = rearrange(k, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
v = rearrange(v, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
gk = rearrange(gk, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
aug_balance = rearrange(self.aug_balance, '(h d) -> h d', h = self.num_heads).contiguous()
if self.training:
o, new_hidden_states = fused_chunk_gla(q, k, v, gk, initial_state=hidden_states, output_final_state=True)
else:
o, new_hidden_states = fused_recurrent_gla(q, k, v, gk, initial_state=hidden_states, output_final_state=True)
##### self augmentation #####
augk = torch.einsum('bhld,hd->bhld', k, aug_balance)
aug_w = torch.einsum('bhld,bhld->bhl', q, augk)
o = o + F.sigmoid(aug_w.unsqueeze(-1) * v)
o = self.group_norm(o)
o = rearrange(o, 'b h l d -> b l (h d)')
return o, new_hidden_states
class ParallelMetaLALayer_selfaug(nn.Module):
def __init__(
self,
neox_args,
init_method,
output_layer_init_method,
layer_number,
use_cache=False
):
super().__init__()
assert not use_cache, "[MetaLA]: use_cache conflicts with training mode!"
self.neox_args = neox_args
self.layer_number = layer_number
norm, eps = get_norm(neox_args)
self.input_layernorm = norm(neox_args.hidden_size, eps=eps)
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps)
self.hidden_dropout = neox_args.hidden_dropout
self.attention = ParallelMetaLA_Attention_selfaug(neox_args,
init_method,
output_layer_init_method)
self.mlp = LLaMAParallelMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
)
def forward(self, x, attention_mask, layer_past=None):
residual = x # (l, b, d)
moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
attention_output, _ = self.attention(self.input_layernorm(x))
with torch.enable_grad():
attention_output = (
torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
)
+ residual
)
layernorm_output = self.post_attention_layernorm(attention_output)
mlp_output, _ = self.mlp(layernorm_output)
with torch.enable_grad():
output = mlp_output + attention_output
return output, moe_loss
class ParallelMetaLALayer_selfaugPipe(ParallelMetaLALayer_selfaug):
"""Extends ParallelMetaLALayer to forward attention_mask through the pipeline."""
def forward(self, args):
assert (
len(args) == 2
), "ParallelMetaLALayer_selfaugPipe expects 2 arguments - hidden_states and attention_mask"
hidden_states, attention_mask = args
# we are returning just [hidden_states, mask]
output, moe_loss = super().forward(hidden_states, attention_mask)
# auxiliary output
self.last_moe_loss = moe_loss
return output, attention_mask