forked from EleutherAI/gpt-neo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinputs.py
255 lines (196 loc) · 10.2 KB
/
inputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import numpy as np
import tensorflow.compat.v1 as tf
from functools import partial
from data.encoders import encode
def generic_text(params, eval=False, sample_text_fn=None):
i = 0 if not eval else 1
print('##############################')
print(params["datasets"])
print('##############################')
weights = []
datasets = []
for dataset in params["datasets"]:
dataset_id, stitch, datatype, weight = dataset
assert dataset_id in params['dataset_configs'], f'Unknown dataset id {dataset_id} given. Please make sure your dataset ids contain that configuration'
dataset_config = params['dataset_configs'][dataset_id]
path_key = 'path' if not eval else 'eval_path'
path = dataset_config[path_key]
datasets.append(text_dataset(
tf.io.gfile.glob(path),
params,
stitch = stitch,
datatype = datatype,
batch = False,
sample_text_fn = sample_text_fn
))
weights.append(weight)
batch_size = params['eval_batch_size' if eval else 'train_batch_size']
seed = params.get('seed', None)
dataset = tf.data.experimental.sample_from_datasets(datasets, weights=weights, seed=seed)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
return dataset
def text_dataset(files, params, stitch, datatype, batch=True, sample_text_fn=None):
seed = params.get('seed', None)
deterministic = seed is not None
num_parallel_calls = 1 if deterministic else tf.data.experimental.AUTOTUNE
dataset = tf.data.Dataset.from_tensor_slices(files)
if deterministic:
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)
else:
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False))
if "documents" in datatype:
def _parse_function(example_proto):
features = {
# "hash": tf.VarLenFeature(tf.string),
"text": tf.VarLenFeature(tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["text"], parsed_features["text"].dense_shape[0]
else:
def _parse_function(example_proto):
features = {
"text": tf.VarLenFeature(tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["text"] # Assuming the text is not sparse
dataset = dataset.map(_parse_function, num_parallel_calls=1)
# Subsample method
if "documents" in datatype:
# Since samples can be less than the correct length, and TPUs don't like variable lengths, this function stitches together enough samples
# to have a text at least 1024 tokens long. For this to work the stitch parameter must be correctly tuned so that
# stitch * min(characters_in_text) >= amount
def _stitch_text(x, y):
x = tf.sparse.to_dense(x)
def _get_x(i):
return tf.gather(x[i], tf.range(y[i]))
out = _get_x(0)
eos_id = params['eos_id']
for i in range(1, stitch):
out = tf.concat([out, [eos_id], _get_x(i)], axis=0) # text1<|endoftext|>text2
return out
# Hack-y way to stitch together multiple texts
dataset = dataset.shuffle(1000 * stitch, seed=seed).batch(stitch, drop_remainder=True).map(_stitch_text,
num_parallel_calls=num_parallel_calls)
# Sample 1024(+1) tokens from the stitched together text
is_random_documents = datatype == "documents_random"
if sample_text_fn is not None:
_sample_text = partial(sample_text_fn, random_documents = is_random_documents)
else:
_sample_text = autoregressive_sample_text_random_documents if is_random_documents else autoregressive_sample_text
_sample_text = partial(_sample_text, params)
dataset = dataset.map(_sample_text, num_parallel_calls=num_parallel_calls)
if batch:
dataset = dataset.batch(params["train_batch_size"], drop_remainder=True).prefetch(params["iterations"] * 2)
dataset = dataset.repeat()
return dataset
def autoregressive_sample_text(params, x):
vals1 = x[:params["n_ctx"]]
vals2 = x[1:params["n_ctx"] + 1]
vals1 = tf.reshape(vals1, [params["n_ctx"]])
vals2 = tf.reshape(vals2, [params["n_ctx"]])
vals1 = tf.cast(vals1, dtype=tf.int32)
vals2 = tf.cast(vals2, dtype=tf.int32)
return vals1, vals2
def autoregressive_sample_text_random_documents(params, x):
seed = params.get('seed', None)
s = tf.size(x)
r = tf.random.uniform([], maxval=s - (params["n_ctx"] + 1), dtype=tf.dtypes.int32, seed=seed)
r1 = tf.range(r, r + params["n_ctx"])
r2 = tf.range(r + 1, (r + 1) + params["n_ctx"])
r1 = tf.reshape(r1, [params["n_ctx"]]) # Somehow, this makes the compiler happy
r2 = tf.reshape(r2, [params["n_ctx"]]) # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input
vals1 = tf.gather(x, r1)
vals2 = tf.gather(x, r2)
vals1 = tf.reshape(vals1, [params["n_ctx"]])
vals2 = tf.reshape(vals2, [params["n_ctx"]])
vals1 = tf.cast(vals1, dtype=tf.int32)
vals2 = tf.cast(vals2, dtype=tf.int32)
return vals1, vals2
def mlm_sample_text(params, x, random_documents = False):
seed = params.get('seed', None)
ctx_len = params["n_ctx"]
assert 'mlm_mask_id' in params, 'the key `mlm_mask_id` must be set on your config to do masked language model training, specifying the id of the reserved mask token'
mask_id = params['mlm_mask_id']
cls_token_id = params.get('mlm_cls_token_id', None)
num_tokens = params.get('n_vocab', None)
mask_ignore_ids = set(params.get('mlm_mask_ignore_ids', []))
mask_ignore_ids.add(cls_token_id)
mask_prob = params.get('mlm_mask_prob', 0.15)
same_token_prob = params.get('mlm_same_token_prob', 0.10)
random_token_prob = params.get('mlm_random_token_prob', 0.)
seq_len = ctx_len if cls_token_id is None else (ctx_len - 1)
if random_documents:
s = tf.size(x)
r = tf.random.uniform([], maxval=(s - seq_len), dtype=tf.dtypes.int32, seed=seed)
r1 = tf.range(r, r + seq_len)
r1 = tf.reshape(r1, [seq_len])
features = tf.gather(x, r1)
else:
features = x[:seq_len]
# add cls token id if specified by `mlm_cls_token_id`
if cls_token_id is not None:
features = tf.pad(features, [[1, 0]], constant_values=cls_token_id)
features = tf.cast(features, dtype=tf.int32)
shape = features.shape
# determine which tokens are mask-able
can_mask = tf.not_equal(features, 0)
for ignore_id in mask_ignore_ids:
can_mask &= tf.not_equal(features, ignore_id)
# generate boolean mask for masking ids
mask_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), mask_prob)
mask_mask &= can_mask
# generate mask for actually replacing the tokens, for allowing a small number of tokens to stay the same
replace_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), 1 - same_token_prob)
# randomly replace some tokens with random tokens before masking
if random_token_prob > 0:
random_token_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), random_token_prob)
random_tokens = tf.random.uniform(shape, minval = 1, maxval = num_tokens, dtype = tf.dtypes.int32, seed = seed)
# make sure random tokens do not include illegal token ids specified by `mlm_mask_ignore_ids`
random_can_mask = tf.not_equal(random_tokens, 0)
for ignore_id in mask_ignore_ids:
random_can_mask &= tf.not_equal(random_tokens, ignore_id)
features = tf.where(random_token_mask & random_can_mask, random_tokens, features)
# mask the tokens
mask_tokens = tf.ones(shape, dtype=tf.int32) * mask_id
masked_features = tf.where(mask_mask & replace_mask, mask_tokens, features)
# labels will be set to 0 for all non-masked tokens
labels = tf.where(mask_mask, tf.zeros(shape, dtype=tf.int32), features)
masked_features, labels = map(lambda t: tf.reshape(t, [ctx_len]), (masked_features, labels))
return masked_features, labels
def pred_input(params, logger, enc=None,
path_to_prompt=""):
unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
"previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
"researchers was the fact that the unicorns spoke perfect English."
text = unicorns if path_to_prompt == "" else open(path_to_prompt, "r").read()
tokens = encode(enc, text)
if len(tokens) > params["n_ctx"]:
logger.info("The length of your input prompt is longer than the model's context length - truncating input.")
tokens = tokens[len(tokens) - params["n_ctx"]:]
if len(tokens) < params["n_ctx"]:
tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]], constant_values=params["padding_id"])
t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]])
dataset = tf.data.Dataset.from_tensors(t)
def _dummy_labels(x):
return x, x
dataset = dataset.map(_dummy_labels)
return dataset
def handle_pred_output(predictions, logger, enc, params, out_name="test"):
with tf.gfile.Open(f"{out_name}.txt", "w") as f:
for i, p in enumerate(predictions):
p = p["outputs"]
# remove eos + padding ids from output
idx = np.argmax(p == params['eos_id'])
if idx > 0:
p = p[:idx]
idx = np.argmax(p == params['padding_id'])
if idx > 0:
p = p[:idx]
text = enc.decode(p)
f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
f.write(text)
f.write("\n" + "=" * 80 + "\n")
logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
logger.info(text)
logger.info("\n" + "=" * 80 + "\n")