Skip to content

[WIP] Prefix lm add BOT token to separate input and target #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
19 changes: 19 additions & 0 deletions bigscience/gins/nc_dec_c4_prefix_lm_with_bot_before_target.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __gin__ import dynamic_registration

from t5x import models
import seqio

include "bigscience/gins/nc_dec_xxl.gin"
include "t5x/configs/runs/pretrain.gin"
include "bigscience/gins/pretrainer_base.gin"

TASK_FEATURE_LENGTHS = {
"decoder_target_tokens": 626,
"decoder_input_tokens": 626,
"decoder_segment_ids": 626,
"decoder_causal_attention": 626,
"targets": 625 # we have to take in account an extra token between input and target
}
MIXTURE_OR_TASK_NAME = "c4_prefix_lm_objective_decoder_architecture_with_bot_seperator"

models.DecoderOnlyModel.feature_converter_cls = @seqio.PassThroughFeatureConverter
105 changes: 105 additions & 0 deletions bigscience/gins/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import dataclasses
import functools

import seqio
import t5
import tensorflow as tf
from t5.data import preprocessors, get_default_vocabulary
from t5.data.preprocessors import select_random_chunk, reduce_concat_tokens, split_tokens

Expand Down Expand Up @@ -43,6 +46,108 @@ def full_lm(dataset, sequence_length, output_features):
},
metric_fns=[])

# We want input and target to have an additional token between.
# Inspired by https://github.com/google-research/text-to-text-transfer-transformer/blob/9844ddb4f760ae8a1d4de410578f6211e487bbf9/t5/data/tasks.py#L445

assert get_default_vocabulary().vocab_size == 32100, "Use T5 tokenizer by default"
BOT_ID = 32000 # FIXME: this is only true for t5 tokenizer right now.
@dataclasses.dataclass(frozen=True)
class FancyFeature(seqio.Feature):
# This token is use to seperate input and target. `bot` is the acronym for beginning of target
add_bot: bool = False

def pack_prefix_lm_decoder_only(ds,
sequence_length,
output_features,
loss_on_targets_only=True,
pad_id=0):
"""Randomly split the tokens for the prefix LM objective."""
packed_length = sequence_length["decoder_input_tokens"]
assert packed_length % 2 == 0
# "targets" is a special key
add_bot = output_features["decoder_input_tokens"].add_bot

assert all(l == packed_length for key, l in sequence_length.items() if (not add_bot) or key != "targets")
assert all(l.add_bot == add_bot for key, l in output_features.items() if key != "targets")
if add_bot:
assert sequence_length["targets"] == packed_length - 1
else:
assert sequence_length["targets"] == packed_length

@seqio.utils.map_over_dataset(num_seeds=1)
def pack_examples(example, seed):
split_point = tf.random.stateless_uniform((),
minval=1,
# Adding an extra token costs a bit.
maxval=packed_length if add_bot else packed_length - 1,
seed=seed,
dtype=tf.int32)
if add_bot:
decoder_target_tokens = tf.concat(
[
example['targets'][:split_point - 1],
# bot will be the same as _<extra_id_99>. Not ideal, but the tokenizer doesn't have `bos` right now.
[BOT_ID],
example['targets'][split_point - 1:],
],
axis=0
)
# This has to be specified otherwise dataset tensor spec assigns None in shape.
decoder_target_tokens = tf.reshape(decoder_target_tokens, (packed_length,))
else:
decoder_target_tokens = example['targets']

decoder_input_tokens = seqio.utils.make_autoregressive_inputs(decoder_target_tokens)

if loss_on_targets_only:
decoder_loss_weights = tf.cast(
tf.range(packed_length) >= split_point, tf.int32)
else:
decoder_loss_weights = tf.ones((packed_length,), dtype=tf.int32)

padding_mask = tf.cast(
tf.not_equal(decoder_target_tokens, pad_id), dtype=tf.int32)
decoder_loss_weights *= padding_mask

decoder_causal_attention = tf.cast(
tf.range(packed_length) <= split_point, tf.int32)

return {
'decoder_target_tokens': decoder_target_tokens,
'decoder_input_tokens': decoder_input_tokens,
'decoder_loss_weights': decoder_loss_weights,
'decoder_causal_attention': decoder_causal_attention,
}

return pack_examples(ds)

TaskRegistry.add(
"c4_prefix_lm_objective_decoder_architecture_with_bot_seperator",
source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"),
preprocessors=[
functools.partial(
preprocessors.rekey, key_map={
"inputs": None,
"targets": "text"
}),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
t5.data.preprocessors.targets_for_prefix_lm_objective,
pack_prefix_lm_decoder_only,
],
output_features={
"decoder_target_tokens": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True),
"decoder_input_tokens": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True),
"decoder_loss_weights": FancyFeature(vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True),
"decoder_causal_attention": FancyFeature(
vocabulary=get_default_vocabulary(), add_eos=False, add_bot=True),
# All but the last stage of the preprocessing uses "targets" as the key,
# so this output feature is necessary. It is not marked required because
# the final preprocessor drops it.
"targets": seqio.Feature(vocabulary=get_default_vocabulary(), required=False),
},
metric_fns=[])

# --- Improve sharding ---

# def fully_sharded_logical_axis_rules() -> LogicalAxisRules:
Expand Down
1 change: 1 addition & 0 deletions bigscience/scripts/setup_vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ popd
#rm -rf t5x
git clone https://github.com/bigscience-workshop/t5x.git
pushd t5x
git checkout thomas/prefix_lm_add_token
pip3 install -e .
popd

Expand Down
23 changes: 17 additions & 6 deletions bigscience/scripts/test_seqio_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from t5x import models
import seqio

from t5x import utils
import tensorflow as tf
from ..gins import task

def main():
ds = utils.get_dataset(
utils.DatasetConfig(
"c4_v220_full_lm",
task_feature_lengths={"targets": 626},
"c4_prefix_lm_objective_decoder_architecture_with_bot_seperator",

task_feature_lengths={
"decoder_target_tokens": 626,
"decoder_input_tokens": 626,
"decoder_segment_ids": 626,
"decoder_causal_attention": 626,
"targets": 625 # we have to take in account an extra token between input and target
},
split="train",
batch_size=2048,
shuffle=False,
Expand All @@ -19,7 +27,7 @@ def main():
),
0,
1,
models.DecoderOnlyModel.FEATURE_CONVERTER_CLS
seqio.PassThroughFeatureConverter,
)
first_element = next(iter(ds))
print(first_element)
Expand All @@ -34,8 +42,11 @@ def main():
print(tf.shape(first_element["decoder_target_tokens"]))
print(tf.shape(first_element["decoder_input_tokens"]))
print(tf.shape(first_element["decoder_loss_weights"]))
print(tf.shape(first_element["decoder_segment_ids"]))
print(tf.shape(first_element["decoder_positions"]))
# print(tf.shape(first_element["decoder_segment_ids"]))
# print(tf.shape(first_element["decoder_positions"]))
print(tf.where(first_element["decoder_target_tokens"] == 32000))
print(tf.where(first_element["decoder_input_tokens"] == 32000))
print(ds.element_spec)

if __name__ == "__main__":
main()