diff --git a/caikit_nlp/modules/text_generation/fine_tuning.py b/caikit_nlp/modules/text_generation/fine_tuning.py index cf2ab102..f3933516 100644 --- a/caikit_nlp/modules/text_generation/fine_tuning.py +++ b/caikit_nlp/modules/text_generation/fine_tuning.py @@ -12,17 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Standard +from typing import Optional # Third Party from torch.utils.data import IterableDataset -from transformers import ( - AutoConfig, - AutoTokenizer, - DataCollatorForSeq2Seq, - Seq2SeqTrainer, - Seq2SeqTrainingArguments, - Trainer, -) +from transformers import AutoConfig, AutoTokenizer import torch # First Party @@ -35,6 +30,11 @@ # Local from ...data_model import GenerationTrainRecord +from ...resources.pretrained_model import ( + HFAutoCausalLM, + HFAutoSeq2SeqLM, + PretrainedModelBase, +) from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper from ...toolkit.data_type_utils import get_torch_dtype @@ -55,17 +55,26 @@ class FineTuning(ModuleBase): """Module to provide fine-tuning support for text generation task""" - def __init__(self, tokenizer, model): + RANDOM_SEED = 73 + supported_resources = [HFAutoCausalLM, HFAutoSeq2SeqLM] + + def __init__( + self, + tokenizer, + model, + bos_token: Optional[str] = None, + sep_token: Optional[str] = None, + eos_token: Optional[str] = None, + pad_token: Optional[str] = None, + ): super().__init__() self.tokenizer = tokenizer - # NOTE: self.model here can also be HF trainer. This is because - # if we have just trained the model then the models weights might be - # available in different devices (and configuration), depending on - # how it was trained. For now (July 10, 2023), we are not trying to - # extract the model out from trainer itself, since that would require - # us to essentially save it or reconstruct it to do normal inferring. self.model = model + self._bos_token = bos_token + self._sep_token = sep_token + self._eos_token = eos_token + self._pad_token = pad_token @classmethod def train( @@ -78,12 +87,49 @@ def train( batch_size: int = 8, num_epochs: int = 5, accumulate_steps: int = 32, + random_seed: int = RANDOM_SEED, lr: float = 2e-5, # Directory where model predictions and checkpoints will be written checkpoint_dir: str = "/tmp", + **training_arguments, ): """ - # FIXME: Below is currently configured for Seq2Seq only + Fine-tune a CausalLM or Seq2seq text generation model. + + Args: + base_model: Union[str, caikit_nlp.resources.pretrained_model.base.PretrainedModelBase] + Base resource model used for underlying generation. + train_stream: DataStream[GenerationTrainRecord] or DataStream[ClassificationTrainRecord] + Data to be used for fine-tuning the generation model. + torch_dtype: str + TODO: Optional[Union[torch.dtype, str]] + Data type to use for training/inference of the underlying text generation model. + If no value is provided, we pull from torch_dtype in config. If an in memory + resource is provided which does not match the specified data type, the model + underpinning the resource will be converted in place to the correct torch dtype. + max_source_length: int + Max length of input sequences being considered. Default: 256. + max_target_length: int + Max length of target sequences being predicted. Default: 128. + batch_size: int + Batch sized to be used for training / evaluation data. Default: 8. + num_epochs: int + Number of epochs to tune the model. Default: 20. + accumulate_steps: int + Number of steps to use for gradient accumulation. Default: 1. + lr: float + Learning rate to be used while tuning model. Default: 2e-5. + checkpoint_dir: str + Directory where model predictions and checkpoints will be written + **training_arguments: + Arguments supported by HF Training Arguments. + TrainingArguments: + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments + Seq2SeqTrainingArguments: + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments + Returns: + FineTuning + Instance of this class with fine-tuned models. """ torch_dtype = get_torch_dtype(torch_dtype) @@ -92,11 +138,12 @@ def train( # text_generation module. In future, we would want to consolidate this into # a base class or a toolkit function # pylint: disable=duplicate-code + resource_type = None + ## Load base model if isinstance(base_model, str): model_config = AutoConfig.from_pretrained(base_model) - resource_type = None for resource in cls.supported_resources: if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: resource_type = resource @@ -112,8 +159,14 @@ def train( log.debug("Bootstrapping base resource [%s]", base_model) base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) + else: + # base_model is actually a resource object + resource_type = type(base_model) + + error.type_check("", PretrainedModelBase, base_model=base_model) ## Generate data loader from stream training_dataset: IterableDataset = cls._preprocess_function( + base_model=base_model, train_stream=train_stream, tokenizer=base_model.tokenizer, max_source_length=max_source_length, @@ -144,47 +197,33 @@ def train( # by optionally accepting `training_args` # as argument to this train function. # TODO: Remove all the default used below and make them all configurable - training_args = Seq2SeqTrainingArguments( - output_dir=checkpoint_dir, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - num_train_epochs=num_epochs, + + training_args = { + "output_dir": checkpoint_dir, + "per_device_train_batch_size": batch_size, + "per_device_eval_batch_size": batch_size, + "num_train_epochs": num_epochs, + "seed": random_seed, # NOTE: We have disabled evaluation for now - do_eval=False, - # evaluation_strategy = "epoch", - learning_rate=lr, - weight_decay=0.01, - save_total_limit=3, - predict_with_generate=True, - push_to_hub=False, - no_cuda=False, # Default - generation_max_length=max_target_length, - remove_unused_columns=False, - dataloader_pin_memory=False, - gradient_accumulation_steps=accumulate_steps, - eval_accumulation_steps=accumulate_steps, - logging_strategy="epoch", - disable_tqdm=True, - # NOTE: Following not possible without save and eval strategy - # load_best_model_at_end=True, + "do_eval": False, + # "evaluation_strategy ": "epoch", + "learning_rate": lr, + "weight_decay": 0.01, + "save_total_limit": 3, + "push_to_hub": False, + "no_cuda": False, # Default + "remove_unused_columns": False, + "dataloader_pin_memory": False, + "gradient_accumulation_steps": accumulate_steps, + "eval_accumulation_steps": accumulate_steps, # eval_steps=1, + # load_best_model_at_end + **training_arguments, **dtype_based_params, - ## TODO: Make below configurable - # fsdp="full_shard auto_wrap", - # local_rank=0, - ) - - data_collator = DataCollatorForSeq2Seq( - tokenizer=base_model.tokenizer, model=base_model.model - ) + } - trainer = Seq2SeqTrainer( - base_model.model, - training_args, - train_dataset=training_dataset, - data_collator=data_collator, - tokenizer=base_model.tokenizer, - # compute_metrics=compute_metrics, + trainer = base_model.get_trainer( + train_dataset=training_dataset, **training_args ) if num_epochs < 1: @@ -201,17 +240,25 @@ def train( # Start training via Trainer.train function trainer.train() - # NOTE: By default the model would be available in different ways - # depending on where and how it was trained. So we need to fetch the model - # from the trainer depending on the training method, like fsdp, ddp etc. - # For simplicity, currently we will use trainer as the model since it anyways - # enable the `predict` function on it and has all the layers of the model - # distributed already, so it will be most optimized to use trainer to - # perform prediction at this stage. + + # save the model temporarily and reload it + # this is done, since otherwise the model might be distributed in different + # devices, in which case its better to use trainer's `prediction_step` + # functions, but then, they don't always give API similar to `generate` + # and thus cause incompatibilities in `run` function + trainer.save_model(checkpoint_dir) + + model = resource_type.bootstrap( + checkpoint_dir, checkpoint_dir, torch_dtype=torch_dtype + ) return cls( - tokenizer=base_model.tokenizer, - model=trainer, + tokenizer=model.tokenizer, + model=model, + bos_token=model.tokenizer.bos_token or None, + sep_token=model.tokenizer.sep_token or None, + eos_token=model.tokenizer.eos_token or None, + pad_token=model.tokenizer.pad_token or None, ) # pylint: disable=unused-argument @@ -236,44 +283,41 @@ def run( GeneratedTextResult Generated text result """ - if isinstance(self.model, Trainer): - # Apply the tokenizer to the sample text & move to correct device - tok_tensors = self.tokenizer(text, return_tensors="pt") - # NOTE: below function is prediction on trainer, for which we need to supply - # the actual underlying model as well - # NOTE: We are using prediction_step instead of calling `self.model.generate` - # because this way HF Trainer automatically handles device placement of the - # data and model. Since the model is with Trainer at this point - # and thus the device placement be according to training strategy, - # its better to let Trainer handle the evaluation / prediction - - # TODO: Add support for passing extra arguments to prediction_step - _, generated_tokens, _ = self.model.prediction_step( - self.model.model, - tok_tensors, - prediction_loss_only=False, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - ) - generated_text = self.tokenizer.batch_decode( - generated_tokens.detach().cpu().numpy(), skip_special_tokens=True - )[0] + inputs = self.model.tokenizer(text, return_tensors="pt") + generate_ids = self.model.model.generate( + input_ids=inputs["input_ids"], + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + use_cache=True, + ) - else: - error( - "", - NotImplementedError( - "model prediction on pre-finetuned model currently not supported" - ), + token_count = generate_ids.size(1) - 1 + preds = [ + self.model.tokenizer.decode( + g, skip_special_tokens=True, clean_up_tokenization_spaces=True ) + for g in generate_ids + ] + if generate_ids[0][-1].item() == self._eos_token: + finish_reason = "EOS_TOKEN" + elif generate_ids.size(1) - 1 == max_new_tokens: + finish_reason = "MAX_TOKENS" + else: + finish_reason = "OTHER" - return GeneratedTextResult(generated_text=generated_text) + return GeneratedTextResult( + generated_tokens=token_count, + generated_text=preds[0], + finish_reason=finish_reason, + producer_id=self.PRODUCER_ID, + ) ################################## Private Functions ########################################### @staticmethod def _preprocess_function( + base_model: PretrainedModelBase, train_stream: DataStream[GenerationTrainRecord], tokenizer: AutoTokenizer, max_source_length: int, @@ -282,28 +326,17 @@ def _preprocess_function( ): """Pre-process each example to get it prepared for training.""" - # FIXME: Below is currently configured for Seq2Seq only - - def _tokenization_func( - example: GenerationTrainRecord, - ): - model_inputs = tokenizer( - example.input, - max_length=max_source_length, - truncation=True, - ) - - labels = tokenizer( - example.output, - max_length=max_target_length, - padding="max_length", - truncation=True, - ) - - model_inputs["labels"] = labels["input_ids"] - - return model_inputs - - return SimpleIterableStreamWrapper( - train_stream.map(_tokenization_func), shuffle=shuffle + # TODO: We are using a default verbalizer which is strictly tied to + # source training record currently. We need to figure out a better + # way to make verbalizer optional for build_task_tokenize_function + ( + tokenize_function, + requires_unwrapping, + ) = base_model.build_task_tokenize_function( + tokenizer, max_source_length, max_target_length, verbalizer="{{input}}" ) + mapped_stream = train_stream.map(tokenize_function) + if requires_unwrapping: + mapped_stream = mapped_stream.flatten() + + return SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle) diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 6b4e7e7f..2f4049c8 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -199,6 +199,7 @@ def run( verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) # Apply the tokenizer to the sample text & move to correct device tok_tensors = self.tokenizer(verbalized_text, return_tensors="pt") + device = PeftPromptTuning._get_device(device) inputs = {k: v.to(device) for k, v in tok_tensors.items()} with torch.no_grad(): @@ -604,7 +605,12 @@ def save(self, model_path: str, save_base_model: bool = False): module_saver.update_config(config_options) @classmethod - def load(cls, model_path: str, torch_dtype: str = None) -> "PeftPromptTuning": + def load( + cls, + model_path: str, + torch_dtype: str = None, + device: str = _DETECT_DEVICE, # TODO: Union[int, str] + ) -> "PeftPromptTuning": """Load a PEFT prompt tuning model. This method will currently fail if the original model was not saved with the arg value save_base_model=True. @@ -626,7 +632,7 @@ def load(cls, model_path: str, torch_dtype: str = None) -> "PeftPromptTuning": torch_dtype = str_to_torch_dtype(config.trained_torch_dtype) if config.has_base_model: # TODO: Implement logic for resource loading - device = cls._get_device(cls._DETECT_DEVICE) + device = cls._get_device(device) model_config = os.path.join(model_path, config.full_model_path) peft_config = PeftConfig.from_pretrained(model_config) if peft_config.task_type == "CAUSAL_LM": @@ -1005,7 +1011,7 @@ def _get_data_loaders_from_stream( tokenize_function, requires_unwrapping, ) = base_model.build_task_tokenize_function( - tokenizer, max_source_length, max_target_length, verbalizer + tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0 ) mapped_stream = train_stream.map(tokenize_function) if requires_unwrapping: @@ -1065,8 +1071,11 @@ def _execute_train_loop( num_warmup_steps=0, num_training_steps=(len(train_dataloader) * num_epochs), ) - # Configure accelerator for gradient accumulation - accelerator = Accelerator(gradient_accumulation_steps=accumulate_steps) + + accelerator = Accelerator( + gradient_accumulation_steps=accumulate_steps, device_placement=True + ) + for epoch in range(num_epochs): model.train() total_loss = 0 diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index c2232a6c..59bb4d45 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -14,12 +14,18 @@ # Standard from abc import ABC, abstractmethod -from typing import Callable, List, Optional, Tuple, Type +from typing import Callable, List, Optional, Tuple, Type, Union import json import os # Third Party -from transformers import AutoTokenizer +from torch.utils.data import IterableDataset +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + Trainer, + TrainingArguments, +) from transformers.models.auto.auto_factory import _BaseAutoModelClass import torch @@ -233,6 +239,61 @@ def save( self.tokenizer.save_pretrained(tok_abs_path) self.model.save_pretrained(model_abs_path) + def get_trainer( + self, + train_dataset: IterableDataset, + eval_dataset: Union[IterableDataset, None] = None, + optimizers=(None, None), + **kwargs, + ): + """ + Args: + **kwargs: arguments supported by HF TrainingArguments: + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments + + NOTE: following parameters are not supported currently: + 1. model_init + 2. compute_metrics + 3. callbacks + 4. preprocess_logits_for_metrics + """ + + training_args = TrainingArguments(**kwargs) + + data_collator = self._get_data_collator(**kwargs) + + trainer_arguments = { + "train_dataset": train_dataset, + "data_collator": data_collator, + "tokenizer": self._tokenizer, + "optimizers": optimizers, + "eval_dataset": eval_dataset, + } + + return Trainer(self._model, training_args, **trainer_arguments) + + def _get_data_collator(self, **kwargs): + """Function to return appropriate data collator based on resource. + + The default implementation of the base resource uses + DataCollatorWithPadding which will dynamically pad the inputs received. + + Args: + **kwargs: + All the keyword arguments passed to this function + will get filtered out to appropriate ones that are + applicable to implemented data collator. + Returns: + transformers.DataCollator + """ + + applicable_args = ["max_length", "pad_to_multiple_of"] + collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} + + return DataCollatorWithPadding( + tokenizer=self._tokenizer, padding=True, **collator_kwargs + ) + # pylint: disable=unused-argument @classmethod def get_num_transformers_submodules( @@ -249,6 +310,7 @@ def build_task_tokenize_function( max_source_length: int, max_target_length: int, verbalizer: str, + task_ids: Union[None, int] = None, ) -> Tuple[Callable, bool]: """Builds tokenizer functions which can be mapped over train streams to process data which can then be easily passed to a DataLoader for different model types. @@ -263,6 +325,10 @@ def build_task_tokenize_function( verbalizer: str Verbalizer template to be used for formatting data. This template may use brackets to indicate where fields from the data model TrainGenerationRecord must be rendered. + task_ids: Union[None, int] + Task id corresponding particular task for multi-task prompt tuning. + NOTE: Only required for MPT (Multi-task prompt tuning) + Default: None Returns: Tuple(Callable, bool) diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py index b98a2983..30c0be20 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py @@ -16,10 +16,10 @@ """ # Standard from copy import deepcopy -from typing import Callable, Tuple +from typing import Callable, Tuple, Union # Third Party -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling from transformers.models.auto import modeling_auto # First Party @@ -52,6 +52,7 @@ def build_task_tokenize_function( max_source_length: int, max_target_length: int, verbalizer: str, + task_ids: Union[None, int] = None, ) -> Tuple[Callable, bool]: """Builds tokenizer functions which can be mapped over train streams to process data which can then be easily passed to a DataLoader for CausalLM models. @@ -66,6 +67,10 @@ def build_task_tokenize_function( verbalizer: str Verbalizer template to be used for formatting data. This template may use brackets to indicate where fields from the data model TrainGenerationRecord must be rendered. + task_ids: Union[None, int] + Task id corresponding particular task for multi-task prompt tuning. + NOTE: Only required for MPT (Multi-task prompt tuning) + Default: None Returns: Tuple(Callable, bool) @@ -104,7 +109,9 @@ def tokenize_function_language_model( # Here, we need to yield and manipulate the attention mask to attend # to the input seq + the tokens we have seen so far... num_target_samples = len(target_ids.input_ids) - source_ids["task_ids"] = 0 + + if task_ids is not None: + source_ids["task_ids"] = task_ids def generator_func(): for idx in range(num_target_samples): @@ -122,3 +129,32 @@ def generator_func(): return DataStream(generator_func) return (tokenize_function_language_model, True) + + def _get_data_collator(self, **kwargs): + """Function to return appropriate data collator based on resource. + + DataCollatorForLanguageModeling is used here which will dynamically + padded to maximum length of a batch if they are not all of the same + length. + + NOTE: If mlm (masked language modeling) is not passed in kwargs, + this function will automatically set it to `False`. + + Args: + **kwargs: + All the keyword arguments passed to this function + will get filtered out to appropriate ones that are + applicable to implemented data collator. + Returns: + transformers.DataCollator + """ + + applicable_args = ["mlm", "pad_to_multiple_of"] + collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} + + if "mlm" not in collator_kwargs: + collator_kwargs["mlm"] = False + + return DataCollatorForLanguageModeling( + tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs + ) diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py index a8b41d45..bdd69aa1 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py @@ -15,10 +15,16 @@ Huggingface auto causal LM resource type """ # Standard -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Union # Third Party -from transformers import AutoModelForSeq2SeqLM +from torch.utils.data import IterableDataset +from transformers import ( + AutoModelForSeq2SeqLM, + DataCollatorForSeq2Seq, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, +) from transformers.models.auto import modeling_auto # First Party @@ -68,12 +74,72 @@ def get_num_transformers_submodules( ) return num_transformer_submodules + def get_trainer( + self, + train_dataset: IterableDataset, + eval_dataset: Union[IterableDataset, None] = None, + optimizers=(None, None), + **kwargs + ): + """ + Args: + *kwargs: arguments supported by HF Seq2SeqTrainingArguments: + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments + + NOTE: following parameters are not supported currently: + 1. model_init + 2. compute_metrics + 3. callbacks + 4. preprocess_logits_for_metrics + """ + + # NOTE: predict_with_generate is incompatible with fsdp + training_args = Seq2SeqTrainingArguments(**kwargs) + + # pylint: disable=duplicate-code + # TODO: Fetch DataCollator either from property of this + # class or fetch it as an argument. + data_collator = self._get_data_collator(**kwargs) + + trainer_arguments = { + "train_dataset": train_dataset, + "data_collator": data_collator, + "tokenizer": self._tokenizer, + "optimizers": optimizers, + "eval_dataset": eval_dataset, + # "generation_max_length": max_target_length, + } + + return Seq2SeqTrainer(self._model, training_args, **trainer_arguments) + + def _get_data_collator(self, **kwargs): + """Function to return appropriate data collator based on resource. + + This implementation uses DataCollatorForSeq2Seq + + Args: + **kwargs: + All the keyword arguments passed to this function + will get filtered out to appropriate ones that are + applicable to implemented data collator. + Returns: + transformers.DataCollator + """ + + applicable_args = ["max_length", "pad_to_multiple_of"] + collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} + + return DataCollatorForSeq2Seq( + tokenizer=self._tokenizer, model=self._model, **collator_kwargs + ) + @staticmethod def build_task_tokenize_function( tokenizer: "AutoTokenizer", max_source_length: int, max_target_length: int, verbalizer: str, + task_ids: Union[None, int] = None, ) -> Tuple[Callable, bool]: """Builds tokenizer functions which can be mapped over train streams to process data which can then be easily passed to a DataLoader for seq2seq models. @@ -88,6 +154,10 @@ def build_task_tokenize_function( verbalizer: str Verbalizer template to be used for formatting data. This template may use brackets to indicate where fields from the data model TrainGenerationRecord must be rendered. + task_ids: Union[None, int] + Task id corresponding particular task for multi-task prompt tuning. + NOTE: Only required for MPT (Multi-task prompt tuning) + Default: None Returns: Tuple(Callable, bool) @@ -134,7 +204,9 @@ def tokenize_function_seq2seq( map(lambda x: IGNORE_ID if x == tokenizer.pad_token_id else x, labels) ) model_inputs["labels"] = labels - model_inputs["task_ids"] = 0 + if task_ids is not None: + model_inputs["task_ids"] = task_ids + return model_inputs return (tokenize_function_seq2seq, False) diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 2cc4f0b7..324cdee3 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -32,6 +32,20 @@ SEQ2SEQ_LM_MODEL = os.path.join(TINY_MODELS_DIR, "T5ForConditionalGeneration") +@pytest.fixture() +def set_cpu_device(request): + """Fixture to set default cuda device. + This fixture is particularly useful for running the unit tests where + cuda devices are available, in which case, some transformers function + may try to consume cuda and give device mismatch error. + """ + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") + os.environ["CUDA_VISIBLE_DEVICES"] = "" + with mock.patch.object(torch.cuda, "is_available", return_value=False): + yield + os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices + + @pytest.fixture def disable_wip(request): """Fixture to temporarily disable wip decorator""" diff --git a/tests/modules/text_generation/test_fine_tuning.py b/tests/modules/text_generation/test_fine_tuning.py index 17a611b1..a17f5ffa 100644 --- a/tests/modules/text_generation/test_fine_tuning.py +++ b/tests/modules/text_generation/test_fine_tuning.py @@ -10,12 +10,18 @@ # Local from caikit_nlp.data_model import GenerationTrainRecord from caikit_nlp.modules.text_generation import FineTuning -from caikit_nlp.resources.pretrained_model import HFAutoSeq2SeqLM -from tests.fixtures import SEQ2SEQ_LM_MODEL, disable_wip +from caikit_nlp.resources.pretrained_model import HFAutoCausalLM, HFAutoSeq2SeqLM +from tests.fixtures import ( + CAUSAL_LM_MODEL, + SEQ2SEQ_LM_MODEL, + disable_wip, + set_cpu_device, +) -def test_train_model(disable_wip): - """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" +def test_train_model_seq2seq(disable_wip, set_cpu_device): + """Ensure that we can finetune a seq2seq model on some toy data for 1+ + steps & run inference.""" train_kwargs = { "base_model": HFAutoSeq2SeqLM.bootstrap( model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL @@ -34,7 +40,32 @@ def test_train_model(disable_wip): "torch_dtype": torch.float32, } model = FineTuning.train(**train_kwargs) - assert isinstance(model.model, Trainer) + assert isinstance(model.model, HFAutoSeq2SeqLM) + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) + + +def test_train_model_causallm(disable_wip, set_cpu_device): + """Ensure that we can finetune a causal-lm model on some toy data for 1+ + steps & run inference.""" + train_kwargs = { + "base_model": HFAutoCausalLM.bootstrap( + model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + ] + ), + "torch_dtype": torch.float32, + } + model = FineTuning.train(**train_kwargs) + assert isinstance(model.model, HFAutoCausalLM) + # Ensure that we can get something out of it pred = model.run("@bar what a cute cat!") assert isinstance(pred, GeneratedTextResult) diff --git a/tests/modules/text_generation/test_peft_prompt_tuning.py b/tests/modules/text_generation/test_peft_prompt_tuning.py index 8ce87ff2..907338d0 100644 --- a/tests/modules/text_generation/test_peft_prompt_tuning.py +++ b/tests/modules/text_generation/test_peft_prompt_tuning.py @@ -30,14 +30,16 @@ causal_lm_train_kwargs, seq2seq_lm_dummy_model, seq2seq_lm_train_kwargs, + set_cpu_device, ) import caikit_nlp # Indexes into the peft config dictionary to get the actual prompt tuning config DEFAULT_ADAPTER = "default" + ### Tests validating block interfaces and behavior -def test_save_and_reload_with_base_model(causal_lm_dummy_model): +def test_save_and_reload_with_base_model(causal_lm_dummy_model, set_cpu_device): """Ensure that we can save a model + its base to a tempdir and reload it.""" with tempfile.TemporaryDirectory() as model_dir: causal_lm_dummy_model.save(model_dir, save_base_model=True) @@ -109,7 +111,7 @@ def test_verbalizer_cannot_be_static(causal_lm_train_kwargs): ) -def test_train_model(causal_lm_train_kwargs): +def test_train_model(causal_lm_train_kwargs, set_cpu_device): """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" patch_kwargs = { "num_epochs": 1, @@ -138,7 +140,7 @@ def test_train_model(causal_lm_train_kwargs): assert isinstance(pred, GeneratedTextResult) -def test_train_model_classification_record(causal_lm_train_kwargs): +def test_train_model_classification_record(causal_lm_train_kwargs, set_cpu_device): """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" patch_kwargs = { "num_epochs": 1, diff --git a/tests/resources/test_pretrained_model.py b/tests/resources/test_pretrained_model.py index 0b377e28..d7e5a748 100644 --- a/tests/resources/test_pretrained_model.py +++ b/tests/resources/test_pretrained_model.py @@ -128,6 +128,7 @@ def test_causal_lm_tok_output_correctness(models_cache_dir): max_source_length=100, max_target_length=100, verbalizer="{{input}}", + task_ids=0, ) input_tok = causal_lm.tokenizer.encode(sample.input) output_tok = causal_lm.tokenizer.encode(sample.output) @@ -170,6 +171,7 @@ def test_seq2seq_tokenize_func_contains_unwrapped_stream(models_cache_dir): max_source_length=100, max_target_length=100, verbalizer="{{input}}", + task_ids=0, ) tok_res = tok_func(GenerationTrainRecord(input="hello", output="world")) map_stream = SAMPLE_TRAINING_DATA.map(tok_func) @@ -195,6 +197,7 @@ def test_seq2seq_tok_output_correctness(models_cache_dir): max_source_length=20, max_target_length=20, verbalizer="{{input}}", + task_ids=0, ) input_tok = seq2seq.tokenizer.encode(sample.input) output_tok = seq2seq.tokenizer.encode(sample.output)