Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 149 additions & 116 deletions caikit_nlp/modules/text_generation/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from typing import Optional

# Third Party
from torch.utils.data import IterableDataset
from transformers import (
AutoConfig,
AutoTokenizer,
DataCollatorForSeq2Seq,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
Trainer,
)
from transformers import AutoConfig, AutoTokenizer
import torch

# First Party
Expand All @@ -35,6 +30,11 @@

# Local
from ...data_model import GenerationTrainRecord
from ...resources.pretrained_model import (
HFAutoCausalLM,
HFAutoSeq2SeqLM,
PretrainedModelBase,
)
from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper
from ...toolkit.data_type_utils import get_torch_dtype

Expand All @@ -55,17 +55,26 @@
class FineTuning(ModuleBase):
"""Module to provide fine-tuning support for text generation task"""

def __init__(self, tokenizer, model):
RANDOM_SEED = 73
supported_resources = [HFAutoCausalLM, HFAutoSeq2SeqLM]

def __init__(
self,
tokenizer,
model,
bos_token: Optional[str] = None,
sep_token: Optional[str] = None,
eos_token: Optional[str] = None,
pad_token: Optional[str] = None,
):
super().__init__()

self.tokenizer = tokenizer
# NOTE: self.model here can also be HF trainer. This is because
# if we have just trained the model then the models weights might be
# available in different devices (and configuration), depending on
# how it was trained. For now (July 10, 2023), we are not trying to
# extract the model out from trainer itself, since that would require
# us to essentially save it or reconstruct it to do normal inferring.
self.model = model
self._bos_token = bos_token
self._sep_token = sep_token
self._eos_token = eos_token
self._pad_token = pad_token

