From b644c18784605fd38be4302e4bae6016e294796e Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 12:25:42 +0100 Subject: [PATCH 01/15] Add beginning of target seperator between input and target --- ...ec_c4_prefix_lm_with_bot_before_target.gin | 19 ++++ bigscience/gins/task.py | 103 ++++++++++++++++++ bigscience/scripts/setup_vm.sh | 1 + bigscience/scripts/test_seqio_dataset.py | 16 ++- 4 files changed, 135 insertions(+), 4 deletions(-) create mode 100644 bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin diff --git a/bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin b/bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin new file mode 100644 index 000000000..11a98d030 --- /dev/null +++ b/bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin @@ -0,0 +1,19 @@ +from __gin__ import dynamic_registration + +from t5x import models +import seqio + +include "bigscience/gins/nc_dec_xxl.gin" +include "t5x/configs/runs/pretrain.gin" +include "bigscience/gins/pretrainer_base.gin" + +TASK_FEATURE_LENGTHS = { + "decoder_target_tokens": 626, + "decoder_input_tokens": 626, + "decoder_segment_ids": 626, + "decoder_causal_attention": 626, + "targets": 625 # we have to take in account an extra token between input and target +} +MIXTURE_OR_TASK_NAME = "c4_prefix_lm_objective_decoder_architecture_with_eoi_seperator" + +models.DecoderOnlyModel.feature_converter_cls = @seqio.PassThroughFeatureConverter \ No newline at end of file diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index 7db13c4ba..5ae4842e7 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -1,6 +1,9 @@ +import dataclasses import functools import seqio +import t5 +import tensorflow as tf from t5.data import preprocessors, get_default_vocabulary from t5.data.preprocessors import select_random_chunk, reduce_concat_tokens, split_tokens @@ -43,6 +46,106 @@ def full_lm(dataset, sequence_length, output_features): }, metric_fns=[]) +# We want input and target to have an additional token between. +# Inspired by https://github.com/google-research/text-to-text-transfer-transformer/blob/9844ddb4f760ae8a1d4de410578f6211e487bbf9/t5/data/tasks.py#L445 + +assert get_default_vocabulary().vocab_size == 32100, "Use T5 tokenizer by default" +BOT_ID = 32000 # FIXME: this is only true for t5 tokenizer right now. +@dataclasses.dataclass(frozen=True) +class FancyFeature(seqio.Feature): + # This token is use to seperate input and target. `bot` is the acronym for beginning of target + add_bot: bool = False + +def pack_prefix_lm_decoder_only(ds, + sequence_length, + output_features, + loss_on_targets_only=True, + pad_id=0): + """Randomly split the tokens for the prefix LM objective.""" + packed_length = sequence_length["decoder_input_tokens"] + assert packed_length % 2 == 0 + # "targets" is a special key + add_eoi = output_features["decoder_input_tokens"].add_eoi + + assert all(l == packed_length for key, l in sequence_length.items() if (not add_eoi) or key != "targets") + assert all(l.add_eoi == add_eoi for key, l in output_features.items() if key != "targets") + if add_eoi: + assert sequence_length["targets"] == packed_length - 1 + else: + assert sequence_length["targets"] == packed_length + + @seqio.utils.map_over_dataset(num_seeds=1) + def pack_examples(example, seed): + split_point = tf.random.stateless_uniform((), + minval=1, + # Adding an extra token costs a bit. + maxval=packed_length if output_features["decoder_input_tokens"].add_eoi else packed_length - 1, + seed=seed, + dtype=tf.int32) + if output_features["decoder_input_tokens"].add_eoi: + decoder_target_tokens = tf.concat( + [ + example['targets'][:split_point - 1], + # bot will be the same as _. Not ideal, but the tokenizer doesn't have `bos` right now. + BOT_ID, + example['targets'][split_point - 1:], + ], + axis=-1 + ) + else: + decoder_target_tokens = example['targets'] + + decoder_input_tokens = seqio.utils.make_autoregressive_inputs(decoder_target_tokens) + + if loss_on_targets_only: + decoder_loss_weights = tf.cast( + tf.range(packed_length) >= split_point, tf.int32) + else: + decoder_loss_weights = tf.ones((packed_length,), dtype=tf.int32) + + padding_mask = tf.cast( + tf.not_equal(decoder_target_tokens, pad_id), dtype=tf.int32) + decoder_loss_weights *= padding_mask + + decoder_causal_attention = tf.cast( + tf.range(packed_length) <= split_point, tf.int32) + + return { + 'decoder_target_tokens': decoder_target_tokens, + 'decoder_input_tokens': decoder_input_tokens, + 'decoder_loss_weights': decoder_loss_weights, + 'decoder_causal_attention': decoder_causal_attention, + } + + return pack_examples(ds) + +TaskRegistry.add( + "c4_prefix_lm_objective_decoder_architecture_with_bot_seperator", + source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"), + preprocessors=[ + functools.partial( + preprocessors.rekey, key_map={ + "inputs": None, + "targets": "text" + }), + seqio.preprocessors.tokenize, + seqio.CacheDatasetPlaceholder(), + t5.data.preprocessors.targets_for_prefix_lm_objective, + pack_prefix_lm_decoder_only, + ], + output_features={ + "decoder_target_tokens": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True), + "decoder_input_tokens": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True), + "decoder_loss_weights": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True), + "decoder_causal_attention": FancyFeature( + vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True), + # All but the last stage of the preprocessing uses "targets" as the key, + # so this output feature is necessary. It is not marked required because + # the final preprocessor drops it. + "targets": seqio.Feature(vocabulary=get_default_vocabulary(), required=False), + }, + metric_fns=[]) + # --- Improve sharding --- # def fully_sharded_logical_axis_rules() -> LogicalAxisRules: diff --git a/bigscience/scripts/setup_vm.sh b/bigscience/scripts/setup_vm.sh index 2b993b1d6..8b56d9a82 100644 --- a/bigscience/scripts/setup_vm.sh +++ b/bigscience/scripts/setup_vm.sh @@ -35,6 +35,7 @@ popd #rm -rf t5x git clone https://github.com/bigscience-workshop/t5x.git pushd t5x +git checkout thomas/prefix_lm_add_token pip3 install -e . popd diff --git a/bigscience/scripts/test_seqio_dataset.py b/bigscience/scripts/test_seqio_dataset.py index 6e66b6f69..0a306f0c8 100644 --- a/bigscience/scripts/test_seqio_dataset.py +++ b/bigscience/scripts/test_seqio_dataset.py @@ -1,4 +1,5 @@ -from t5x import models +import seqio + from t5x import utils import tensorflow as tf from ..gins import task @@ -6,8 +7,15 @@ def main(): ds = utils.get_dataset( utils.DatasetConfig( - "c4_v220_full_lm", - task_feature_lengths={"targets": 626}, + "c4_prefix_lm_objective_decoder_architecture_with_bot_seperator", + + task_feature_lengths={ + "decoder_target_tokens": 626, + "decoder_input_tokens": 626, + "decoder_segment_ids": 626, + "decoder_causal_attention": 626, + "targets": 625 # we have to take in account an extra token between input and target + }, split="train", batch_size=2048, shuffle=False, @@ -19,7 +27,7 @@ def main(): ), 0, 1, - models.DecoderOnlyModel.FEATURE_CONVERTER_CLS + seqio.PassThroughFeatureConverter, ) first_element = next(iter(ds)) print(first_element) From ed76c192eeb75260f4d62300a782a36ba43bd2fa Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 12:35:48 +0100 Subject: [PATCH 02/15] Woops --- bigscience/scripts/test_seqio_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigscience/scripts/test_seqio_dataset.py b/bigscience/scripts/test_seqio_dataset.py index 0a306f0c8..02eb8264f 100644 --- a/bigscience/scripts/test_seqio_dataset.py +++ b/bigscience/scripts/test_seqio_dataset.py @@ -27,7 +27,7 @@ def main(): ), 0, 1, - seqio.PassThroughFeatureConverter, + seqio.PassThroughFeatureConverter(), ) first_element = next(iter(ds)) print(first_element) From 87a8ebbe64b4330bfe9317b037929533c7310c5c Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:02:33 +0100 Subject: [PATCH 03/15] Woops --- bigscience/gins/task.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index 5ae4842e7..a12e1bf54 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -65,11 +65,11 @@ def pack_prefix_lm_decoder_only(ds, packed_length = sequence_length["decoder_input_tokens"] assert packed_length % 2 == 0 # "targets" is a special key - add_eoi = output_features["decoder_input_tokens"].add_eoi + add_bot = output_features["decoder_input_tokens"].add_bot - assert all(l == packed_length for key, l in sequence_length.items() if (not add_eoi) or key != "targets") - assert all(l.add_eoi == add_eoi for key, l in output_features.items() if key != "targets") - if add_eoi: + assert all(l == packed_length for key, l in sequence_length.items() if (not add_bot) or key != "targets") + assert all(l.add_bot == add_bot for key, l in output_features.items() if key != "targets") + if add_bot: assert sequence_length["targets"] == packed_length - 1 else: assert sequence_length["targets"] == packed_length @@ -79,10 +79,10 @@ def pack_examples(example, seed): split_point = tf.random.stateless_uniform((), minval=1, # Adding an extra token costs a bit. - maxval=packed_length if output_features["decoder_input_tokens"].add_eoi else packed_length - 1, + maxval=packed_length if output_features["decoder_input_tokens"].add_bot else packed_length - 1, seed=seed, dtype=tf.int32) - if output_features["decoder_input_tokens"].add_eoi: + if output_features["decoder_input_tokens"].add_bot: decoder_target_tokens = tf.concat( [ example['targets'][:split_point - 1], From 6363b974253ea350b3958f5ea6ef02d613701f72 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:03:05 +0100 Subject: [PATCH 04/15] Woops --- bigscience/scripts/test_seqio_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigscience/scripts/test_seqio_dataset.py b/bigscience/scripts/test_seqio_dataset.py index 02eb8264f..0a306f0c8 100644 --- a/bigscience/scripts/test_seqio_dataset.py +++ b/bigscience/scripts/test_seqio_dataset.py @@ -27,7 +27,7 @@ def main(): ), 0, 1, - seqio.PassThroughFeatureConverter(), + seqio.PassThroughFeatureConverter, ) first_element = next(iter(ds)) print(first_element) From 360040a8b5a707f0d2272115e9d942bfa9bb75e0 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:10:47 +0100 Subject: [PATCH 05/15] Concat needs to be tensors --- bigscience/gins/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index a12e1bf54..b5002fc75 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -87,7 +87,7 @@ def pack_examples(example, seed): [ example['targets'][:split_point - 1], # bot will be the same as _. Not ideal, but the tokenizer doesn't have `bos` right now. - BOT_ID, + tf.constant([BOT_ID]), example['targets'][split_point - 1:], ], axis=-1 From cdeaab40b6cace5324ccbe36c21708c0bf5232f0 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:51:43 +0100 Subject: [PATCH 06/15] Change test to test what I want --- bigscience/scripts/test_seqio_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bigscience/scripts/test_seqio_dataset.py b/bigscience/scripts/test_seqio_dataset.py index 0a306f0c8..0bcbd1118 100644 --- a/bigscience/scripts/test_seqio_dataset.py +++ b/bigscience/scripts/test_seqio_dataset.py @@ -42,8 +42,9 @@ def main(): print(tf.shape(first_element["decoder_target_tokens"])) print(tf.shape(first_element["decoder_input_tokens"])) print(tf.shape(first_element["decoder_loss_weights"])) - print(tf.shape(first_element["decoder_segment_ids"])) - print(tf.shape(first_element["decoder_positions"])) - + # print(tf.shape(first_element["decoder_segment_ids"])) + # print(tf.shape(first_element["decoder_positions"])) + print(tf.where(first_element["decoder_target_tokens"] == 32000)) + print(tf.where(first_element["decoder_input_tokens"] == 32000)) if __name__ == "__main__": main() \ No newline at end of file From d94cf75578b7dcf94941ea5adf097ee4986cbf17 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:53:42 +0100 Subject: [PATCH 07/15] Woops --- bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin b/bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin index 11a98d030..fe8e334b8 100644 --- a/bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin +++ b/bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin @@ -14,6 +14,6 @@ TASK_FEATURE_LENGTHS = { "decoder_causal_attention": 626, "targets": 625 # we have to take in account an extra token between input and target } -MIXTURE_OR_TASK_NAME = "c4_prefix_lm_objective_decoder_architecture_with_eoi_seperator" +MIXTURE_OR_TASK_NAME = "c4_prefix_lm_objective_decoder_architecture_with_bot_seperator" models.DecoderOnlyModel.feature_converter_cls = @seqio.PassThroughFeatureConverter \ No newline at end of file From 79379b27e4d3307574ddb350cbacd63a950be996 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 16:24:04 +0100 Subject: [PATCH 08/15] Wtf --- bigscience/gins/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index b5002fc75..fc872516a 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -87,7 +87,7 @@ def pack_examples(example, seed): [ example['targets'][:split_point - 1], # bot will be the same as _. Not ideal, but the tokenizer doesn't have `bos` right now. - tf.constant([BOT_ID]), + [BOT_ID], example['targets'][split_point - 1:], ], axis=-1 From c2395305ae011329d11d3b7585c1ff16d21779cd Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 16:25:19 +0100 Subject: [PATCH 09/15] Maybe --- bigscience/gins/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index fc872516a..e4a555e13 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -90,7 +90,7 @@ def pack_examples(example, seed): [BOT_ID], example['targets'][split_point - 1:], ], - axis=-1 + axis=0 ) else: decoder_target_tokens = example['targets'] From d4b318f676e3b1d0e72dfd0f57e1c42ba3019a0d Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 20:09:23 +0100 Subject: [PATCH 10/15] Maybe 2 --- bigscience/gins/task.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index e4a555e13..94f107da7 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -83,15 +83,21 @@ def pack_examples(example, seed): seed=seed, dtype=tf.int32) if output_features["decoder_input_tokens"].add_bot: - decoder_target_tokens = tf.concat( - [ - example['targets'][:split_point - 1], - # bot will be the same as _. Not ideal, but the tokenizer doesn't have `bos` right now. - [BOT_ID], - example['targets'][split_point - 1:], - ], - axis=0 + decoder_target_tokens = tf.ones( + (tf.shape(example['targets'])[0] + 1), example['targets'].dtype ) + decoder_target_tokens[:split_point-1].assign(example['targets'][:split_point - 1]) + decoder_target_tokens[split_point].assing(BOT_ID) + decoder_target_tokens[split_point:].assign(example['targets'][split_point - 1 :]) + # decoder_target_tokens = tf.concat( + # [ + # example['targets'][:split_point - 1], + # # bot will be the same as _. Not ideal, but the tokenizer doesn't have `bos` right now. + # [BOT_ID], + # example['targets'][split_point - 1:], + # ], + # axis=0 + # ) else: decoder_target_tokens = example['targets'] From 54cfa3d69dc839dd25480c21b66f2639b1145398 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 20:13:21 +0100 Subject: [PATCH 11/15] Maybe 3 --- bigscience/gins/task.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index 94f107da7..61abbc1a3 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -84,11 +84,11 @@ def pack_examples(example, seed): dtype=tf.int32) if output_features["decoder_input_tokens"].add_bot: decoder_target_tokens = tf.ones( - (tf.shape(example['targets'])[0] + 1), example['targets'].dtype + (tf.shape(example['targets'])[0] + 1,), example['targets'].dtype ) - decoder_target_tokens[:split_point-1].assign(example['targets'][:split_point - 1]) - decoder_target_tokens[split_point].assing(BOT_ID) - decoder_target_tokens[split_point:].assign(example['targets'][split_point - 1 :]) + decoder_target_tokens[:split_point-1] = example['targets'][:split_point - 1] + decoder_target_tokens[split_point-1] = BOT_ID + decoder_target_tokens[split_point:] = example['targets'][split_point - 1 :] # decoder_target_tokens = tf.concat( # [ # example['targets'][:split_point - 1], From 1f09775e3c67880299c95da21094d68eec7c922d Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 20:16:35 +0100 Subject: [PATCH 12/15] Revert back --- bigscience/gins/task.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index 61abbc1a3..e4a555e13 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -83,21 +83,15 @@ def pack_examples(example, seed): seed=seed, dtype=tf.int32) if output_features["decoder_input_tokens"].add_bot: - decoder_target_tokens = tf.ones( - (tf.shape(example['targets'])[0] + 1,), example['targets'].dtype + decoder_target_tokens = tf.concat( + [ + example['targets'][:split_point - 1], + # bot will be the same as _. Not ideal, but the tokenizer doesn't have `bos` right now. + [BOT_ID], + example['targets'][split_point - 1:], + ], + axis=0 ) - decoder_target_tokens[:split_point-1] = example['targets'][:split_point - 1] - decoder_target_tokens[split_point-1] = BOT_ID - decoder_target_tokens[split_point:] = example['targets'][split_point - 1 :] - # decoder_target_tokens = tf.concat( - # [ - # example['targets'][:split_point - 1], - # # bot will be the same as _. Not ideal, but the tokenizer doesn't have `bos` right now. - # [BOT_ID], - # example['targets'][split_point - 1:], - # ], - # axis=0 - # ) else: decoder_target_tokens = example['targets'] From d41ff414249a5d05e8c718d7526381b60c359935 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 20:17:28 +0100 Subject: [PATCH 13/15] Add test --- bigscience/scripts/test_seqio_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bigscience/scripts/test_seqio_dataset.py b/bigscience/scripts/test_seqio_dataset.py index 0bcbd1118..87a3ba554 100644 --- a/bigscience/scripts/test_seqio_dataset.py +++ b/bigscience/scripts/test_seqio_dataset.py @@ -46,5 +46,7 @@ def main(): # print(tf.shape(first_element["decoder_positions"])) print(tf.where(first_element["decoder_target_tokens"] == 32000)) print(tf.where(first_element["decoder_input_tokens"] == 32000)) + print(ds.element_spec) + if __name__ == "__main__": main() \ No newline at end of file From 1a3f96aa405d3c399d63d0ab577ed7c3e6896489 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 20:19:41 +0100 Subject: [PATCH 14/15] Maybe 4 --- bigscience/gins/task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index e4a555e13..865085217 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -79,10 +79,10 @@ def pack_examples(example, seed): split_point = tf.random.stateless_uniform((), minval=1, # Adding an extra token costs a bit. - maxval=packed_length if output_features["decoder_input_tokens"].add_bot else packed_length - 1, + maxval=packed_length if add_bot else packed_length - 1, seed=seed, dtype=tf.int32) - if output_features["decoder_input_tokens"].add_bot: + if add_bot: decoder_target_tokens = tf.concat( [ example['targets'][:split_point - 1], From 6d08e682d5cb5e91de07189311d87e834738d84a Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 14 Dec 2021 20:30:23 +0100 Subject: [PATCH 15/15] Maybe 5 --- bigscience/gins/task.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bigscience/gins/task.py b/bigscience/gins/task.py index 865085217..e585f271c 100644 --- a/bigscience/gins/task.py +++ b/bigscience/gins/task.py @@ -92,6 +92,8 @@ def pack_examples(example, seed): ], axis=0 ) + # This has to be specified otherwise dataset tensor spec assigns None in shape. + decoder_target_tokens = tf.reshape(decoder_target_tokens, (packed_length,)) else: decoder_target_tokens = example['targets']