Skip to content

Commit f39e881

Browse files
Padding examples for TPU eval/predictions and checking case match
1 parent b8ba348 commit f39e881

8 files changed

+164
-34
lines changed

README.md

+7-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Mongolian \*\*\*\*\***
55

66
We uploaded a new multilingual model which does *not* perform any normalization
77
on the input (no lower casing, accent stripping, or Unicode normalization), and
8-
additionally includes Thai and Mongolian.
8+
additionally inclues Thai and Mongolian.
99

1010
**It is recommended to use this version for developing multilingual models,
1111
especially on languages with non-Latin alphabets.**
@@ -38,8 +38,9 @@ repository.
3838

3939
We have made two new BERT models available:
4040

41-
* **[`BERT-Base, Multilingual`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**:
42-
102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
41+
* **[`BERT-Base, Multilingual`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)
42+
(Not recommended, use `Multilingual Cased` instead)**: 102 languages,
43+
12-layer, 768-hidden, 12-heads, 110M parameters
4344
* **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**:
4445
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M
4546
parameters
@@ -228,8 +229,9 @@ The links to the models are here (right-click, 'Save link as...' on the name):
228229
24-layer, 1024-hidden, 16-heads, 340M parameters
229230
* **[`BERT-Base, Multilingual Cased (New, recommended)`](https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip)**:
230231
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
231-
* **[`BERT-Base, Multilingual Uncased (Orig, not recommended)`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**:
232-
102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
232+
* **[`BERT-Base, Multilingual Uncased (Orig, not recommended)`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)
233+
(Not recommended, use `Multilingual Cased` instead)**: 102 languages,
234+
12-layer, 768-hidden, 12-heads, 110M parameters
233235
* **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**:
234236
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M
235237
parameters

create_pretraining_data.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020

2121
import collections
2222
import random
23-
24-
import tokenization
2523
import tensorflow as tf
24+
import tokenization
2625

2726
flags = tf.flags
2827

multilingual.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ chosen because they are the top 100 languages with the largest Wikipedias:
297297
* Volapük
298298
* Waray-Waray
299299
* Welsh
300-
* West
300+
* West Frisian
301301
* Western Punjabi
302302
* Yoruba
303303

optimization.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
7676
train_op = optimizer.apply_gradients(
7777
zip(grads, tvars), global_step=global_step)
7878

79+
# Normally the global step update is done inside of `apply_gradients`.
80+
# However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
81+
# a different optimizer, you should probably take this line out.
7982
new_global_step = global_step + 1
8083
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
8184
return train_op
@@ -137,7 +140,7 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
137140
# the correct way of using L2 regularization/weight decay with Adam,
138141
# since that will interact with the m and v parameters in strange ways.
139142
#
140-
# Instead we want to decay the weights in a manner that doesn't interact
143+
# Instead we want ot decay the weights in a manner that doesn't interact
141144
# with the m/v parameters. This is equivalent to adding the square
142145
# of the weights to the loss with plain (non-momentum) SGD.
143146
if self._do_use_weight_decay(param_name):

run_classifier.py

+89-20
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,33 @@ def __init__(self, guid, text_a, text_b=None, label=None):
145145
self.label = label
146146

147147

148+
class PaddingInputExample(object):
149+
"""Fake example so the num input examples is a multiple of the batch size.
150+
151+
When running eval/predict on the TPU, we need to pad the number of examples
152+
to be a multiple of the batch size, because the TPU requires a fixed batch
153+
size. The alternative is to drop the last batch, which is bad because it means
154+
the entire output data won't be generated.
155+
156+
We use this class instead of `None` because treating `None` as padding
157+
battches could cause silent errors.
158+
"""
159+
160+
148161
class InputFeatures(object):
149162
"""A single set of features of data."""
150163

151-
def __init__(self, input_ids, input_mask, segment_ids, label_id):
164+
def __init__(self,
165+
input_ids,
166+
input_mask,
167+
segment_ids,
168+
label_id,
169+
is_real_example=True):
152170
self.input_ids = input_ids
153171
self.input_mask = input_mask
154172
self.segment_ids = segment_ids
155173
self.label_id = label_id
174+
self.is_real_example = is_real_example
156175

157176

158177
class DataProcessor(object):
@@ -358,6 +377,15 @@ def _create_examples(self, lines, set_type):
358377
def convert_single_example(ex_index, example, label_list, max_seq_length,
359378
tokenizer):
360379
"""Converts a single `InputExample` into a single `InputFeatures`."""
380+
381+
if isinstance(example, PaddingInputExample):
382+
return InputFeatures(
383+
input_ids=[0] * max_seq_length,
384+
input_mask=[0] * max_seq_length,
385+
segment_ids=[0] * max_seq_length,
386+
label_id=0,
387+
is_real_example=False)
388+
361389
label_map = {}
362390
for (i, label) in enumerate(label_list):
363391
label_map[label] = i
@@ -393,7 +421,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
393421
# it easier for the model to learn the concept of sequences.
394422
#
395423
# For classification tasks, the first vector (corresponding to [CLS]) is
396-
# used as as the "sentence vector". Note that this only makes sense because
424+
# used as the "sentence vector". Note that this only makes sense because
397425
# the entire model is fine-tuned.
398426
tokens = []
399427
segment_ids = []
@@ -443,7 +471,8 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
443471
input_ids=input_ids,
444472
input_mask=input_mask,
445473
segment_ids=segment_ids,
446-
label_id=label_id)
474+
label_id=label_id,
475+
is_real_example=True)
447476
return feature
448477

