-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathdata_utils.py
580 lines (509 loc) · 22.2 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
# Copyright (c) 2024, EleutherAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import numpy as np
from typing import List, Tuple
from itertools import zip_longest
from functools import partial
from megatron import mpu, print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data.samplers import DistributedBatchSampler
def make_data_loader(dataset, neox_args):
"""Build dataloader given an input dataset."""
if dataset is None:
return None
# Data parallel arguments.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = neox_args.batch_size * world_size
num_workers = neox_args.num_workers
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(
sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size,
)
# Torch dataloader.
return torch.utils.data.DataLoader(
dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True
)
def build_the_dataset(
data_prefix,
name,
data_impl,
num_samples,
seq_length,
seed,
skip_warmup,
build_index_mappings=True,
label_prefix=None,
index_mapping_paths=None,
index_offset=0,
reshuffle_when_loading=True,
):
"""Build train/valid/test datasets."""
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
if label_prefix is None:
label_dataset = None
else:
label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
print_rank_0(" {}:".format(name))
print_rank_0(" no. of documents:{}".format(total_num_of_documents))
dataset = None
documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32)
dataset = GPT2Dataset(
name,
data_prefix,
documents,
indexed_dataset,
num_samples,
seq_length,
seed,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
index_mapping_paths=index_mapping_paths,
index_offset=index_offset,
reshuffle_when_loading=reshuffle_when_loading,
)
return dataset
def build_train_valid_test_datasets(
data_prefix,
use_shared_fs,
data_impl,
splits_string,
train_valid_test_num_samples,
seq_length,
seed,
skip_warmup,
):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(" > dataset split:")
def print_split_stats(name, index):
print_rank_0(" {}:".format(name))
print_rank_0(
" document indices in [{}, {}) total of {} "
"documents".format(
splits[index], splits[index + 1], splits[index + 1] - splits[index]
)
)
print_split_stats("train", 0)
print_split_stats("validation", 1)
print_split_stats("test", 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(
start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32
)
dataset = GPT2Dataset(
name,
data_prefix,
documents,
indexed_dataset,
train_valid_test_num_samples[index],
seq_length,
seed,
use_shared_fs=use_shared_fs,
)
return dataset
train_dataset = build_dataset(0, "train")
valid_dataset = build_dataset(1, "valid")
test_dataset = build_dataset(2, "test")
return train_dataset, valid_dataset, test_dataset
def get_train_valid_test_split_(splits_string, size):
"""Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(",") != -1:
splits = [float(s) for s in splits_string.split(",")]
elif splits_string.find("/") != -1:
splits = [float(s) for s in splits_string.split("/")]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.0)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] + int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
def get_normalized_weights_and_num_samples(
weights: List[float], num_samples: int
) -> Tuple[List[float], List[int]]:
# Normalize weights
weight_sum = sum(weights)
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]
# Add 0.5% (the 1.005 factor) so in case the blending dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
weighted_num_samples = []
for weight in weights:
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
return weights, weighted_num_samples
def get_normalized_weights_and_num_samples_with_replay(
weights: List[float], replay_weights: List[float], replay_fraction, num_samples: int
) -> Tuple[List[float], List[int]]:
# Normalize weights. weights correspond to the weights from the training data and replay_weights correspond
# to weights from the replay data. The idea is that we will be merge the weights provided for training data
# and replay data into the same array. We know that replay_weights should contribute replay_fraction of all
# weights, so we also need to normalise replay weights by replay_fraction and the rest by (1-replay_fraction).
weight_sum = sum(weights)
assert weight_sum > 0.0
weights = [(weight / weight_sum) * (1-replay_fraction) for weight in weights]
replay_weights_sum = sum(replay_weights)
assert replay_weights_sum > 0.0
replay_weights = [(replay_weight / replay_weights_sum) * replay_fraction for replay_weight in replay_weights]
# merge weights with the replay weights given the replay_fraction
weights = weights + replay_weights
# Add 0.5% (the 1.005 factor) so in case the blending dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
weighted_num_samples = []
for weight in weights:
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
return weights, weighted_num_samples
def build_weighted_datasets(
neox_args,
train_num_samples,
valid_num_samples,
test_num_samples,
train_weights,
valid_weights,
test_weights,
build_index_mappings=True,
concatenate_train_replay_paths=False,
):
# The concatenate_train_replay_paths bool is necessary to avoid issues when this function gets called a second time.
if neox_args.is_replay_enabled and concatenate_train_replay_paths:
# Merge replay data paths into train data paths logic, but need to keep track of
# what paths in train_data_paths came from replay
num_replay_data_paths = len(neox_args.replay_data_paths)
num_non_replay_data_paths = len(neox_args.train_data_paths)
neox_args.train_data_paths += neox_args.replay_data_paths
else:
num_replay_data_paths = 0
assert not (neox_args.label_data_paths and neox_args.is_replay_enabled), "Simultaneous use of label data and replay is untested.\
Remove assert at your own risk. You might want to add a replay_label_data_paths arg too if relevant."
# build individual datasets
train_datasets, valid_datasets, test_datasets = [], [], []
for i, (train_path, label_path, valid_path, test_path) in enumerate(
zip_longest(
neox_args.train_data_paths,
neox_args.label_data_paths if neox_args.label_data_paths else [],
neox_args.valid_data_paths,
neox_args.test_data_paths,
)
):
if train_path:
if i < len(neox_args.train_data_paths) - num_replay_data_paths:
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
name=f"train_{i}",
data_impl=neox_args.data_impl,
num_samples=train_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=label_path,
)
)
# when dealing with replay dataset, will need to pass neox_args to load idx files instead of building them.
else:
i_replay = i - (len(neox_args.train_data_paths) - num_replay_data_paths)
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
name=f"replay_{i_replay}",
data_impl=neox_args.data_impl,
num_samples=train_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.replay_seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=False,
index_mapping_paths=neox_args.replay_data_to_idx_paths[train_path],
index_offset=neox_args.replay_idx_offsets[i_replay],
reshuffle_when_loading=neox_args.replay_reshuffle_idx,
)
)
if valid_path:
valid_datasets.append(
build_the_dataset(
data_prefix=valid_path,
name=f"valid_{i}",
data_impl=neox_args.data_impl,
num_samples=valid_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
)
)
if test_path:
test_datasets.append(
build_the_dataset(
data_prefix=test_path,
name=f"test_{i}",
data_impl=neox_args.data_impl,
num_samples=test_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
)
)
return train_datasets, valid_datasets, test_datasets
def weights_by_num_docs(l: list, alpha=0.3):
"""
Builds weights from a multinomial distribution over groups of data according to the number of
samples in each group.
We sample from a group according to the probability p(L) ∝ |L| ** α,
where p(L) is the probability of sampling from a given group,
|L| is the number of examples in that datapoint,
and α is a coefficient that acts to upsample data from underrepresented groups
Hence α (`alpha`) allows us to control how much to 'boost' the probability of training on low-resource groups.
See https://arxiv.org/abs/1911.02116 for more details
"""
if len(l) == 1:
return [1.0]
total_n_docs = sum(l)
unbiased_sample_probs = [i / total_n_docs for i in l]
probs = [i**alpha for i in unbiased_sample_probs]
# normalize
total = sum(probs)
probs = [i / total for i in probs]
# weights should be the inverse of the number of samples
unbiased_sample_probs_inverse = [1 - p for p in unbiased_sample_probs]
weights = [p * p2 for p, p2 in zip(probs, unbiased_sample_probs_inverse)]
# normalize
total = sum(weights)
weights = [i / total for i in weights]
return weights
def build_train_valid_test_data_iterators(neox_args):
"""XXX"""
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
print_rank_0("> building train, validation, and test datasets ...")
# Ensure only the first/last pipeline stages have data loaders
if neox_args.is_pipe_parallel:
is_first_stage = mpu.get_pipe_parallel_rank() == 0
is_last_stage = (
mpu.get_pipe_parallel_rank() == mpu.get_pipe_parallel_world_size() - 1
)
pipe_load = is_first_stage or is_last_stage
else:
pipe_load = True
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0 and pipe_load:
# Number of train/valid/test samples.
train_iters = neox_args.train_iters
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
test_iters = neox_args.eval_iters
train_val_test_num_samples = [
train_iters * neox_args.train_batch_size,
eval_iters * neox_args.train_batch_size,
test_iters * neox_args.train_batch_size,
]
if neox_args.train_data_paths:
# when individual train / valid / test data paths are provided
# normalize weight values and get num samples for each dataset
if neox_args.is_replay_enabled:
train_weights, train_num_samples = get_normalized_weights_and_num_samples_with_replay(
neox_args.train_data_weights, neox_args.replay_data_weights,
neox_args.replay_fraction, train_val_test_num_samples[0]
)
else:
train_weights, train_num_samples = get_normalized_weights_and_num_samples(
neox_args.train_data_weights, train_val_test_num_samples[0]
)
valid_weights, valid_num_samples = get_normalized_weights_and_num_samples(
neox_args.valid_data_weights, train_val_test_num_samples[1]
)
test_weights, test_num_samples = get_normalized_weights_and_num_samples(
neox_args.test_data_weights, train_val_test_num_samples[2]
)
# build individual datasets
train_datasets, valid_datasets, test_datasets = build_weighted_datasets(
neox_args,
train_num_samples,
valid_num_samples,
test_num_samples,
train_weights,
valid_weights,
test_weights,
build_index_mappings=not neox_args.weight_by_num_documents,
concatenate_train_replay_paths=True,
)
if neox_args.weight_by_num_documents:
assert not neox_args.is_replay_enabled, "Replay not tested in the case of autoweighting, remove assert at your own risk.\
I suspect that something might break with the concatenation of the train and replay happening twice due to a second call\
of build_weighted_datasets, so setting it to False with concatenate_train_replay_paths=False."
# gets the number of documents in each datapath
get_num_docs_list = lambda datasets: [
dataset.indexed_dataset.sizes.shape[0] for dataset in datasets
]
train_num_docs, valid_num_docs, test_num_docs = (
get_num_docs_list(train_datasets),
get_num_docs_list(valid_datasets),
get_num_docs_list(test_datasets),
)
# builds weights according to alpha + the number of docs
fn = partial(
weights_by_num_docs, alpha=neox_args.weighted_sampler_alpha
)
train_weights, valid_weights, test_weights = (
fn(train_num_docs),
fn(valid_num_docs),
fn(test_num_docs),
)
(
train_weights,
train_num_samples,
) = get_normalized_weights_and_num_samples(
train_weights, train_val_test_num_samples[0]
)
(
valid_weights,
valid_num_samples,
) = get_normalized_weights_and_num_samples(
valid_weights, train_val_test_num_samples[1]
)
test_weights, test_num_samples = get_normalized_weights_and_num_samples(
test_weights, train_val_test_num_samples[2]
)
# rebuild datasets weighted according to new weights
train_datasets, valid_datasets, test_datasets = build_weighted_datasets(
neox_args,
train_num_samples,
valid_num_samples,
test_num_samples,
train_weights,
valid_weights,
test_weights,
concatenate_train_replay_paths=False,
)
if train_datasets:
train_ds = BlendableDataset(train_datasets, train_weights)
if valid_datasets:
valid_ds = BlendableDataset(valid_datasets, valid_weights)
if test_datasets:
test_ds = BlendableDataset(test_datasets, test_weights)
else:
assert not neox_args.is_replay_enabled, "Replay not implemented in the case of autosplitting into train/val/test datasets."
# when just data_path is provided
# split dataset into train, valid and test from data_path
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=neox_args.data_path,
use_shared_fs=neox_args.use_shared_fs,
data_impl=neox_args.data_impl,
splits_string=neox_args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
)
# Build dataloders.
train_dataloader = make_data_loader(train_ds, neox_args=neox_args)
valid_dataloader = make_data_loader(valid_ds, neox_args=neox_args)
test_dataloader = make_data_loader(test_ds, neox_args=neox_args)
# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and neox_args.train_iters > 0
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0
do_test = test_dataloader is not None and neox_args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
if neox_args.is_pipe_parallel:
# Only first/last pipeline stages have data loaders, so pipeline parallelism should
# broadcast globally instead of just the model parallel group.
torch.distributed.broadcast(flags, src=0)
else:
torch.distributed.broadcast(
flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group(),
)
neox_args.do_train = flags[0].item()
neox_args.do_valid = flags[1].item()
neox_args.do_test = flags[2].item()
# Shift the start iterations.
if train_dataloader is not None:
train_dataloader.batch_sampler.start_iter = (
neox_args.iteration * neox_args.gradient_accumulation_steps
) % len(train_dataloader)
print_rank_0(
"setting training data start iteration to {}".format(
train_dataloader.batch_sampler.start_iter
)
)
if valid_dataloader is not None:
start_iter_val = (
(neox_args.iteration * neox_args.gradient_accumulation_steps)
// neox_args.eval_interval
) * neox_args.eval_iters
valid_dataloader.batch_sampler.start_iter = start_iter_val % len(
valid_dataloader
)
print_rank_0(
"setting validation data start iteration to {}".format(
valid_dataloader.batch_sampler.start_iter
)
)
# Build iterators.
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader)
else:
train_data_iterator = None
if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader)
else:
valid_data_iterator = None
if test_dataloader is not None:
test_data_iterator = iter(test_dataloader)
else:
test_data_iterator = None
return train_data_iterator, valid_data_iterator, test_data_iterator
def compile_helper():
"""Compile helper function at runtime. Make sure this
is invoked on a single process."""
import os
import subprocess
path = os.path.abspath(os.path.dirname(__file__))
ret = subprocess.run(["make", "-C", path])
if ret.returncode != 0:
print("Making C++ dataset helpers module failed, exiting.")
import sys
sys.exit(1)