-
Notifications
You must be signed in to change notification settings - Fork 55
Add support causalm finetune #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
70dfa5d
e9d21ff
e067c6d
7bd7b1a
f539380
e1c8f38
1724b60
9a8f877
0c2df95
9230f4e
f468973
4eda03b
ed7bbe6
4890761
d3d962c
f84a357
664a3d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,17 +12,12 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Standard | ||
| from typing import Optional | ||
|
|
||
| # Third Party | ||
| from torch.utils.data import IterableDataset | ||
| from transformers import ( | ||
| AutoConfig, | ||
| AutoTokenizer, | ||
| DataCollatorForSeq2Seq, | ||
| Seq2SeqTrainer, | ||
| Seq2SeqTrainingArguments, | ||
| Trainer, | ||
| ) | ||
| from transformers import AutoConfig, AutoTokenizer | ||
| import torch | ||
|
|
||
| # First Party | ||
|
|
@@ -35,6 +30,11 @@ | |
|
|
||
| # Local | ||
| from ...data_model import GenerationTrainRecord | ||
| from ...resources.pretrained_model import ( | ||
| HFAutoCausalLM, | ||
| HFAutoSeq2SeqLM, | ||
| PretrainedModelBase, | ||
| ) | ||
| from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper | ||
| from ...toolkit.data_type_utils import get_torch_dtype | ||
|
|
||
|
|
@@ -55,17 +55,26 @@ | |
| class FineTuning(ModuleBase): | ||
| """Module to provide fine-tuning support for text generation task""" | ||
|
|
||
| def __init__(self, tokenizer, model): | ||
| RANDOM_SEED = 73 | ||
| supported_resources = [HFAutoCausalLM, HFAutoSeq2SeqLM] | ||
|
|
||
| def __init__( | ||
| self, | ||
| tokenizer, | ||
| model, | ||
| bos_token: Optional[str] = None, | ||
| sep_token: Optional[str] = None, | ||
| eos_token: Optional[str] = None, | ||
| pad_token: Optional[str] = None, | ||
| ): | ||
| super().__init__() | ||
|
|
||
| self.tokenizer = tokenizer | ||
| # NOTE: self.model here can also be HF trainer. This is because | ||
| # if we have just trained the model then the models weights might be | ||
| # available in different devices (and configuration), depending on | ||
| # how it was trained. For now (July 10, 2023), we are not trying to | ||
| # extract the model out from trainer itself, since that would require | ||
| # us to essentially save it or reconstruct it to do normal inferring. | ||
| self.model = model | ||
| self._bos_token = bos_token | ||
| self._sep_token = sep_token | ||
| self._eos_token = eos_token | ||
| self._pad_token = pad_token | ||
|
|
||
| @classmethod | ||
| def train( | ||
|
|
@@ -78,12 +87,49 @@ def train( | |
| batch_size: int = 8, | ||
| num_epochs: int = 5, | ||
| accumulate_steps: int = 32, | ||
| random_seed: int = RANDOM_SEED, | ||
| lr: float = 2e-5, | ||
| # Directory where model predictions and checkpoints will be written | ||
| checkpoint_dir: str = "/tmp", | ||
| **training_arguments, | ||
| ): | ||
| """ | ||
| # FIXME: Below is currently configured for Seq2Seq only | ||
| Fine-tune a CausalLM or Seq2seq text generation model. | ||
|
|
||
| Args: | ||
| base_model: Union[str, caikit_nlp.resources.pretrained_model.base.PretrainedModelBase] | ||
| Base resource model used for underlying generation. | ||
| train_stream: DataStream[GenerationTrainRecord] or DataStream[ClassificationTrainRecord] | ||
| Data to be used for fine-tuning the generation model. | ||
| torch_dtype: str | ||
| TODO: Optional[Union[torch.dtype, str]] | ||
| Data type to use for training/inference of the underlying text generation model. | ||
| If no value is provided, we pull from torch_dtype in config. If an in memory | ||
| resource is provided which does not match the specified data type, the model | ||
| underpinning the resource will be converted in place to the correct torch dtype. | ||
| max_source_length: int | ||
| Max length of input sequences being considered. Default: 256. | ||
| max_target_length: int | ||
| Max length of target sequences being predicted. Default: 128. | ||
| batch_size: int | ||
| Batch sized to be used for training / evaluation data. Default: 8. | ||
| num_epochs: int | ||
| Number of epochs to tune the model. Default: 20. | ||
| accumulate_steps: int | ||
| Number of steps to use for gradient accumulation. Default: 1. | ||
| lr: float | ||
| Learning rate to be used while tuning model. Default: 2e-5. | ||
| checkpoint_dir: str | ||
| Directory where model predictions and checkpoints will be written | ||
| **training_arguments: | ||
| Arguments supported by HF Training Arguments. | ||
| TrainingArguments: | ||
| https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments | ||
| Seq2SeqTrainingArguments: | ||
| https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments | ||
| Returns: | ||
| FineTuning | ||
| Instance of this class with fine-tuned models. | ||
| """ | ||
|
|
||
| torch_dtype = get_torch_dtype(torch_dtype) | ||
|
|
@@ -92,11 +138,12 @@ def train( | |
| # text_generation module. In future, we would want to consolidate this into | ||
| # a base class or a toolkit function | ||
| # pylint: disable=duplicate-code | ||
| resource_type = None | ||
|
|
||
| ## Load base model | ||
| if isinstance(base_model, str): | ||
| model_config = AutoConfig.from_pretrained(base_model) | ||
|
|
||
| resource_type = None | ||
| for resource in cls.supported_resources: | ||
| if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: | ||
| resource_type = resource | ||
|
|
@@ -112,8 +159,14 @@ def train( | |
| log.debug("Bootstrapping base resource [%s]", base_model) | ||
| base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) | ||
|
|
||
| else: | ||
| # base_model is actually a resource object | ||
| resource_type = type(base_model) | ||
|
|
||
| error.type_check("<NLP03221895E>", PretrainedModelBase, base_model=base_model) | ||
| ## Generate data loader from stream | ||
| training_dataset: IterableDataset = cls._preprocess_function( | ||
| base_model=base_model, | ||
| train_stream=train_stream, | ||
| tokenizer=base_model.tokenizer, | ||
| max_source_length=max_source_length, | ||
|
|
@@ -144,47 +197,33 @@ def train( | |
| # by optionally accepting `training_args` | ||
| # as argument to this train function. | ||
| # TODO: Remove all the default used below and make them all configurable | ||
| training_args = Seq2SeqTrainingArguments( | ||
| output_dir=checkpoint_dir, | ||
| per_device_train_batch_size=batch_size, | ||
| per_device_eval_batch_size=batch_size, | ||
| num_train_epochs=num_epochs, | ||
|
|
||
| training_args = { | ||
| "output_dir": checkpoint_dir, | ||
| "per_device_train_batch_size": batch_size, | ||
| "per_device_eval_batch_size": batch_size, | ||
| "num_train_epochs": num_epochs, | ||
| "seed": random_seed, | ||
| # NOTE: We have disabled evaluation for now | ||
| do_eval=False, | ||
| # evaluation_strategy = "epoch", | ||
| learning_rate=lr, | ||
| weight_decay=0.01, | ||
| save_total_limit=3, | ||
| predict_with_generate=True, | ||
| push_to_hub=False, | ||
| no_cuda=False, # Default | ||
| generation_max_length=max_target_length, | ||
| remove_unused_columns=False, | ||
| dataloader_pin_memory=False, | ||
| gradient_accumulation_steps=accumulate_steps, | ||
| eval_accumulation_steps=accumulate_steps, | ||
| logging_strategy="epoch", | ||
| disable_tqdm=True, | ||
| # NOTE: Following not possible without save and eval strategy | ||
| # load_best_model_at_end=True, | ||
| "do_eval": False, | ||
| # "evaluation_strategy ": "epoch", | ||
| "learning_rate": lr, | ||
| "weight_decay": 0.01, | ||
| "save_total_limit": 3, | ||
| "push_to_hub": False, | ||
| "no_cuda": False, # Default | ||
alex-jw-brooks marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "remove_unused_columns": False, | ||
| "dataloader_pin_memory": False, | ||
| "gradient_accumulation_steps": accumulate_steps, | ||
| "eval_accumulation_steps": accumulate_steps, | ||
| # eval_steps=1, | ||
| # load_best_model_at_end | ||
| **training_arguments, | ||
| **dtype_based_params, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be a nice good first issue in the future to cleanly make sure there aren't collisions in these expanded dicts, but for now we can leave it
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good idea |
||
| ## TODO: Make below configurable | ||
| # fsdp="full_shard auto_wrap", | ||
| # local_rank=0, | ||
| ) | ||
|
|
||
| data_collator = DataCollatorForSeq2Seq( | ||
| tokenizer=base_model.tokenizer, model=base_model.model | ||
| ) | ||
| } | ||
|
|
||
| trainer = Seq2SeqTrainer( | ||
| base_model.model, | ||
| training_args, | ||
| train_dataset=training_dataset, | ||
| data_collator=data_collator, | ||
| tokenizer=base_model.tokenizer, | ||
| # compute_metrics=compute_metrics, | ||
| trainer = base_model.get_trainer( | ||
| train_dataset=training_dataset, **training_args | ||
| ) | ||
|
|
||
| if num_epochs < 1: | ||
|
|
@@ -201,17 +240,25 @@ def train( | |
|
|
||
| # Start training via Trainer.train function | ||
| trainer.train() | ||
| # NOTE: By default the model would be available in different ways | ||
| # depending on where and how it was trained. So we need to fetch the model | ||
| # from the trainer depending on the training method, like fsdp, ddp etc. | ||
| # For simplicity, currently we will use trainer as the model since it anyways | ||
| # enable the `predict` function on it and has all the layers of the model | ||
| # distributed already, so it will be most optimized to use trainer to | ||
| # perform prediction at this stage. | ||
|
|
||
| # save the model temporarily and reload it | ||
| # this is done, since otherwise the model might be distributed in different | ||
| # devices, in which case its better to use trainer's `prediction_step` | ||
| # functions, but then, they don't always give API similar to `generate` | ||
| # and thus cause incompatibilities in `run` function | ||
| trainer.save_model(checkpoint_dir) | ||
|
|
||
| model = resource_type.bootstrap( | ||
| checkpoint_dir, checkpoint_dir, torch_dtype=torch_dtype | ||
| ) | ||
|
|
||
| return cls( | ||
| tokenizer=base_model.tokenizer, | ||
| model=trainer, | ||
| tokenizer=model.tokenizer, | ||
| model=model, | ||
| bos_token=model.tokenizer.bos_token or None, | ||
| sep_token=model.tokenizer.sep_token or None, | ||
| eos_token=model.tokenizer.eos_token or None, | ||
| pad_token=model.tokenizer.pad_token or None, | ||
| ) | ||
|
|
||
| # pylint: disable=unused-argument | ||
|
|
@@ -236,44 +283,41 @@ def run( | |
| GeneratedTextResult | ||
| Generated text result | ||
| """ | ||
| if isinstance(self.model, Trainer): | ||
| # Apply the tokenizer to the sample text & move to correct device | ||
| tok_tensors = self.tokenizer(text, return_tensors="pt") | ||
| # NOTE: below function is prediction on trainer, for which we need to supply | ||
| # the actual underlying model as well | ||
| # NOTE: We are using prediction_step instead of calling `self.model.generate` | ||
| # because this way HF Trainer automatically handles device placement of the | ||
| # data and model. Since the model is with Trainer at this point | ||
| # and thus the device placement be according to training strategy, | ||
| # its better to let Trainer handle the evaluation / prediction | ||
|
|
||
| # TODO: Add support for passing extra arguments to prediction_step | ||
| _, generated_tokens, _ = self.model.prediction_step( | ||
| self.model.model, | ||
| tok_tensors, | ||
| prediction_loss_only=False, | ||
| max_new_tokens=max_new_tokens, | ||
| min_new_tokens=min_new_tokens, | ||
| ) | ||
|
|
||
| generated_text = self.tokenizer.batch_decode( | ||
| generated_tokens.detach().cpu().numpy(), skip_special_tokens=True | ||
| )[0] | ||
| inputs = self.model.tokenizer(text, return_tensors="pt") | ||
| generate_ids = self.model.model.generate( | ||
| input_ids=inputs["input_ids"], | ||
| max_new_tokens=max_new_tokens, | ||
| min_new_tokens=min_new_tokens, | ||
| use_cache=True, | ||
| ) | ||
|
|
||
| else: | ||
| error( | ||
| "<NLP38929392E>", | ||
| NotImplementedError( | ||
| "model prediction on pre-finetuned model currently not supported" | ||
| ), | ||
| token_count = generate_ids.size(1) - 1 | ||
| preds = [ | ||
| self.model.tokenizer.decode( | ||
| g, skip_special_tokens=True, clean_up_tokenization_spaces=True | ||
| ) | ||
| for g in generate_ids | ||
| ] | ||
| if generate_ids[0][-1].item() == self._eos_token: | ||
| finish_reason = "EOS_TOKEN" | ||
| elif generate_ids.size(1) - 1 == max_new_tokens: | ||
| finish_reason = "MAX_TOKENS" | ||
| else: | ||
| finish_reason = "OTHER" | ||
|
|
||
| return GeneratedTextResult(generated_text=generated_text) | ||
| return GeneratedTextResult( | ||
| generated_tokens=token_count, | ||
| generated_text=preds[0], | ||
| finish_reason=finish_reason, | ||
| producer_id=self.PRODUCER_ID, | ||
| ) | ||
|
|
||
| ################################## Private Functions ########################################### | ||
|
|
||
| @staticmethod | ||
| def _preprocess_function( | ||
| base_model: PretrainedModelBase, | ||
| train_stream: DataStream[GenerationTrainRecord], | ||
| tokenizer: AutoTokenizer, | ||
| max_source_length: int, | ||
|
|
@@ -282,28 +326,17 @@ def _preprocess_function( | |
| ): | ||
| """Pre-process each example to get it prepared for training.""" | ||
|
|
||
| # FIXME: Below is currently configured for Seq2Seq only | ||
|
|
||
| def _tokenization_func( | ||
| example: GenerationTrainRecord, | ||
| ): | ||
| model_inputs = tokenizer( | ||
| example.input, | ||
| max_length=max_source_length, | ||
| truncation=True, | ||
| ) | ||
|
|
||
| labels = tokenizer( | ||
| example.output, | ||
| max_length=max_target_length, | ||
| padding="max_length", | ||
| truncation=True, | ||
| ) | ||
|
|
||
| model_inputs["labels"] = labels["input_ids"] | ||
|
|
||
| return model_inputs | ||
|
|
||
| return SimpleIterableStreamWrapper( | ||
| train_stream.map(_tokenization_func), shuffle=shuffle | ||
| # TODO: We are using a default verbalizer which is strictly tied to | ||
| # source training record currently. We need to figure out a better | ||
| # way to make verbalizer optional for build_task_tokenize_function | ||
| ( | ||
| tokenize_function, | ||
| requires_unwrapping, | ||
| ) = base_model.build_task_tokenize_function( | ||
| tokenizer, max_source_length, max_target_length, verbalizer="{{input}}" | ||
| ) | ||
| mapped_stream = train_stream.map(tokenize_function) | ||
alex-jw-brooks marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if requires_unwrapping: | ||
| mapped_stream = mapped_stream.flatten() | ||
|
|
||
| return SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is better! Can you link the trainer args in the docstring through?