449478

@@ -469,9 +498,12 @@ def create_int_feature(values):
469498
features["input_mask"] = create_int_feature(feature.input_mask)
470499
features["segment_ids"] = create_int_feature(feature.segment_ids)
471500
features["label_ids"] = create_int_feature([feature.label_id])
501+
features["is_real_example"] = create_int_feature(
502+
[int(feature.is_real_example)])
472503

473504
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
474505
writer.write(tf_example.SerializeToString())
506+
writer.close()
475507

476508

477509
def file_based_input_fn_builder(input_file, seq_length, is_training,
@@ -483,6 +515,7 @@ def file_based_input_fn_builder(input_file, seq_length, is_training,
483515
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
484516
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
485517
"label_ids": tf.FixedLenFeature([], tf.int64),
518+
"is_real_example": tf.FixedLenFeature([], tf.int64),
486519
}
487520

488521
def _decode_record(record, name_to_features):
@@ -599,6 +632,11 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
599632
input_mask = features["input_mask"]
600633
segment_ids = features["segment_ids"]
601634
label_ids = features["label_ids"]
635+
is_real_example = None
636+
if "is_real_example" in features:
637+
is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
638+
else:
639+
is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
602640

603641
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
604642

@@ -643,24 +681,28 @@ def tpu_scaffold():
643681
scaffold_fn=scaffold_fn)
644682
elif mode == tf.estimator.ModeKeys.EVAL:
645683

646-
def metric_fn(per_example_loss, label_ids, logits):
684+
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
647685
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
648-
accuracy = tf.metrics.accuracy(label_ids, predictions)
649-
loss = tf.metrics.mean(per_example_loss)
686+
accuracy = tf.metrics.accuracy(
687+
labels=label_ids, predictions=predictions, weights=is_real_example)
688+
loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
650689
return {
651690
"eval_accuracy": accuracy,
652691
"eval_loss": loss,
653692
}
654693

655-
eval_metrics = (metric_fn, [per_example_loss, label_ids, logits])
694+
eval_metrics = (metric_fn,
695+
[per_example_loss, label_ids, logits, is_real_example])
656696
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
657697
mode=mode,
658698
loss=total_loss,
659699
eval_metrics=eval_metrics,
660700
scaffold_fn=scaffold_fn)
661701
else:
662702
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
663-
mode=mode, predictions=probabilities, scaffold_fn=scaffold_fn)
703+
mode=mode,
704+
predictions={"probabilities": probabilities},
705+
scaffold_fn=scaffold_fn)
664706
return output_spec
665707

666708
return model_fn
@@ -748,6 +790,9 @@ def main(_):
748790
"xnli": XnliProcessor,
749791
}
750792

793+
tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
794+
FLAGS.init_checkpoint)
795+
751796
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
752797
raise ValueError(
753798
"At least one of `do_train`, `do_eval` or `do_predict' must be True.")
@@ -836,22 +881,33 @@ def main(_):
836881

837882
if FLAGS.do_eval:
838883
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
884+
num_actual_eval_examples = len(eval_examples)
885+
if FLAGS.use_tpu:
886+
# TPU requires a fixed batch size for all batches, therefore the number
887+
# of examples must be a multiple of the batch size, or else examples
888+
# will get dropped. So we pad with fake examples which are ignored
889+
# later on. These do NOT count towards the metric (all tf.metrics
890+
# support a per-instance weight, and these get a weight of 0.0).
891+
while len(eval_examples) % FLAGS.eval_batch_size != 0:
892+
eval_examples.append(PaddingInputExample())
893+
839894
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
840895
file_based_convert_examples_to_features(
841896
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
842897

843898
tf.logging.info("***** Running evaluation *****")
844-
tf.logging.info(" Num examples = %d", len(eval_examples))
899+
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
900+
len(eval_examples), num_actual_eval_examples,
901+
len(eval_examples) - num_actual_eval_examples)
845902
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
846903

847904
# This tells the estimator to run through the entire set.
848905
eval_steps = None
849906
# However, if running eval on the TPU, you will need to specify the
850907
# number of steps.
851908
if FLAGS.use_tpu:
852-
# Eval will be slightly WRONG on the TPU because it will truncate
853-
# the last batch.
854-
eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size)
909+
assert len(eval_examples) % FLAGS.eval_batch_size == 0
910+
eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
855911

