From 8961efefcd68f0b41e4c699768ce19c7b64eec97 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Wed, 6 Dec 2023 19:35:10 -0700 Subject: [PATCH 1/5] add DataCollatorForCompletionOnlyLM Signed-off-by: Sukriti-Sharma4 --- .../text_generation/peft_prompt_tuning.py | 7 ++++++ caikit_nlp/resources/pretrained_model/base.py | 25 ++++++++++++------- .../toolkit/text_generation/training_utils.py | 4 ++- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 435e439d..e7414d4c 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -347,6 +347,11 @@ def train( Silences TQDM progress bars at train time. Default: True. seed: int Integer to be used as random seed for training. + train_on_completion: bool + True will train the model on the generated prompts only. Default: False. + response_template: Optional[str] = None + Only if train_on_completion is set to True, pass a response template that + will be used to parse out the response. Returns: PeftPromptTuning Instance of this class with tuned prompt vectors. @@ -505,6 +510,8 @@ def train( training_args, checkpoint_dir, base_model, + train_on_completion, + response_template, ) # Wrap up the trained model in a class instance diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index 6adad903..55acb934 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -21,6 +21,7 @@ # Third Party from torch.utils.data import IterableDataset +from trl import DataCollatorForCompletionOnlyLM from transformers import ( AutoTokenizer, DataCollatorWithPadding, @@ -280,6 +281,8 @@ def get_trainer( eval_dataset: Union[IterableDataset, None] = None, optimizers=(None, None), model=None, + train_on_completion=False, + response_template=None, **kwargs, ): """ @@ -296,7 +299,7 @@ def get_trainer( training_args = TrainingArguments(**kwargs) - data_collator = self._get_data_collator(**kwargs) + data_collator = self._get_data_collator(train_on_completion, response_template, **kwargs) trainer_arguments = { "train_dataset": train_dataset, @@ -311,13 +314,15 @@ def get_trainer( return LoggingTrainer(self._model, training_args, **trainer_arguments) - def _get_data_collator(self, **kwargs): + def _get_data_collator(self, train_on_completion, response_template, **kwargs): """Function to return appropriate data collator based on resource. The default implementation of the base resource uses DataCollatorWithPadding which will dynamically pad the inputs received. Args: + train_on_completion: bool, + response_template: str, **kwargs: All the keyword arguments passed to this function will get filtered out to appropriate ones that are @@ -325,13 +330,15 @@ def _get_data_collator(self, **kwargs): Returns: transformers.DataCollator """ - - applicable_args = ["max_length", "pad_to_multiple_of"] - collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} - - return DataCollatorWithPadding( - tokenizer=self._tokenizer, padding=True, **collator_kwargs - ) + if train_on_completion: + return DataCollatorForCompletionOnlyLM(response_template, tokenizer=self._tokenizer) + else: + applicable_args = ["max_length", "pad_to_multiple_of"] + collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} + + return DataCollatorWithPadding( + tokenizer=self._tokenizer, padding=True, **collator_kwargs + ) # pylint: disable=unused-argument @classmethod diff --git a/caikit_nlp/toolkit/text_generation/training_utils.py b/caikit_nlp/toolkit/text_generation/training_utils.py index ac905fad..7657f145 100644 --- a/caikit_nlp/toolkit/text_generation/training_utils.py +++ b/caikit_nlp/toolkit/text_generation/training_utils.py @@ -174,12 +174,14 @@ def launch_training( checkpoint_dir, caikit_resource=None, tokenizer=None, + train_on_completion=False, + response_template=None ) -> None: """Utility function to wrap trainer and execute training""" # If we have a caikit resource, grab the trainer through it if caikit_resource is not None: trainer = caikit_resource.get_trainer( - train_dataset=training_dataset, model=base_model, **training_args + train_dataset=training_dataset, model=base_model, train_on_completion=train_on_completion, response_template=response_template, **training_args ) else: # If trainer is not provided fetch it from base_model From ae5f147dad26cfce013466b671eea5128d8a3574 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Wed, 6 Dec 2023 20:24:29 -0700 Subject: [PATCH 2/5] enable example Signed-off-by: Sukriti-Sharma4 --- .../text_generation/peft_prompt_tuning.py | 8 ++++++++ caikit_nlp/resources/pretrained_model/base.py | 2 ++ examples/run_peft_tuning.py | 20 +++++++++++++++++++ pyproject.toml | 1 + 4 files changed, 31 insertions(+) diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index e7414d4c..866b2b0a 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -300,6 +300,8 @@ def train( torch_dtype: Optional[str] = None, # TODO: Optional[Union[torch.dtype, str]] silence_progress_bars: Optional[bool] = True, seed: int = RANDOM_SEED, + train_on_completion: bool = False, + response_template: str = None, **kwargs, ) -> "PeftPromptTuning": """Run prompt tuning (vanilla or MPT) through PEFT on a CausalLM or Seq2seq model @@ -360,6 +362,12 @@ def train( "", len(train_stream) > 0, "train_stream cannot be empty" ) + if train_on_completion: + if not response_template: + error.value_check( + "", "Response template is need for train on completion" + ) + # Configure random seed transformers.set_seed(seed) # NOTE: Following can be uncommented to allow full determinism diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index 55acb934..badede0e 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -331,6 +331,8 @@ def _get_data_collator(self, train_on_completion, response_template, **kwargs): transformers.DataCollator """ if train_on_completion: + if response_template==None: + error("", "Response Template needs to be set to use completion only") return DataCollatorForCompletionOnlyLM(response_template, tokenizer=self._tokenizer) else: applicable_args = ["max_length", "pad_to_multiple_of"] diff --git a/examples/run_peft_tuning.py b/examples/run_peft_tuning.py index 305bd779..1668a218 100644 --- a/examples/run_peft_tuning.py +++ b/examples/run_peft_tuning.py @@ -241,6 +241,24 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non default="float32", choices=["float16", "bfloat16", "float32"], ) + subparser.add_argument( + "--train_on_completion", + help="Train on completion True or False", + default=False, + choices=[True,False], + ) + subparser.add_argument( + "--train_on_completion", + help="Train on completion True or False", + default=False, + choices=[True,False], + ) + subparser.add_argument( + "--response_template", + help="Response template to identify response", + default=None + ) + def register_multitask_prompt_tuning_args(subparser: argparse.ArgumentParser): @@ -414,6 +432,8 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None: silence_progress_bars=not args.verbose, accumulate_steps=args.accumulate_steps, torch_dtype=args.torch_dtype, + train_on_completion=args.train_on_completion, + response_template=args.response_template ) model.save(args.output_dir, save_base_model=not args.prompt_only) print_colored("[Training Complete]") diff --git a/pyproject.toml b/pyproject.toml index 7df014ce..038b71b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "torch>=2.0.1", "tqdm>=4.65.0", "transformers>=4.32.0", + "trl>=0.7.2", # GK-AUG-25-2023 NOTE: mpt branch on Mayank's fork was merged to peft main on Aug 24 and it got deleted # which broke caikit-nlp build. peft hasn't released newer version yet, so to get # the build fix, we pulling peft from main branch commit. In future, we will pull PEFT from From 318e9cb21add69a063125348144d934030184cf5 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Wed, 6 Dec 2023 21:03:01 -0700 Subject: [PATCH 3/5] fix conflicting argument Signed-off-by: Sukriti-Sharma4 --- examples/run_peft_tuning.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/examples/run_peft_tuning.py b/examples/run_peft_tuning.py index 1668a218..b3905825 100644 --- a/examples/run_peft_tuning.py +++ b/examples/run_peft_tuning.py @@ -247,12 +247,6 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non default=False, choices=[True,False], ) - subparser.add_argument( - "--train_on_completion", - help="Train on completion True or False", - default=False, - choices=[True,False], - ) subparser.add_argument( "--response_template", help="Response template to identify response", From ea6a7dca33c1a35972ec434bef16293541ba0847 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Wed, 6 Dec 2023 22:53:51 -0700 Subject: [PATCH 4/5] update tests for data collator Signed-off-by: Sukriti-Sharma4 --- .../text_generation/peft_prompt_tuning.py | 7 ++-- caikit_nlp/resources/pretrained_model/base.py | 36 ++++++++++--------- .../toolkit/text_generation/training_utils.py | 8 +++-- .../test_peft_prompt_tuning.py | 31 ++++++++++++++++ 4 files changed, 61 insertions(+), 21 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 866b2b0a..abcec02a 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -365,7 +365,8 @@ def train( if train_on_completion: if not response_template: error.value_check( - "", "Response template is need for train on completion" + "", + "Response template is need for train on completion", ) # Configure random seed @@ -518,8 +519,8 @@ def train( training_args, checkpoint_dir, base_model, - train_on_completion, - response_template, + train_on_completion=train_on_completion, + response_template=response_template, ) # Wrap up the trained model in a class instance diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index badede0e..48046ad8 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -21,7 +21,6 @@ # Third Party from torch.utils.data import IterableDataset -from trl import DataCollatorForCompletionOnlyLM from transformers import ( AutoTokenizer, DataCollatorWithPadding, @@ -30,6 +29,7 @@ TrainingArguments, ) from transformers.models.auto.auto_factory import _BaseAutoModelClass +from trl import DataCollatorForCompletionOnlyLM import torch # First Party @@ -299,7 +299,17 @@ def get_trainer( training_args = TrainingArguments(**kwargs) - data_collator = self._get_data_collator(train_on_completion, response_template, **kwargs) + if train_on_completion: + if response_template is None: + error( + "", + "Response Template needs to be set to use completion only", + ) + data_collator = DataCollatorForCompletionOnlyLM( + response_template, tokenizer=self._tokenizer + ) + else: + data_collator = self._get_data_collator(**kwargs) trainer_arguments = { "train_dataset": train_dataset, @@ -314,15 +324,13 @@ def get_trainer( return LoggingTrainer(self._model, training_args, **trainer_arguments) - def _get_data_collator(self, train_on_completion, response_template, **kwargs): + def _get_data_collator(self, **kwargs): """Function to return appropriate data collator based on resource. The default implementation of the base resource uses DataCollatorWithPadding which will dynamically pad the inputs received. Args: - train_on_completion: bool, - response_template: str, **kwargs: All the keyword arguments passed to this function will get filtered out to appropriate ones that are @@ -330,17 +338,13 @@ def _get_data_collator(self, train_on_completion, response_template, **kwargs): Returns: transformers.DataCollator """ - if train_on_completion: - if response_template==None: - error("", "Response Template needs to be set to use completion only") - return DataCollatorForCompletionOnlyLM(response_template, tokenizer=self._tokenizer) - else: - applicable_args = ["max_length", "pad_to_multiple_of"] - collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} - - return DataCollatorWithPadding( - tokenizer=self._tokenizer, padding=True, **collator_kwargs - ) + + applicable_args = ["max_length", "pad_to_multiple_of"] + collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} + + return DataCollatorWithPadding( + tokenizer=self._tokenizer, padding=True, **collator_kwargs + ) # pylint: disable=unused-argument @classmethod diff --git a/caikit_nlp/toolkit/text_generation/training_utils.py b/caikit_nlp/toolkit/text_generation/training_utils.py index 7657f145..c69892a0 100644 --- a/caikit_nlp/toolkit/text_generation/training_utils.py +++ b/caikit_nlp/toolkit/text_generation/training_utils.py @@ -175,13 +175,17 @@ def launch_training( caikit_resource=None, tokenizer=None, train_on_completion=False, - response_template=None + response_template=None, ) -> None: """Utility function to wrap trainer and execute training""" # If we have a caikit resource, grab the trainer through it if caikit_resource is not None: trainer = caikit_resource.get_trainer( - train_dataset=training_dataset, model=base_model, train_on_completion=train_on_completion, response_template=response_template, **training_args + train_dataset=training_dataset, + model=base_model, + train_on_completion=train_on_completion, + response_template=response_template, + **training_args ) else: # If trainer is not provided fetch it from base_model diff --git a/tests/modules/text_generation/test_peft_prompt_tuning.py b/tests/modules/text_generation/test_peft_prompt_tuning.py index 5cf82439..1ae58638 100644 --- a/tests/modules/text_generation/test_peft_prompt_tuning.py +++ b/tests/modules/text_generation/test_peft_prompt_tuning.py @@ -162,6 +162,37 @@ def test_train_model(causal_lm_train_kwargs, set_cpu_device): assert isinstance(pred, GeneratedTextResult) +def test_train_model_on_completion(causal_lm_train_kwargs, set_cpu_device): + """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + caikit_nlp.data_model.GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + caikit_nlp.data_model.GenerationTrainRecord( + input="@bar this is the worst idea ever.", output="complaint" + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + "train_on_completion": True, + "response_template": "#answer:", + } + causal_lm_train_kwargs.update(patch_kwargs) + model = caikit_nlp.modules.text_generation.PeftPromptTuning.train( + **causal_lm_train_kwargs + ) + # Test fallback to float32 behavior if this machine doesn't support bfloat16 + assert model.model.dtype is torch.float32 + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) + + def test_gen_trained_mpt(causal_lm_train_kwargs, set_cpu_device): """Ensure that we are able to do generation on causal-lm model trained using MPT.""" From 5f77a410a2d4e8c23c942e4b23c7c9679544ce7d Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Thu, 7 Dec 2023 12:41:24 -0700 Subject: [PATCH 5/5] enable data collator completion LM Signed-off-by: Sukriti-Sharma4 --- .../modules/text_generation/peft_prompt_tuning.py | 4 +++- caikit_nlp/resources/pretrained_model/base.py | 10 ++++------ .../resources/pretrained_model/hf_auto_causal_lm.py | 13 +++++++++++++ examples/run_peft_tuning.py | 8 ++++---- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index abcec02a..712d3945 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -350,7 +350,9 @@ def train( seed: int Integer to be used as random seed for training. train_on_completion: bool - True will train the model on the generated prompts only. Default: False. + True will train the model on the generated prompts only. + Only applicable to Causal LMs. + Default: False. response_template: Optional[str] = None Only if train_on_completion is set to True, pass a response template that will be used to parse out the response. diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index 48046ad8..56b76d32 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -29,7 +29,6 @@ TrainingArguments, ) from transformers.models.auto.auto_factory import _BaseAutoModelClass -from trl import DataCollatorForCompletionOnlyLM import torch # First Party @@ -305,11 +304,10 @@ def get_trainer( "", "Response Template needs to be set to use completion only", ) - data_collator = DataCollatorForCompletionOnlyLM( - response_template, tokenizer=self._tokenizer - ) - else: - data_collator = self._get_data_collator(**kwargs) + kwargs["train_on_completion"] = train_on_completion + kwargs["response_template"] = response_template + + data_collator = self._get_data_collator(**kwargs) trainer_arguments = { "train_dataset": train_dataset, diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py index fc09734e..d7704293 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py @@ -25,6 +25,7 @@ DataCollatorForLanguageModeling, ) from transformers.models.auto import modeling_auto +from trl import DataCollatorForCompletionOnlyLM import torch # First Party @@ -168,6 +169,18 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": Collator to be used for causal language modeling. """ + if "train_on_completion" in kwargs and kwargs["train_on_completion"]: + applicable_args = ["mlm", "response_template", "instruction_template"] + collator_kwargs = { + key: kwargs[key] for key in applicable_args if key in kwargs + } + + if "mlm" not in collator_kwargs: + collator_kwargs["mlm"] = False + + return DataCollatorForCompletionOnlyLM( + tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs + ) applicable_args = ["mlm", "pad_to_multiple_of"] collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} diff --git a/examples/run_peft_tuning.py b/examples/run_peft_tuning.py index b3905825..9d012134 100644 --- a/examples/run_peft_tuning.py +++ b/examples/run_peft_tuning.py @@ -245,14 +245,14 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non "--train_on_completion", help="Train on completion True or False", default=False, - choices=[True,False], + type=bool, + choices=[True, False], ) subparser.add_argument( "--response_template", help="Response template to identify response", - default=None + default=None, ) - def register_multitask_prompt_tuning_args(subparser: argparse.ArgumentParser): @@ -427,7 +427,7 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None: accumulate_steps=args.accumulate_steps, torch_dtype=args.torch_dtype, train_on_completion=args.train_on_completion, - response_template=args.response_template + response_template=args.response_template, ) model.save(args.output_dir, save_base_model=not args.prompt_only) print_colored("[Training Complete]")