diff --git a/caikit_nlp/modules/text_classification/classification_prompt_tuning.py b/caikit_nlp/modules/text_classification/classification_prompt_tuning.py new file mode 100644 index 00000000..bb9ac83c --- /dev/null +++ b/caikit_nlp/modules/text_classification/classification_prompt_tuning.py @@ -0,0 +1,240 @@ +# Standard +from typing import List, Optional, Union +import os + +# First Party +from caikit.core.data_model import DataStream +from caikit.core.modules import ( + ModuleBase, + ModuleConfig, + ModuleLoader, + ModuleSaver, + module, +) +from caikit.core.toolkit import error_handler, wip_decorator +from caikit.interfaces.nlp.data_model import ( + ClassificationResult, + ClassificationResults, + ClassificationTrainRecord, +) +from caikit.interfaces.nlp.tasks import TextClassificationTask +import alog + +# Local +from ...data_model import TuningConfig +from ...toolkit.task_specific_utils import get_sorted_unique_class_labels +from ..text_generation import PeftPromptTuning + +log = alog.use_channel("CLASSIFICATION_PROMPT") +error = error_handler.get(log) + +# TODO: try to refactor this into a smaller module +# pylint: disable=too-many-lines,too-many-instance-attributes +@module( + id="6713731b-160b-4sc5-8df4-167126e2cd11", + name="Classification Peft Tuning", + version="0.1.0", + task=TextClassificationTask, +) +class ClassificationPeftPromptTuning(ModuleBase): + + _DETECT_DEVICE = "__DETECT__" + + def __init__( + self, + classifier: PeftPromptTuning, + unique_class_labels: List[str], + ): + super().__init__() + error.type_check( + "", + PeftPromptTuning, + classifier=classifier, + ) + error.type_check( + "", + List, + unique_class_labels=unique_class_labels, + ) + self.classifier = classifier + self.unique_class_labels = unique_class_labels + + @classmethod + @wip_decorator.work_in_progress( + category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.WARNING + ) + def train( + cls, + base_model: str, # TODO: Union[str, PretrainedModelBase] + train_stream: DataStream[ClassificationTrainRecord], + tuning_config: TuningConfig, + val_stream: DataStream[ClassificationTrainRecord] = None, + device: str = _DETECT_DEVICE, # TODO: Union[int, str] + tuning_type: str = "PROMPT_TUNING", # TODO: Union[str, TuningType] + num_epochs: int = 20, + lr: float = 0.3, + verbalizer: str = "{{input}}", + batch_size: int = 8, + max_source_length: int = 256, + max_target_length: int = 128, + accumulate_steps: int = 32, + torch_dtype: str = None, # TODO: Optional[Union[torch.dtype, str]] + silence_progress_bars: bool = True, + **kwargs, + ) -> "ClassificationPeftPromptTuning": + """Run prompt tuning (vanilla or MPT) through PEFT on a CausalLM or Seq2seq model + to refine a text generation model. + + Args: + base_model: Union[str, caikit_nlp.resources.pretrained_model.base.PretrainedModelBase] + Base resource model used for underlying generation. + train_stream: DataStream[ClassificationTrainRecord] + Data to be used for training the prompt vectors of the generation model. + tuning_config: TuningConfig + Additional model tuning configurations to be considered for prompt vector + initialization and training behavior. + val_stream: Optional[DataStream[ClassificationTrainRecord] + Data to be used for validation throughout the train process or None. + device: str + Device to be used for training the model. Default: cls._DETECT_DEVICE, which + will fall back to "cuda" if available, else None. + tuning_type: str + Type of Peft Tuning config which we would like to build. + num_epochs: int + Number of epochs to tune the prompt vectors. Default: 20. + lr: float + Learning rate to be used while tuning prompt vectors. Default: 1e-3. + verbalizer: str + Verbalizer template to be used for formatting data at train and inference time. + This template may use brackets to indicate where fields from the data model + TrainGenerationRecord must be rendered. Default: "{{input}}", i.e., the raw text. + batch_size: int + Batch sized to be used for training / evaluation data. Default: 8. + max_source_length: int + Max length of input sequences being considered. Default: 256. + max_target_length: int + Max length of target sequences being predicted. Default: 128. + accumulate_steps: int + Number of steps to use for gradient accumulation. Default: 1. + torch_dtype: str + TODO: Optional[Union[torch.dtype, str]] + Data type to use for training/inference of the underlying text generation model. + If no value is provided, we pull from torch_dtype in config. If an in memory + resource is provided which does not match the specified data type, the model + underpinning the resource will be converted in place to the correct torch dtype. + silence_progress_bars: bool + Silences TQDM progress bars at train time. Default: True. + Returns: + ClassificationPeftPromptTuning + Instance of this class with tuned prompt vectors. + """ + + unique_class_labels = get_sorted_unique_class_labels(train_stream) + # Wrap up the trained model in a class instance + return cls( + classifier=PeftPromptTuning.train( + base_model, + train_stream, + tuning_config, + val_stream, + device, + tuning_type, + num_epochs, + lr, + verbalizer, + batch_size, + max_source_length, + max_target_length, + accumulate_steps, + torch_dtype, + silence_progress_bars, + **kwargs, + ), + unique_class_labels=unique_class_labels, + # TODO: Export other training params to model as well + ) + + # TODO: enable passing save_base_model flag as argument when supported by caikit + @wip_decorator.work_in_progress( + category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.WARNING + ) + def save(self, model_path: str, save_base_model: bool = False): + """Save classification model + + Args: + model_path: str + Folder to save classification prompt tuning model + save_base_model: bool + Save base model along with the prompts in the model_path provided. + Default: False + """ + saver = ModuleSaver(self, model_path=model_path) + with saver: + saver.save_module( + self.classifier, "artifacts", save_base_model=save_base_model + ) + saver.update_config( + { + "unique_class_labels": self.unique_class_labels, + } + ) + + @classmethod + @wip_decorator.work_in_progress( + category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.WARNING + ) + def load(cls, model_path: str) -> "ClassificationPeftPromptTuning": + """Load a classification model. + + Args: + model_path: str + Path to the model to be loaded. + + Returns: + ClassificationPeftPromptTuning + Instance of this class. + """ + config = ModuleConfig.load(os.path.abspath(model_path)) + loader = ModuleLoader(model_path) + classifier = loader.load_module("artifacts") + return cls( + classifier=classifier, + unique_class_labels=config.unique_class_labels, + ) + + # TODO: Currently only singlelabel classification is supported, \ + # hence it will always return list of 1 element. + # Support for multilabel may be added in future. + def run( + self, + text: str, + device: Optional[Union[str, int]] = _DETECT_DEVICE, + max_new_tokens=20, + min_new_tokens=0, + ) -> ClassificationResults: + """Run the classifier model. + + Args: + text: str + Input string to be used to the classification model. + device: Optional[Union[str, int]] + Device on which we should run inference; by default, we use the detected device. + max_new_tokens: int + The maximum numbers of tokens to generate for class label. + Default: 20 + min_new_tokens: int + The minimum numbers of tokens to generate. + Default: 0 - means no minimum + + Returns: + ClassificationResults + """ + gen_result = self.classifier.run(text, device, max_new_tokens, min_new_tokens) + # Either return supported class labels or None + label = ( + gen_result.generated_text + if gen_result.generated_text in self.unique_class_labels + else None + ) + + return ClassificationResults(results=[ClassificationResult(label=label)]) diff --git a/caikit_nlp/toolkit/task_specific_utils.py b/caikit_nlp/toolkit/task_specific_utils.py index 0d42fd35..fcc3e9de 100644 --- a/caikit_nlp/toolkit/task_specific_utils.py +++ b/caikit_nlp/toolkit/task_specific_utils.py @@ -39,3 +39,22 @@ def convert_to_generation_record(train_record): and GenerationTrainRecord are supported" ), ) + + +def get_sorted_unique_class_labels(data_stream): + """Get the list of sorted unique class labels from a data stream of ClassificationTrainRecord. + + Args: + data_stream: DataStream[ClassificationTrainRecord] + Data stream of ClassificationTrainRecord from which to extract unique class labels + Returns: + unique_labels + Sorted list containing the unique set of classes discovered in the data stream + """ + labels_data_stream = data_stream.map(lambda item: item.labels) + unique_labels = set() + for label_list in labels_data_stream: + for label in label_list: + unique_labels.add(label) + + return sorted(unique_labels) diff --git a/tests/modules/text_classification/test_classification_prompt_tuning.py b/tests/modules/text_classification/test_classification_prompt_tuning.py new file mode 100644 index 00000000..0dcc2e78 --- /dev/null +++ b/tests/modules/text_classification/test_classification_prompt_tuning.py @@ -0,0 +1,158 @@ +"""Tests for sequence classification module +""" +# Standard +import os +import tempfile + +# Third Party +import pytest +import torch + +# First Party +from caikit.interfaces.nlp.data_model import ( + ClassificationResults, + ClassificationTrainRecord, +) +import caikit + +# Local +from caikit_nlp.modules.text_classification.classification_prompt_tuning import ( + ClassificationPeftPromptTuning, +) +from caikit_nlp.modules.text_generation.peft_prompt_tuning import PeftPromptTuning +from tests.fixtures import causal_lm_dummy_model, causal_lm_train_kwargs +import caikit_nlp + +#################### +## train/run ## +#################### + + +def test_train_model(causal_lm_train_kwargs): + """Ensure that we can train a model on some toy data for 1+ steps""" + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + ClassificationTrainRecord( + text="@foo what a cute dog!", labels=["no complaint"] + ), + ClassificationTrainRecord( + text="@bar this is the worst idea ever.", labels=["complaint"] + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + } + causal_lm_train_kwargs.update(patch_kwargs) + model = ClassificationPeftPromptTuning.train(**causal_lm_train_kwargs) + # Test fallback to float32 behavior if this machine doesn't support bfloat16 + assert model.classifier.model.dtype is torch.float32 + assert isinstance(model, ClassificationPeftPromptTuning) + + +# TODO: add test for scores in future when implemented +def test_run_classification_model(causal_lm_dummy_model): + classifier_model = ClassificationPeftPromptTuning( + classifier=causal_lm_dummy_model, + unique_class_labels=["LABEL_0", "LABEL_1", "LABEL_2"], + ) + output = classifier_model.run("Text does not matter") + assert isinstance(output, ClassificationResults) + # Returns supported class labels or None + classifier_model.unique_class_labels.append(None) + assert output.results[0].label in classifier_model.unique_class_labels + assert output.results[0].score == None + + +def test_train_run_model_classification_record(causal_lm_train_kwargs): + """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + ClassificationTrainRecord( + text="@foo what a cute dog!", labels=["no complaint"] + ), + ClassificationTrainRecord( + text="@bar this is the worst idea ever.", labels=["complaint"] + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + } + causal_lm_train_kwargs.update(patch_kwargs) + model = ClassificationPeftPromptTuning.train(**causal_lm_train_kwargs) + # Test fallback to float32 behavior if this machine doesn't support bfloat16 + assert model.classifier.model.dtype is torch.float32 + assert isinstance(model, ClassificationPeftPromptTuning) + output = model.run("Text does not matter") + assert isinstance(output, ClassificationResults) + assert model.unique_class_labels == ["complaint", "no complaint"] + # Returns supported class labels or None + model.unique_class_labels.append(None) + assert output.results[0].label in model.unique_class_labels + assert output.results[0].score == None + + +#################### +## save/load ## +#################### + + +def test_save(causal_lm_dummy_model): + classifier_model = ClassificationPeftPromptTuning( + classifier=causal_lm_dummy_model, unique_class_labels=["label1", "label2"] + ) + with tempfile.TemporaryDirectory() as model_dir: + classifier_model.save(model_dir) + assert os.path.exists(os.path.join(model_dir, "config.yml")) + assert os.path.exists(os.path.join(model_dir, "artifacts", "config.yml")) + + +def test_save_and_reload_with_base_model(causal_lm_dummy_model): + classifier_model = ClassificationPeftPromptTuning( + classifier=causal_lm_dummy_model, unique_class_labels=["label1", "label2"] + ) + with tempfile.TemporaryDirectory() as model_dir: + classifier_model.save(model_dir, save_base_model=True) + model_load = caikit_nlp.load(model_dir) + assert isinstance(model_load, ClassificationPeftPromptTuning) + assert isinstance(model_load.classifier, PeftPromptTuning) + assert model_load.unique_class_labels == ["label1", "label2"] + + +def test_save_and_reload_without_base_model(causal_lm_dummy_model): + """Ensure that if we don't save the base model, we get the expected behavior.""" + with tempfile.TemporaryDirectory() as model_dir: + causal_lm_dummy_model.save(model_dir, save_base_model=False) + # For now, if we are missing the base model at load time, we throw ValueError + with pytest.raises(ValueError): + caikit_nlp.load(model_dir) + + +#################### +## save/load/run ## +#################### + + +def test_save_reload_and_run_with_base_model(causal_lm_dummy_model): + classifier_model = ClassificationPeftPromptTuning( + classifier=causal_lm_dummy_model, unique_class_labels=["label1", "label2"] + ) + with tempfile.TemporaryDirectory() as model_dir: + classifier_model.save(model_dir, save_base_model=True) + model_load = caikit_nlp.load(model_dir) + assert isinstance(model_load, ClassificationPeftPromptTuning) + assert isinstance(model_load.classifier, PeftPromptTuning) + assert model_load.unique_class_labels == ["label1", "label2"] + output = model_load.run("Text does not matter") + assert isinstance(output, ClassificationResults) + # Returns supported class labels or None + model_load.unique_class_labels.append(None) + assert output.results[0].label in model_load.unique_class_labels + assert output.results[0].score == None diff --git a/tests/toolkit/test_task_specific_utils.py b/tests/toolkit/test_task_specific_utils.py index 98b9f8d0..361e19e2 100644 --- a/tests/toolkit/test_task_specific_utils.py +++ b/tests/toolkit/test_task_specific_utils.py @@ -16,11 +16,15 @@ import pytest # First Party +from caikit.core.data_model import DataStream from caikit.interfaces.nlp.data_model import ClassificationTrainRecord # Local from caikit_nlp.data_model import GenerationTrainRecord -from caikit_nlp.toolkit.task_specific_utils import convert_to_generation_record +from caikit_nlp.toolkit.task_specific_utils import ( + convert_to_generation_record, + get_sorted_unique_class_labels, +) def test_convert_classification_train_record_to_generation_record(): @@ -56,3 +60,17 @@ def test_convert_to_generation_record_gives_error_with_unsupported_type(): string_record = "test record" with pytest.raises(TypeError): convert_to_generation_record(string_record) + + +def test_get_sorted_unique_class_labels(): + # Sample train data + sample_data = [ + ClassificationTrainRecord(text="foo bar", labels=["label1"]), + ClassificationTrainRecord( + text="foo bar", labels=["label1", "label2", "label3"] + ), + ] + output_labels = ["label1", "label2", "label3"] + sample_stream = DataStream.from_iterable(sample_data) + class_labels = get_sorted_unique_class_labels(sample_stream) + assert output_labels == class_labels