@classmethod
def train(
Expand All @@ -78,12 +87,49 @@ def train(
batch_size: int = 8,
num_epochs: int = 5,
accumulate_steps: int = 32,
random_seed: int = RANDOM_SEED,
lr: float = 2e-5,
# Directory where model predictions and checkpoints will be written
checkpoint_dir: str = "/tmp",
**training_arguments,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better! Can you link the trainer args in the docstring through?

):
"""
# FIXME: Below is currently configured for Seq2Seq only
Fine-tune a CausalLM or Seq2seq text generation model.

Args:
base_model: Union[str, caikit_nlp.resources.pretrained_model.base.PretrainedModelBase]
Base resource model used for underlying generation.
train_stream: DataStream[GenerationTrainRecord] or DataStream[ClassificationTrainRecord]
Data to be used for fine-tuning the generation model.
torch_dtype: str
TODO: Optional[Union[torch.dtype, str]]
Data type to use for training/inference of the underlying text generation model.
If no value is provided, we pull from torch_dtype in config. If an in memory
resource is provided which does not match the specified data type, the model
underpinning the resource will be converted in place to the correct torch dtype.
max_source_length: int
Max length of input sequences being considered. Default: 256.
max_target_length: int
Max length of target sequences being predicted. Default: 128.
batch_size: int
Batch sized to be used for training / evaluation data. Default: 8.
num_epochs: int
Number of epochs to tune the model. Default: 20.
accumulate_steps: int
Number of steps to use for gradient accumulation. Default: 1.
lr: float
Learning rate to be used while tuning model. Default: 2e-5.
checkpoint_dir: str
Directory where model predictions and checkpoints will be written
**training_arguments:
Arguments supported by HF Training Arguments.
TrainingArguments:
https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments
Seq2SeqTrainingArguments:
https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments
Returns:
FineTuning
Instance of this class with fine-tuned models.
"""

torch_dtype = get_torch_dtype(torch_dtype)
Expand All @@ -92,11 +138,12 @@ def train(
# text_generation module. In future, we would want to consolidate this into
# a base class or a toolkit function
# pylint: disable=duplicate-code
resource_type = None

## Load base model
if isinstance(base_model, str):
model_config = AutoConfig.from_pretrained(base_model)

resource_type = None
for resource in cls.supported_resources:
if model_config.model_type in resource.SUPPORTED_MODEL_TYPES:
resource_type = resource
Expand All @@ -112,8 +159,14 @@ def train(
log.debug("Bootstrapping base resource [%s]", base_model)
base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype)

else:
# base_model is actually a resource object
resource_type = type(base_model)

error.type_check("<NLP03221895E>", PretrainedModelBase, base_model=base_model)
## Generate data loader from stream
training_dataset: IterableDataset = cls._preprocess_function(
base_model=base_model,
train_stream=train_stream,
tokenizer=base_model.tokenizer,
max_source_length=max_source_length,
Expand Down Expand Up @@ -144,47 +197,33 @@ def train(
# by optionally accepting `training_args`
# as argument to this train function.
# TODO: Remove all the default used below and make them all configurable
training_args = Seq2SeqTrainingArguments(
output_dir=checkpoint_dir,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=num_epochs,

training_args = {
"output_dir": checkpoint_dir,
"per_device_train_batch_size": batch_size,
"per_device_eval_batch_size": batch_size,
"num_train_epochs": num_epochs,
"seed": random_seed,
# NOTE: We have disabled evaluation for now
do_eval=False,
# evaluation_strategy = "epoch",
learning_rate=lr,
weight_decay=0.01,
save_total_limit=3,
predict_with_generate=True,
push_to_hub=False,
no_cuda=False, # Default
generation_max_length=max_target_length,
remove_unused_columns=False,
dataloader_pin_memory=False,
gradient_accumulation_steps=accumulate_steps,
eval_accumulation_steps=accumulate_steps,
logging_strategy="epoch",
disable_tqdm=True,
# NOTE: Following not possible without save and eval strategy
# load_best_model_at_end=True,
"do_eval": False,
# "evaluation_strategy ": "epoch",
"learning_rate": lr,
"weight_decay": 0.01,
"save_total_limit": 3,
"push_to_hub": False,
"no_cuda": False, # Default
"remove_unused_columns": False,
"dataloader_pin_memory": False,
"gradient_accumulation_steps": accumulate_steps,
"eval_accumulation_steps": accumulate_steps,
# eval_steps=1,
# load_best_model_at_end
**training_arguments,
**dtype_based_params,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a nice good first issue in the future to cleanly make sure there aren't collisions in these expanded dicts, but for now we can leave it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

## TODO: Make below configurable
# fsdp="full_shard auto_wrap",
# local_rank=0,
)

data_collator = DataCollatorForSeq2Seq(
tokenizer=base_model.tokenizer, model=base_model.model
)
}

trainer = Seq2SeqTrainer(
base_model.model,
training_args,
train_dataset=training_dataset,
data_collator=data_collator,
tokenizer=base_model.tokenizer,
# compute_metrics=compute_metrics,
trainer = base_model.get_trainer(
train_dataset=training_dataset, **training_args
)

if num_epochs < 1:
Expand All @@ -201,17 +240,25 @@ def train(

# Start training via Trainer.train function
trainer.train()
# NOTE: By default the model would be available in different ways
# depending on where and how it was trained. So we need to fetch the model
# from the trainer depending on the training method, like fsdp, ddp etc.
# For simplicity, currently we will use trainer as the model since it anyways
# enable the `predict` function on it and has all the layers of the model
# distributed already, so it will be most optimized to use trainer to
# perform prediction at this stage.

# save the model temporarily and reload it
# this is done, since otherwise the model might be distributed in different
# devices, in which case its better to use trainer's `prediction_step`
# functions, but then, they don't always give API similar to `generate`
# and thus cause incompatibilities in `run` function
trainer.save_model(checkpoint_dir)

model = resource_type.bootstrap(
checkpoint_dir, checkpoint_dir, torch_dtype=torch_dtype
)

return cls(
tokenizer=base_model.tokenizer,
model=trainer,
tokenizer=model.tokenizer,
model=model,
bos_token=model.tokenizer.bos_token or None,
sep_token=model.tokenizer.sep_token or None,
eos_token=model.tokenizer.eos_token or None,
pad_token=model.tokenizer.pad_token or None,
)

# pylint: disable=unused-argument
Expand All @@ -236,44 +283,41 @@ def run(
GeneratedTextResult
Generated text result
"""
if isinstance(self.model, Trainer):
# Apply the tokenizer to the sample text & move to correct device
tok_tensors = self.tokenizer(text, return_tensors="pt")
# NOTE: below function is prediction on trainer, for which we need to supply
# the actual underlying model as well
# NOTE: We are using prediction_step instead of calling `self.model.generate`
# because this way HF Trainer automatically handles device placement of the
# data and model. Since the model is with Trainer at this point
# and thus the device placement be according to training strategy,
# its better to let Trainer handle the evaluation / prediction

# TODO: Add support for passing extra arguments to prediction_step
_, generated_tokens, _ = self.model.prediction_step(
self.model.model,
tok_tensors,
prediction_loss_only=False,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
)

generated_text = self.tokenizer.batch_decode(
generated_tokens.detach().cpu().numpy(), skip_special_tokens=True
)[0]
inputs = self.model.tokenizer(text, return_tensors="pt")
generate_ids = self.model.model.generate(
input_ids=inputs["input_ids"],
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
use_cache=True,
)

else:
error(
"<NLP38929392E>",
NotImplementedError(
"model prediction on pre-finetuned model currently not supported"
),
token_count = generate_ids.size(1) - 1
preds = [
self.model.tokenizer.decode(
g, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for g in generate_ids
]
if generate_ids[0][-1].item() == self._eos_token:
finish_reason = "EOS_TOKEN"
elif generate_ids.size(1) - 1 == max_new_tokens:
finish_reason = "MAX_TOKENS"
else:
finish_reason = "OTHER"

return GeneratedTextResult(generated_text=generated_text)
return GeneratedTextResult(
generated_tokens=token_count,
generated_text=preds[0],
finish_reason=finish_reason,
producer_id=self.PRODUCER_ID,
)

################################## Private Functions ###########################################

@staticmethod
def _preprocess_function(
base_model: PretrainedModelBase,
train_stream: DataStream[GenerationTrainRecord],
tokenizer: AutoTokenizer,
max_source_length: int,
Expand All @@ -282,28 +326,17 @@ def _preprocess_function(
):
"""Pre-process each example to get it prepared for training."""

# FIXME: Below is currently configured for Seq2Seq only

def _tokenization_func(
example: GenerationTrainRecord,
):
model_inputs = tokenizer(
example.input,
max_length=max_source_length,
truncation=True,
)

labels = tokenizer(
example.output,
max_length=max_target_length,
padding="max_length",
truncation=True,
)

model_inputs["labels"] = labels["input_ids"]

return model_inputs

return SimpleIterableStreamWrapper(
train_stream.map(_tokenization_func), shuffle=shuffle
# TODO: We are using a default verbalizer which is strictly tied to
# source training record currently. We need to figure out a better
# way to make verbalizer optional for build_task_tokenize_function
(
tokenize_function,
requires_unwrapping,
) = base_model.build_task_tokenize_function(
tokenizer, max_source_length, max_target_length, verbalizer="{{input}}"
)
mapped_stream = train_stream.map(tokenize_function)
if requires_unwrapping:
mapped_stream = mapped_stream.flatten()

return SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle)
Loading