@@ -145,14 +145,33 @@ def __init__(self, guid, text_a, text_b=None, label=None):
145
145
self .label = label
146
146
147
147
148
+ class PaddingInputExample (object ):
149
+ """Fake example so the num input examples is a multiple of the batch size.
150
+
151
+ When running eval/predict on the TPU, we need to pad the number of examples
152
+ to be a multiple of the batch size, because the TPU requires a fixed batch
153
+ size. The alternative is to drop the last batch, which is bad because it means
154
+ the entire output data won't be generated.
155
+
156
+ We use this class instead of `None` because treating `None` as padding
157
+ battches could cause silent errors.
158
+ """
159
+
160
+
148
161
class InputFeatures (object ):
149
162
"""A single set of features of data."""
150
163
151
- def __init__ (self , input_ids , input_mask , segment_ids , label_id ):
164
+ def __init__ (self ,
165
+ input_ids ,
166
+ input_mask ,
167
+ segment_ids ,
168
+ label_id ,
169
+ is_real_example = True ):
152
170
self .input_ids = input_ids
153
171
self .input_mask = input_mask
154
172
self .segment_ids = segment_ids
155
173
self .label_id = label_id
174
+ self .is_real_example = is_real_example
156
175
157
176
158
177
class DataProcessor (object ):
@@ -358,6 +377,15 @@ def _create_examples(self, lines, set_type):
358
377
def convert_single_example (ex_index , example , label_list , max_seq_length ,
359
378
tokenizer ):
360
379
"""Converts a single `InputExample` into a single `InputFeatures`."""
380
+
381
+ if isinstance (example , PaddingInputExample ):
382
+ return InputFeatures (
383
+ input_ids = [0 ] * max_seq_length ,
384
+ input_mask = [0 ] * max_seq_length ,
385
+ segment_ids = [0 ] * max_seq_length ,
386
+ label_id = 0 ,
387
+ is_real_example = False )
388
+
361
389
label_map = {}
362
390
for (i , label ) in enumerate (label_list ):
363
391
label_map [label ] = i
@@ -393,7 +421,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
393
421
# it easier for the model to learn the concept of sequences.
394
422
#
395
423
# For classification tasks, the first vector (corresponding to [CLS]) is
396
- # used as as the "sentence vector". Note that this only makes sense because
424
+ # used as the "sentence vector". Note that this only makes sense because
397
425
# the entire model is fine-tuned.
398
426
tokens = []
399
427
segment_ids = []
@@ -443,7 +471,8 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
443
471
input_ids = input_ids ,
444
472
input_mask = input_mask ,
445
473
segment_ids = segment_ids ,
446
- label_id = label_id )
474
+ label_id = label_id ,
475
+ is_real_example = True )
447
476
return feature
448
477
449
478
@@ -469,9 +498,12 @@ def create_int_feature(values):
469
498
features ["input_mask" ] = create_int_feature (feature .input_mask )
470
499
features ["segment_ids" ] = create_int_feature (feature .segment_ids )
471
500
features ["label_ids" ] = create_int_feature ([feature .label_id ])
501
+ features ["is_real_example" ] = create_int_feature (
502
+ [int (feature .is_real_example )])
472
503
473
504
tf_example = tf .train .Example (features = tf .train .Features (feature = features ))
474
505
writer .write (tf_example .SerializeToString ())
506
+ writer .close ()
475
507
476
508
477
509
def file_based_input_fn_builder (input_file , seq_length , is_training ,
@@ -483,6 +515,7 @@ def file_based_input_fn_builder(input_file, seq_length, is_training,
483
515
"input_mask" : tf .FixedLenFeature ([seq_length ], tf .int64 ),
484
516
"segment_ids" : tf .FixedLenFeature ([seq_length ], tf .int64 ),
485
517
"label_ids" : tf .FixedLenFeature ([], tf .int64 ),
518
+ "is_real_example" : tf .FixedLenFeature ([], tf .int64 ),
486
519
}
487
520
488
521
def _decode_record (record , name_to_features ):
@@ -599,6 +632,11 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
599
632
input_mask = features ["input_mask" ]
600
633
segment_ids = features ["segment_ids" ]
601
634
label_ids = features ["label_ids" ]
635
+ is_real_example = None
636
+ if "is_real_example" in features :
637
+ is_real_example = tf .cast (features ["is_real_example" ], dtype = tf .float32 )
638
+ else :
639
+ is_real_example = tf .ones (tf .shape (label_ids ), dtype = tf .float32 )
602
640
603
641
is_training = (mode == tf .estimator .ModeKeys .TRAIN )
604
642
@@ -643,24 +681,28 @@ def tpu_scaffold():
643
681
scaffold_fn = scaffold_fn )
644
682
elif mode == tf .estimator .ModeKeys .EVAL :
645
683
646
- def metric_fn (per_example_loss , label_ids , logits ):
684
+ def metric_fn (per_example_loss , label_ids , logits , is_real_example ):
647
685
predictions = tf .argmax (logits , axis = - 1 , output_type = tf .int32 )
648
- accuracy = tf .metrics .accuracy (label_ids , predictions )
649
- loss = tf .metrics .mean (per_example_loss )
686
+ accuracy = tf .metrics .accuracy (
687
+ labels = label_ids , predictions = predictions , weights = is_real_example )
688
+ loss = tf .metrics .mean (values = per_example_loss , weights = is_real_example )
650
689
return {
651
690
"eval_accuracy" : accuracy ,
652
691
"eval_loss" : loss ,
653
692
}
654
693
655
- eval_metrics = (metric_fn , [per_example_loss , label_ids , logits ])
694
+ eval_metrics = (metric_fn ,
695
+ [per_example_loss , label_ids , logits , is_real_example ])
656
696
output_spec = tf .contrib .tpu .TPUEstimatorSpec (
657
697
mode = mode ,
658
698
loss = total_loss ,
659
699
eval_metrics = eval_metrics ,
660
700
scaffold_fn = scaffold_fn )
661
701
else :
662
702
output_spec = tf .contrib .tpu .TPUEstimatorSpec (
663
- mode = mode , predictions = probabilities , scaffold_fn = scaffold_fn )
703
+ mode = mode ,
704
+ predictions = {"probabilities" : probabilities },
705
+ scaffold_fn = scaffold_fn )
664
706
return output_spec
665
707
666
708
return model_fn
@@ -748,6 +790,9 @@ def main(_):
748
790
"xnli" : XnliProcessor ,
749
791
}
750
792
793
+ tokenization .validate_case_matches_checkpoint (FLAGS .do_lower_case ,
794
+ FLAGS .init_checkpoint )
795
+
751
796
if not FLAGS .do_train and not FLAGS .do_eval and not FLAGS .do_predict :
752
797
raise ValueError (
753
798
"At least one of `do_train`, `do_eval` or `do_predict' must be True." )
@@ -836,22 +881,33 @@ def main(_):
836
881
837
882
if FLAGS .do_eval :
838
883
eval_examples = processor .get_dev_examples (FLAGS .data_dir )
884
+ num_actual_eval_examples = len (eval_examples )
885
+ if FLAGS .use_tpu :
886
+ # TPU requires a fixed batch size for all batches, therefore the number
887
+ # of examples must be a multiple of the batch size, or else examples
888
+ # will get dropped. So we pad with fake examples which are ignored
889
+ # later on. These do NOT count towards the metric (all tf.metrics
890
+ # support a per-instance weight, and these get a weight of 0.0).
891
+ while len (eval_examples ) % FLAGS .eval_batch_size != 0 :
892
+ eval_examples .append (PaddingInputExample ())
893
+
839
894
eval_file = os .path .join (FLAGS .output_dir , "eval.tf_record" )
840
895
file_based_convert_examples_to_features (
841
896
eval_examples , label_list , FLAGS .max_seq_length , tokenizer , eval_file )
842
897
843
898
tf .logging .info ("***** Running evaluation *****" )
844
- tf .logging .info (" Num examples = %d" , len (eval_examples ))
899
+ tf .logging .info (" Num examples = %d (%d actual, %d padding)" ,
900
+ len (eval_examples ), num_actual_eval_examples ,
901
+ len (eval_examples ) - num_actual_eval_examples )
845
902
tf .logging .info (" Batch size = %d" , FLAGS .eval_batch_size )
846
903
847
904
# This tells the estimator to run through the entire set.
848
905
eval_steps = None
849
906
# However, if running eval on the TPU, you will need to specify the
850
907
# number of steps.
851
908
if FLAGS .use_tpu :
852
- # Eval will be slightly WRONG on the TPU because it will truncate
853
- # the last batch.
854
- eval_steps = int (len (eval_examples ) / FLAGS .eval_batch_size )
909
+ assert len (eval_examples ) % FLAGS .eval_batch_size == 0
910
+ eval_steps = int (len (eval_examples ) // FLAGS .eval_batch_size )
855
911
856
912
eval_drop_remainder = True if FLAGS .use_tpu else False
857
913
eval_input_fn = file_based_input_fn_builder (
@@ -871,20 +927,26 @@ def main(_):
871
927
872
928
if FLAGS .do_predict :
873
929
predict_examples = processor .get_test_examples (FLAGS .data_dir )
930
+ num_actual_predict_examples = len (predict_examples )
931
+ if FLAGS .use_tpu :
932
+ # TPU requires a fixed batch size for all batches, therefore the number
933
+ # of examples must be a multiple of the batch size, or else examples
934
+ # will get dropped. So we pad with fake examples which are ignored
935
+ # later on.
936
+ while len (predict_examples ) % FLAGS .predict_batch_size != 0 :
937
+ predict_examples .append (PaddingInputExample ())
938
+
874
939
predict_file = os .path .join (FLAGS .output_dir , "predict.tf_record" )
875
940
file_based_convert_examples_to_features (predict_examples , label_list ,
876
941
FLAGS .max_seq_length , tokenizer ,
877
942
predict_file )
878
943
879
944
tf .logging .info ("***** Running prediction*****" )
880
- tf .logging .info (" Num examples = %d" , len (predict_examples ))
945
+ tf .logging .info (" Num examples = %d (%d actual, %d padding)" ,
946
+ len (predict_examples ), num_actual_predict_examples ,
947
+ len (predict_examples ) - num_actual_predict_examples )
881
948
tf .logging .info (" Batch size = %d" , FLAGS .predict_batch_size )
882
949
883
- if FLAGS .use_tpu :
884
- # Warning: According to tpu_estimator.py Prediction on TPU is an
885
- # experimental feature and hence not supported here
886
- raise ValueError ("Prediction in TPU not supported" )
887
-
888
950
predict_drop_remainder = True if FLAGS .use_tpu else False
889
951
predict_input_fn = file_based_input_fn_builder (
890
952
input_file = predict_file ,
@@ -896,11 +958,18 @@ def main(_):
896
958
897
959
output_predict_file = os .path .join (FLAGS .output_dir , "test_results.tsv" )
898
960
with tf .gfile .GFile (output_predict_file , "w" ) as writer :
961
+ num_written_lines = 0
899
962
tf .logging .info ("***** Predict results *****" )
900
- for prediction in result :
963
+ for (i , prediction ) in enumerate (result ):
964
+ probabilities = prediction ["probabilities" ]
965
+ if i >= num_actual_predict_examples :
966
+ break
901
967
output_line = "\t " .join (
902
- str (class_probability ) for class_probability in prediction ) + "\n "
968
+ str (class_probability )
969
+ for class_probability in probabilities ) + "\n "
903
970
writer .write (output_line )
971
+ num_written_lines += 1
972
+ assert num_written_lines == num_actual_predict_examples
904
973
905
974
906
975
if __name__ == "__main__" :
0 commit comments