-
Notifications
You must be signed in to change notification settings - Fork 55
Add support causalm finetune #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
70dfa5d
e9d21ff
e067c6d
7bd7b1a
f539380
e1c8f38
1724b60
9a8f877
0c2df95
9230f4e
f468973
4eda03b
ed7bbe6
4890761
d3d962c
f84a357
664a3d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,14 +15,7 @@ | |
|
|
||
| # Third Party | ||
| from torch.utils.data import IterableDataset | ||
| from transformers import ( | ||
| AutoConfig, | ||
| AutoTokenizer, | ||
| DataCollatorForSeq2Seq, | ||
| Seq2SeqTrainer, | ||
| Seq2SeqTrainingArguments, | ||
| Trainer, | ||
| ) | ||
| from transformers import AutoConfig, AutoTokenizer, Seq2SeqTrainer, Trainer | ||
| import torch | ||
|
|
||
| # First Party | ||
|
|
@@ -35,6 +28,7 @@ | |
|
|
||
| # Local | ||
| from ...data_model import GenerationTrainRecord | ||
| from ...resources.pretrained_model.base import PretrainedModelBase | ||
| from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper | ||
| from ...toolkit.data_type_utils import get_torch_dtype | ||
|
|
||
|
|
@@ -81,6 +75,7 @@ def train( | |
| lr: float = 2e-5, | ||
| # Directory where model predictions and checkpoints will be written | ||
| checkpoint_dir: str = "/tmp", | ||
| **training_arguments, | ||
| ): | ||
| """ | ||
| # FIXME: Below is currently configured for Seq2Seq only | ||
|
||
|
|
@@ -112,8 +107,10 @@ def train( | |
| log.debug("Bootstrapping base resource [%s]", base_model) | ||
| base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) | ||
|
|
||
| error.type_check("<NLP03221895E>", PretrainedModelBase, base_model=base_model) | ||
| ## Generate data loader from stream | ||
| training_dataset: IterableDataset = cls._preprocess_function( | ||
| base_model=base_model, | ||
| train_stream=train_stream, | ||
| tokenizer=base_model.tokenizer, | ||
| max_source_length=max_source_length, | ||
|
|
@@ -144,47 +141,31 @@ def train( | |
| # by optionally accepting `training_args` | ||
| # as argument to this train function. | ||
| # TODO: Remove all the default used below and make them all configurable | ||
| training_args = Seq2SeqTrainingArguments( | ||
| output_dir=checkpoint_dir, | ||
| per_device_train_batch_size=batch_size, | ||
| per_device_eval_batch_size=batch_size, | ||
| num_train_epochs=num_epochs, | ||
|
|
||
| training_args = { | ||
| "output_dir": checkpoint_dir, | ||
| "per_device_train_batch_size": batch_size, | ||
| "per_device_eval_batch_size": batch_size, | ||
| "num_train_epochs": num_epochs, | ||
| # NOTE: We have disabled evaluation for now | ||
| do_eval=False, | ||
| # evaluation_strategy = "epoch", | ||
| learning_rate=lr, | ||
| weight_decay=0.01, | ||
| save_total_limit=3, | ||
| predict_with_generate=True, | ||
| push_to_hub=False, | ||
| no_cuda=False, # Default | ||
| generation_max_length=max_target_length, | ||
| remove_unused_columns=False, | ||
| dataloader_pin_memory=False, | ||
| gradient_accumulation_steps=accumulate_steps, | ||
| eval_accumulation_steps=accumulate_steps, | ||
| logging_strategy="epoch", | ||
| disable_tqdm=True, | ||
| # NOTE: Following not possible without save and eval strategy | ||
| # load_best_model_at_end=True, | ||
| "do_eval": False, | ||
| # "evaluation_strategy ": "epoch", | ||
| "learning_rate": lr, | ||
| "weight_decay": 0.01, | ||
| "save_total_limit": 3, | ||
| "push_to_hub": False, | ||
| "no_cuda": False, # Default | ||
alex-jw-brooks marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "remove_unused_columns": False, | ||
| "dataloader_pin_memory": False, | ||
| "gradient_accumulation_steps": accumulate_steps, | ||
| "eval_accumulation_steps": accumulate_steps, | ||
| # eval_steps=1, | ||
| **training_arguments, | ||
| **dtype_based_params, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be a nice good first issue in the future to cleanly make sure there aren't collisions in these expanded dicts, but for now we can leave it
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good idea |
||
| ## TODO: Make below configurable | ||
| # fsdp="full_shard auto_wrap", | ||
| # local_rank=0, | ||
| ) | ||
|
|
||
| data_collator = DataCollatorForSeq2Seq( | ||
| tokenizer=base_model.tokenizer, model=base_model.model | ||
| ) | ||
| } | ||
|
|
||
| trainer = Seq2SeqTrainer( | ||
| base_model.model, | ||
| training_args, | ||
| train_dataset=training_dataset, | ||
| data_collator=data_collator, | ||
| tokenizer=base_model.tokenizer, | ||
| # compute_metrics=compute_metrics, | ||
| trainer = base_model.get_trainer( | ||
| train_dataset=training_dataset, **training_args | ||
| ) | ||
|
|
||
| if num_epochs < 1: | ||
|
|
@@ -247,13 +228,27 @@ def run( | |
| # and thus the device placement be according to training strategy, | ||
| # its better to let Trainer handle the evaluation / prediction | ||
|
|
||
| # TODO: Add support for passing extra arguments to prediction_step | ||
| generate_args = { | ||
| "prediction_loss_only": False, | ||
| } | ||
| if isinstance(self.model, Seq2SeqTrainer): | ||
| generate_args["max_new_tokens"] = max_new_tokens | ||
| generate_args["min_new_tokens"] = min_new_tokens | ||
| else: | ||
| # NOTE: Currently the default trainer doesn't support easy way to run individual | ||
| # samples without converting them into Datasets etc. There is a | ||
| # predict_with_generate flag, but it doesn't do anything. | ||
| # Applicable for transformers==4.31.0 | ||
| error( | ||
| "<NLP39984681E>", | ||
| NotImplementedError( | ||
| f"Generation on {type(self.model)} not support \ | ||
| currently! Please try saving and running this model in TGIS." | ||
|
||
| ), | ||
| ) | ||
|
|
||
| _, generated_tokens, _ = self.model.prediction_step( | ||
| self.model.model, | ||
| tok_tensors, | ||
| prediction_loss_only=False, | ||
| max_new_tokens=max_new_tokens, | ||
| min_new_tokens=min_new_tokens, | ||
| self.model.model, tok_tensors, **generate_args | ||
| ) | ||
|
|
||
| generated_text = self.tokenizer.batch_decode( | ||
|
|
@@ -274,6 +269,7 @@ def run( | |
|
|
||
| @staticmethod | ||
| def _preprocess_function( | ||
| base_model: PretrainedModelBase, | ||
| train_stream: DataStream[GenerationTrainRecord], | ||
| tokenizer: AutoTokenizer, | ||
| max_source_length: int, | ||
|
|
@@ -282,28 +278,14 @@ def _preprocess_function( | |
| ): | ||
| """Pre-process each example to get it prepared for training.""" | ||
|
|
||
| # FIXME: Below is currently configured for Seq2Seq only | ||
|
|
||
| def _tokenization_func( | ||
| example: GenerationTrainRecord, | ||
| ): | ||
| model_inputs = tokenizer( | ||
| example.input, | ||
| max_length=max_source_length, | ||
| truncation=True, | ||
| ) | ||
|
|
||
| labels = tokenizer( | ||
| example.output, | ||
| max_length=max_target_length, | ||
| padding="max_length", | ||
| truncation=True, | ||
| ) | ||
|
|
||
| model_inputs["labels"] = labels["input_ids"] | ||
|
|
||
| return model_inputs | ||
|
|
||
| return SimpleIterableStreamWrapper( | ||
| train_stream.map(_tokenization_func), shuffle=shuffle | ||
| ( | ||
| tokenize_function, | ||
| requires_unwrapping, | ||
| ) = base_model.build_task_tokenize_function( | ||
| tokenizer, max_source_length, max_target_length, verbalizer="" | ||
|
||
| ) | ||
| mapped_stream = train_stream.map(tokenize_function) | ||
alex-jw-brooks marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if requires_unwrapping: | ||
| mapped_stream = mapped_stream.flatten() | ||
|
|
||
| return SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -173,7 +173,6 @@ def __del__(self): | |
| def run( | ||
| self, | ||
| text: str, | ||
| device: Optional[Union[str, int]] = _DETECT_DEVICE, | ||
| max_new_tokens=20, | ||
| min_new_tokens=0, | ||
| ) -> GeneratedTextResult: | ||
|
|
@@ -182,8 +181,6 @@ def run( | |
| Args: | ||
| text: str | ||
| Input string to be used to the generation model. | ||
| device: Optional[Union[str, int]] | ||
| Device on which we should run inference; by default, we use the detected device. | ||
| max_new_tokens: int | ||
| The maximum numbers of tokens to generate. | ||
| Default: 20 | ||
|
|
@@ -199,8 +196,8 @@ def run( | |
| verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) | ||
| # Apply the tokenizer to the sample text & move to correct device | ||
| tok_tensors = self.tokenizer(verbalized_text, return_tensors="pt") | ||
| device = PeftPromptTuning._get_device(device) | ||
| inputs = {k: v.to(device) for k, v in tok_tensors.items()} | ||
|
|
||
| inputs = {k: v.to(self.model.device) for k, v in tok_tensors.items()} | ||
|
||
| with torch.no_grad(): | ||
| # Run tokenized tensors through the rest of the PEFT model | ||
| outputs = self.model.generate( | ||
|
|
@@ -604,7 +601,12 @@ def save(self, model_path: str, save_base_model: bool = False): | |
| module_saver.update_config(config_options) | ||
|
|
||
| @classmethod | ||
| def load(cls, model_path: str, torch_dtype: str = None) -> "PeftPromptTuning": | ||
| def load( | ||
| cls, | ||
| model_path: str, | ||
| torch_dtype: str = None, | ||
| device: str = _DETECT_DEVICE, # TODO: Union[int, str] | ||
| ) -> "PeftPromptTuning": | ||
| """Load a PEFT prompt tuning model. This method will currently fail if the original | ||
| model was not saved with the arg value save_base_model=True. | ||
|
|
||
|
|
@@ -626,7 +628,7 @@ def load(cls, model_path: str, torch_dtype: str = None) -> "PeftPromptTuning": | |
| torch_dtype = str_to_torch_dtype(config.trained_torch_dtype) | ||
| if config.has_base_model: | ||
| # TODO: Implement logic for resource loading | ||
| device = cls._get_device(cls._DETECT_DEVICE) | ||
| device = cls._get_device(device) | ||
| model_config = os.path.join(model_path, config.full_model_path) | ||
| peft_config = PeftConfig.from_pretrained(model_config) | ||
| if peft_config.task_type == "CAUSAL_LM": | ||
|
|
@@ -1005,7 +1007,7 @@ def _get_data_loaders_from_stream( | |
| tokenize_function, | ||
| requires_unwrapping, | ||
| ) = base_model.build_task_tokenize_function( | ||
| tokenizer, max_source_length, max_target_length, verbalizer | ||
| tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0 | ||
| ) | ||
| mapped_stream = train_stream.map(tokenize_function) | ||
| if requires_unwrapping: | ||
|
|
@@ -1066,7 +1068,13 @@ def _execute_train_loop( | |
| num_training_steps=(len(train_dataloader) * num_epochs), | ||
| ) | ||
| # Configure accelerator for gradient accumulation | ||
| accelerator = Accelerator(gradient_accumulation_steps=accumulate_steps) | ||
| accelerator_args = { | ||
| "gradient_accumulation_steps": accumulate_steps, | ||
| "device_placement": True, | ||
| } | ||
|
|
||
| accelerator = Accelerator(**accelerator_args) | ||
|
||
|
|
||
| for epoch in range(num_epochs): | ||
| model.train() | ||
| total_loss = 0 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,12 +14,18 @@ | |
|
|
||
| # Standard | ||
| from abc import ABC, abstractmethod | ||
| from typing import Callable, List, Optional, Tuple, Type | ||
| from typing import Callable, List, Optional, Tuple, Type, Union | ||
| import json | ||
| import os | ||
|
|
||
| # Third Party | ||
| from transformers import AutoTokenizer | ||
| from torch.utils.data import IterableDataset | ||
| from transformers import ( | ||
| AutoTokenizer, | ||
| DataCollatorWithPadding, | ||
| Trainer, | ||
| TrainingArguments, | ||
| ) | ||
| from transformers.models.auto.auto_factory import _BaseAutoModelClass | ||
| import torch | ||
|
|
||
|
|
@@ -233,6 +239,57 @@ def save( | |
| self.tokenizer.save_pretrained(tok_abs_path) | ||
| self.model.save_pretrained(model_abs_path) | ||
|
|
||
| def get_trainer( | ||
| self, | ||
| train_dataset: IterableDataset, | ||
| eval_dataset: Union[IterableDataset, None] = None, | ||
| optimizers=(None, None), | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| NOTE: following parameters are not supported currently: | ||
| 1. model_init | ||
| 2. compute_metrics | ||
| 3. callbacks | ||
| 4. preprocess_logits_for_metrics | ||
| """ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same questions about documenting the kwargs here in the docstring (at least the nonexpanded ones). I assume the other one probably needs it also |
||
|
|
||
| training_args = TrainingArguments(**kwargs) | ||
|
|
||
| data_collator = self._get_data_collator(**kwargs) | ||
|
|
||
| trainer_arguments = { | ||
| "train_dataset": train_dataset, | ||
| "data_collator": data_collator, | ||
| "tokenizer": self._tokenizer, | ||
| "optimizers": optimizers, | ||
| "eval_dataset": eval_dataset, | ||
| } | ||
|
|
||
| return Trainer(self._model, training_args, **trainer_arguments) | ||
|
|
||
| def _get_data_collator(self, **kwargs): | ||
| """Function to return appropriate data collator based on resource. | ||
|
|
||
| The default implementation of the base resource uses | ||
| DataCollatorWithPadding which will dynamically pad the inputs received. | ||
|
|
||
| Args: | ||
| **kwargs: | ||
| All the keyword arguments passed to this function | ||
| will get filtered out to appropriate ones that are | ||
| applicable to implemented data collator. | ||
| Returns: | ||
| transformers.DataCollator | ||
| """ | ||
|
|
||
| applicable_args = ["max_length", "pad_to_multiple_of"] | ||
| collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} | ||
|
|
||
| return DataCollatorWithPadding( | ||
| tokenizer=self._tokenizer, padding=True, **collator_kwargs | ||
| ) | ||
|
|
||
| # pylint: disable=unused-argument | ||
| @classmethod | ||
| def get_num_transformers_submodules( | ||
|
|
@@ -249,6 +306,7 @@ def build_task_tokenize_function( | |
| max_source_length: int, | ||
| max_target_length: int, | ||
| verbalizer: str, | ||
| task_ids: Union[None, int] = None, | ||
| ) -> Tuple[Callable, bool]: | ||
| """Builds tokenizer functions which can be mapped over train streams to process | ||
| data which can then be easily passed to a DataLoader for different model types. | ||
|
|
@@ -263,6 +321,10 @@ def build_task_tokenize_function( | |
| verbalizer: str | ||
| Verbalizer template to be used for formatting data. This template may use brackets | ||
| to indicate where fields from the data model TrainGenerationRecord must be rendered. | ||
| task_ids: Union[None, int] | ||
| Task id corresponding particular task for multi-task prompt tuning. | ||
| NOTE: Only required for MPT (Multi-task prompt tuning) | ||
| Default: None | ||
|
|
||
| Returns: | ||
| Tuple(Callable, bool) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is better! Can you link the trainer args in the docstring through?