-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathgpt2_model.py
414 lines (371 loc) · 16.2 KB
/
gpt2_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
# Copyright (c) 2024 EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import math
import torch
import torch.nn as nn
from collections import defaultdict
from functools import partial
from megatron.model.utils import Lambda, SequentialWrapper, recursive_setattr
from megatron.model.norms import get_norm
from megatron.model.init_functions import get_init_methods
from megatron import mpu
from megatron.mpu import ParallelRelativePositionBias
from megatron.model.transformer import (
ParallelTransformerLayerPipe,
NormPipe,
ParallelLinearPipe,
parallel_lm_logits,
ParallelLinear,
)
from megatron.model.gmlp import GMLPBlock
from megatron.model.rwkv.v6 import RWKVResidualLayerPipe
from megatron.model.mamba import ParallelMambaResidualLayerPipe
from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding
# Pipeline parallelism
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
from typing import Union, List
def gpt2_attention_mask_func(attention_scores, ltor_mask):
mask_value = torch.finfo(attention_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(
mask_value, dtype=attention_scores.dtype, device=attention_scores.device
)
attention_scores.masked_fill_(ltor_mask, mask_value)
return attention_scores
def cross_entropy(output, labels, _fp16=False):
"""From pretrain_gpt2:forward_step()"""
"""
if self.fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
return loss
"""
labels, loss_mask = labels[0], labels[1]
if _fp16:
assert output.dtype == torch.half and loss_mask.dtype == torch.half
losses = mpu.vocab_parallel_cross_entropy(output.contiguous(), labels)
else:
losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
return loss
def _pre_transformer_block(args):
# data format change for hidden_states to avoid explicit tranposes : [b s h] --> [s b h]
assert len(args) == 2, "Incorrect number of arguments to _pre_transformer_block"
fn = lambda _args: (_args[0].transpose(0, 1).contiguous(), *_args[1:])
return fn(args)
def _post_transformer_block(args):
# from (hidden_states, attention_mask)
# to (hidden_states.T)
assert len(args) == 2, "Incorrect number of arguments to _post_transformer_block"
fn = lambda _args: (_args[0].transpose(0, 1).contiguous())
return fn(args)
class GPT2ModelPipe(PipelineModule, torch.nn.Module):
"""GPT2Model adapted for pipeline parallelism.
The largest change is flattening the GPTModel class so we can express it as a
sequence of layers including embedding, transformer layers, and output.
:param neox_args: NeoX arguments object (configuration)
:param num_tokentypes: number of token types (TODO: deprecated, remove)
:param parallel_output: if true, don't gather the output logits, and calculate loss in parallel. Set to true by default in training for efficiency, but set to false for inference.
:param topology: deepspeed topology object specifying pipe / model parallelism topology.
:param use_cache: if true, cache key/value pairs for each layer in inference.
"""
def __init__(
self,
neox_args,
num_tokentypes=0,
parallel_output=True,
topology=None,
use_cache=False,
):
self.neox_args = neox_args
self.use_cache = use_cache
self.parallel_output = parallel_output
self.hidden_size = self.neox_args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method, self.output_layer_init_method = get_init_methods(
self.neox_args
)
self.__topology__ = topology
self.specs = []
self.init_specs() # initializes the layer specs (basically a fancy nn.Sequential)
super().__init__(
layers=self.specs,
loss_fn=partial(cross_entropy, _fp16=self.neox_args.fp16_lm_cross_entropy),
topology=topology,
activation_checkpoint_interval=self.neox_args.checkpoint_num_layers
if self.neox_args.checkpoint_activations
else 0,
partition_method=neox_args.pipe_partition_method,
checkpointable_layers=[
"GMLPBlock",
"ParallelTransformerLayerPipe",
"ParallelMambaResidualLayerPipe",
],
)
def insert_layers(
self, layers: Union[nn.Module, nn.ModuleList, nn.Sequential, List], idx
):
"""
inserts the layers in `layers` into the pipe model at `idx`.
"""
if isinstance(layers, nn.Module):
self.specs.insert(idx, layers)
elif any(
[isinstance(layers, nn.ModuleList), isinstance(layers, nn.Sequential)]
):
self.specs[idx:idx] = layers
elif isinstance(layers, list):
assert all(
[hasattr(l, "__call__") for l in layers]
), "all items in `layers` must be Callables"
self.specs[idx:idx] = layers
else:
raise ValueError(
f"layer passed into {self.__class__.__name__}.insert_layer() should be either an nn.Module, an nn.ModuleList, an nn.Sequential object, or a list of callables not a {type(layers)}"
)
# re-initialize parent class
super().__init__(
layers=self.specs,
loss_fn=self.loss_fn,
topology=self.__topology__,
activation_checkpoint_interval=self.activation_checkpoint_interval,
partition_method=self.neox_args.pipe_partition_method,
checkpointable_layers=[
"GMLPBlock",
"ParallelTransformerLayerPipe",
"ParallelMambaResidualLayerPipe",
"RWKVResidualLayerPipe",
],
)
def init_specs(self):
weight_tying = not self.neox_args.no_weight_tying
self.specs = []
# Embedding layer
# input will be (input_ids, position_ids, attention_mask)
if weight_tying:
self.specs.append(
TiedLayerSpec(
"embed",
EmbeddingPipe,
self.neox_args,
self.hidden_size,
self.neox_args.padded_vocab_size,
self.neox_args.max_position_embeddings,
self.neox_args.hidden_dropout,
self.init_method,
self.num_tokentypes,
tied_weight_attr="word_embeddings_weight",
)
)
else:
self.specs.append(
LayerSpec(
EmbeddingPipe,
self.neox_args,
self.hidden_size,
self.neox_args.padded_vocab_size,
self.neox_args.max_position_embeddings,
self.neox_args.hidden_dropout,
self.init_method,
self.num_tokentypes,
)
)
# NB: the attention mask always needs to be the *last* item in the args when being passed from
# one stage to the next, because deepspeed is hacks on top of hacks.
#
# outputs are now (hidden_states, attention_mask)
self.specs.append(_pre_transformer_block)
# T5 RPE positional embedding
if self.neox_args.pos_emb == "rpe":
hidden_size_per_attention_head = mpu.divide(
self.neox_args.hidden_size, self.neox_args.num_attention_heads
)
rpe_scale = math.sqrt(hidden_size_per_attention_head)
rpe_emb = ParallelRelativePositionBias(
neox_args=self.neox_args,
scale=rpe_scale,
causal=True,
num_buckets=self.neox_args.rpe_num_buckets,
max_distance=self.neox_args.rpe_max_distance,
heads=self.neox_args.num_attention_heads,
)
# Transformer layers
for i in range(self.neox_args.num_layers):
layer_type = self.neox_args.attention_config[i]
if layer_type in ["gmlp", "amlp"]:
self.specs.append(
LayerSpec(
GMLPBlock,
init_method=self.init_method,
layer_number=i,
output_layer_init_method=self.output_layer_init_method,
neox_args=self.neox_args,
mask_fn=gpt2_attention_mask_func,
)
)
elif layer_type == "rwkv":
self.specs.append(
LayerSpec(
RWKVResidualLayerPipe,
neox_args=self.neox_args,
init_method=self.init_method,
layer_number=i,
)
)
elif layer_type in ["mamba"]:
self.specs.append(
LayerSpec(
ParallelMambaResidualLayerPipe,
neox_args=self.neox_args,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
layer_number=i,
)
)
else:
self.specs.append(
LayerSpec(
ParallelTransformerLayerPipe,
neox_args=self.neox_args,
attention_mask_func=gpt2_attention_mask_func,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
layer_number=i,
rpe=rpe_emb if self.neox_args.pos_emb == "rpe" else None,
rotary=self.neox_args.pos_emb == "rotary",
use_cache=self.use_cache,
)
)
# used to drop attention mask + reshape hidden states
self.specs.append(_post_transformer_block)
# NormPipe is a (deprecated) helper class that used to be used to pass presents along the pipeline - since presents are now cached to the `TransformerLayer` class this is no longer needed
norm, eps = get_norm(self.neox_args)
self.specs.append(
LayerSpec(NormPipe, norm, self.neox_args.hidden_size, eps=eps)
)
# outputs are now a single tensor: hidden_states
def _logits_helper(embedding, lm_output):
"""Just a wrapper to massage inputs/outputs from pipeline."""
if self.neox_args.use_mup:
# Since we're using pipeline parallelism, we can't directly use MuReadout. Instead, use this workaround that does the same thing as MuReadout.
# https://github.com/microsoft/mup/issues/6#issuecomment-1082156274
lm_output = (
lm_output
/ self.tied_modules.embed.word_embeddings.weight.infshape.width_mult()
)
logits = parallel_lm_logits(
lm_output,
embedding.word_embeddings_weight,
self.parallel_output,
seq_parallel=self.neox_args.sequence_parallel,
)
return logits
if weight_tying:
self.specs.append(
TiedLayerSpec(
"embed",
EmbeddingPipe,
self.neox_args,
self.hidden_size,
self.neox_args.padded_vocab_size,
self.neox_args.max_position_embeddings,
self.neox_args.hidden_dropout,
self.init_method,
self.num_tokentypes,
forward_fn=_logits_helper,
tied_weight_attr="word_embeddings_weight",
)
)
else:
self.specs.append(
LayerSpec(
ParallelLinearPipe,
neox_args=self.neox_args,
init_method=self.init_method,
parallel_output=self.parallel_output,
is_last_layer=True,
)
)
def _set_parallel_output(self, value):
# sets the parallel output value of the final layer to value
final_layer = list(self.forward_funcs)[-1]
if isinstance(final_layer, (ParallelLinearPipe, ParallelLinear)):
final_layer.final_linear.set_parallel_output(value)
def inference_mode(self, use_cache=True):
"""
Sets up the model for inference by turning on k/v caching (if specified) and setting `parallel output` of the final layer to false,
so logits are gathered across model parallel ranks.
:param cache: (bool) True if you want to use caching during inference, False otherwise
"""
# first set caching to true if specified
recursive_setattr(self.forward_funcs, "use_cache", use_cache, assert_type=bool)
# then set parallel output of the final layer to false so we don't have to gather the output manually
self._set_parallel_output(False)
recursive_setattr(self.forward_funcs, "training", False)
def train_mode(self):
"""
Sets up the model for training by turning off k/v caching and setting `parallel output` of the final layer to True,
so logits are not gathered across model parallel ranks, and loss is computed in parallel (more efficient).
"""
# set caching to false
recursive_setattr(self.forward_funcs, "use_cache", False)
# then set parallel output to true (more efficient training)
self._set_parallel_output(True)
recursive_setattr(self.forward_funcs, "training", True)
def clear_cache(self):
"""
Recursively clears the kv cache on all layers
"""
recursive_setattr(self.forward_funcs, "layer_past", None)
def to_sequential(self):
"""
Transforms the PipelineModule to a plain nn.Sequential module
:return:
"""
layers = []
tied_layers = defaultdict(list)
for n, spec in enumerate(self.specs):
if isinstance(spec, TiedLayerSpec):
if spec.key in tied_layers:
# receiver
layers.append(
Lambda(lambda x: spec.forward_fn(tied_layers[spec.key][0], x))
)
else:
# owner
module = spec.build(log=False)
layers.append(module)
tied_layers[spec.key].append(module)
elif isinstance(spec, LayerSpec):
layers.append(spec.build(log=False))
elif hasattr(spec, "__call__"):
# check that it's a callable function
layers.append(Lambda(spec))
else:
raise ValueError(f"Layer number {n} ({spec}) Not recognized")
model = SequentialWrapper(
layers,
self.activation_checkpoint_interval,
self.activation_checkpoint_func,
parent_class_name=self.__class__.__name__,
)
return model