856912
eval_drop_remainder = True if FLAGS.use_tpu else False
857913
eval_input_fn = file_based_input_fn_builder(
@@ -871,20 +927,26 @@ def main(_):
871927

872928
if FLAGS.do_predict:
873929
predict_examples = processor.get_test_examples(FLAGS.data_dir)
930+
num_actual_predict_examples = len(predict_examples)
931+
if FLAGS.use_tpu:
932+
# TPU requires a fixed batch size for all batches, therefore the number
933+
# of examples must be a multiple of the batch size, or else examples
934+
# will get dropped. So we pad with fake examples which are ignored
935+
# later on.
936+
while len(predict_examples) % FLAGS.predict_batch_size != 0:
937+
predict_examples.append(PaddingInputExample())
938+
874939
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
875940
file_based_convert_examples_to_features(predict_examples, label_list,
876941
FLAGS.max_seq_length, tokenizer,
877942
predict_file)
878943

879944
tf.logging.info("***** Running prediction*****")
880-
tf.logging.info(" Num examples = %d", len(predict_examples))
945+
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
946+
len(predict_examples), num_actual_predict_examples,
947+
len(predict_examples) - num_actual_predict_examples)
881948
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
882949

883-
if FLAGS.use_tpu:
884-
# Warning: According to tpu_estimator.py Prediction on TPU is an
885-
# experimental feature and hence not supported here
886-
raise ValueError("Prediction in TPU not supported")
887-
888950
predict_drop_remainder = True if FLAGS.use_tpu else False
889951
predict_input_fn = file_based_input_fn_builder(
890952
input_file=predict_file,
@@ -896,11 +958,18 @@ def main(_):
896958

897959
output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
898960
with tf.gfile.GFile(output_predict_file, "w") as writer:
961+
num_written_lines = 0
899962
tf.logging.info("***** Predict results *****")
900-
for prediction in result:
963+
for (i, prediction) in enumerate(result):
964+
probabilities = prediction["probabilities"]
965+
if i >= num_actual_predict_examples:
966+
break
901967
output_line = "\t".join(
902-
str(class_probability) for class_probability in prediction) + "\n"
968+
str(class_probability)
969+
for class_probability in probabilities) + "\n"
903970
writer.write(output_line)
971+
num_written_lines += 1
972+
assert num_written_lines == num_actual_predict_examples
904973

905974

906975
if __name__ == "__main__":

run_squad.py

+3
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,9 @@ def close(self):
10961096

10971097
def validate_flags_or_throw(bert_config):
10981098
"""Validate the input FLAGS or throw an exception."""
1099+
tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
1100+
FLAGS.init_checkpoint)
1101+
10991102
if not FLAGS.do_train and not FLAGS.do_predict:
11001103
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
11011104

tokenization.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,62 @@
1919
from __future__ import print_function
2020

2121
import collections
22+
import re
2223
import unicodedata
2324
import six
2425
import tensorflow as tf
2526

2627

28+
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29+
"""Checks whether the casing config is consistent with the checkpoint name."""
30+
31+
# The casing has to be passed in by the user and there is no explicit check
32+
# as to whether it matches the checkpoint. The casing information probably
33+
# should have been stored in the bert_config.json file, but it's not, so
34+
# we have to heuristically detect it to validate.
35+
36+
if not init_checkpoint:
37+
return
38+
39+
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40+
if m is None:
41+
return
42+
43+
model_name = m.group(1)
44+
45+
lower_models = [
46+
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47+
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48+
]
49+
50+
cased_models = [
51+
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52+
"multi_cased_L-12_H-768_A-12"
53+
]
54+
55+
is_bad_config = False
56+
if model_name in lower_models and not do_lower_case:
57+
is_bad_config = True
58+
actual_flag = "False"
59+
case_name = "lowercased"
60+
opposite_flag = "True"
61+
62+
if model_name in cased_models and do_lower_case:
63+
is_bad_config = True
64+
actual_flag = "True"
65+
case_name = "cased"
66+
opposite_flag = "False"
67+
68+
if is_bad_config:
69+
raise ValueError(
70+
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71+
"However, `%s` seems to be a %s model, so you "
72+
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73+
"how the model was pre-training. If this error is wrong, please "
74+
"just comment out this check." % (actual_flag, init_checkpoint,
75+
model_name, case_name, opposite_flag))
76+
77+
2778
def convert_to_unicode(text):
2879
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
2980
if six.PY3:
@@ -84,7 +135,10 @@ def load_vocab(vocab_file):
84135

85136
def convert_by_vocab(vocab, items):
86137
"""Converts a sequence of [tokens|ids] using the vocab."""
87-
return [vocab[item] for item in items]
138+
output = []
139+
for item in items:
140+
output.append(vocab[item])
141+
return output
88142

89143

90144
def convert_tokens_to_ids(vocab, tokens):

0 commit comments

Comments
 (0)