diff --git a/caikit_nlp/config/config.yml b/caikit_nlp/config/config.yml index 92a72d9b..3d69e7ad 100644 --- a/caikit_nlp/config/config.yml +++ b/caikit_nlp/config/config.yml @@ -37,4 +37,4 @@ training_data_limit: add_model_name_here: 10000 runtime: - library: caikit_nlp + library: caikit_nlp \ No newline at end of file diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 5ed8ef66..7954517b 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -13,14 +13,15 @@ # limitations under the License. """This module contains prompt tuning through PEFT""" # Standard -from datetime import datetime from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import gc import json import os +import tempfile # Third Party -from accelerate import Accelerator +from datasets import Dataset +from datasets import IterableDataset as TransformersIterableDataset from peft import ( MultitaskPromptTuningConfig, PeftConfig, @@ -30,12 +31,9 @@ TaskType, get_peft_model, ) -from torch.optim import AdamW -from torch.utils.data import DataLoader -from tqdm import tqdm +from peft.utils.other import fsdp_auto_wrap_policy from transformers import AutoModelForCausalLM, default_data_collator from transformers.models.auto.tokenization_auto import AutoTokenizer -from transformers.optimization import get_linear_schedule_with_warmup import numpy as np import torch import transformers @@ -59,12 +57,7 @@ PromptOutputModelType, TuningConfig, ) -from ...resources.pretrained_model import ( - HFAutoCausalLM, - HFAutoSeq2SeqLM, - PretrainedModelBase, -) -from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper +from ...resources.pretrained_model import HFAutoCausalLM, HFAutoSeq2SeqLM from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype from ...toolkit.task_specific_utils import convert_to_generation_record from ...toolkit.text_generation.model_run_utils import ( @@ -72,6 +65,14 @@ generate_text_func, generate_text_func_stream, ) +from ...toolkit.text_generation.training_utils import ( + ALLOWED_TRAINING_ARGS, + collect_trainer_arguments, + infer_max_steps, + launch_training, + preprocess_function, +) +from ...toolkit.torch_run import get_torch_elastic_launch_config from ...toolkit.trainer_utils import validate_training_data from ...toolkit.verbalizer_utils import render_verbalizer from .peft_config import TuningType, get_peft_config, resolve_base_model @@ -336,8 +337,8 @@ def train( Max length of input sequences being considered. Default: 256. max_target_length: int Max length of target sequences being predicted. Default: 128. - accumulate_steps: int - Number of steps to use for gradient accumulation. Default: 1. + accumulate_steps: int (DEPRECATED) + Optional, number of steps to use for gradient accumulation. Default: None. torch_dtype: str TODO: Optional[Union[torch.dtype, str]] Data type to use for training/inference of the underlying text generation model. @@ -356,18 +357,47 @@ def train( "", len(train_stream) > 0, "train_stream cannot be empty" ) + if accumulate_steps: + log.warning( + "", + "accumulate_steps parameter is DEPRECATED and will be removed in future. \ + This parameter is also not getting used internally anymore", + ) + # Configure random seed transformers.set_seed(seed) # NOTE: Following can be uncommented to allow full determinism # but it can have impact on performance. # transformers.enable_full_determinism(seed) - # HACK - These things can't be passed through the train API currently - - metric = kwargs.get("metric") + torch_dtype = get_torch_dtype(torch_dtype) + # Coerce the passed model into a resource; if we have one, this is a noop + # TODO: When splitting up this mono-module, use the configured resource + # type of the concrete class to bootstrap base_model = resolve_base_model(base_model, cls, torch_dtype) + base_model_name = base_model._model_name + + # Enable gradient checkpointing on base model + # PeftModel checks if the base_model has gradient checkpointing + # enabled and then configures the tensors it creates with appropriate + # setting. If we do not enable this, then we will get `tensor 0` requires + # grad error, where `tensor 0` is created by peft + base_model.model.gradient_checkpointing_enable() + + # Get config of the base model + base_model_config = base_model.get_config() + + # Remove _name_or_path field as a model can be + # saved in different location but still same + del base_model_config["_name_or_path"] + error.value_check( + "", + "_name_or_path" not in base_model_config, + "_name_or_path needs to be removed from config!", + ) + task_type, output_model_types, peft_config, tuning_type = get_peft_config( tuning_type, tuning_config, @@ -377,6 +407,8 @@ def train( verbalizer, ) + log.debug("Peft config [%s]", peft_config) + # Check if data is within limit allowed for this module and model validate_training_data( train_stream, @@ -384,12 +416,8 @@ def train( cls.MODULE_ID, ) - # Coerce the passed model into a resource; if we have one, this is a noop - # TODO: When splitting up this mono-module, use the configured resource - # type of the concrete class to bootstrap - torch_dtype = get_torch_dtype(torch_dtype) - train_stream = train_stream.map(convert_to_generation_record) + if val_stream: error.value_check( "", len(val_stream) > 0, "val_stream cannot be empty" @@ -397,19 +425,6 @@ def train( val_stream = val_stream.map(convert_to_generation_record) - # Convert our datastreams -> data loaders by disguising them as PyTorch iterable datasets - train_dataloader, val_dataloader = cls.create_dataloaders_from_stream( - base_model=base_model, - task_type=task_type, - train_stream=train_stream, - verbalizer=verbalizer, - validation_stream=val_stream or None, - batch_size=batch_size, - max_source_length=max_source_length, - max_target_length=max_target_length, - ) - - log.debug("Peft config [%s]", peft_config) # FIXME: Should only do following line for causal LM (and bloomz?) - check that is the case if isinstance(base_model, HFAutoCausalLM): base_model.model.config.d_model = 1024 @@ -419,34 +434,130 @@ def train( # Convert our Peft model (not just the underlying # transformers model) to the right underlying type. device = cls._get_device(device) - cls.convert_peft_model_to_type(device, peft_model, torch_dtype) - - training_loss_tracker = cls._execute_train_loop( - peft_model, - num_epochs, - train_dataloader, - device, - eval_dataloader=val_dataloader, - metric=metric, - learning_rate=learning_rate, + + # cls.convert_peft_model_to_type(device, peft_model, torch_dtype) + + ## Generate data loader from stream + training_dataset: Union[ + Dataset, TransformersIterableDataset + ] = preprocess_function( + base_model=base_model, + train_stream=train_stream, tokenizer=base_model.tokenizer, - accumulate_steps=accumulate_steps, - silence_progress_bars=silence_progress_bars, - torch_dtype=torch_dtype, + max_source_length=max_source_length, + max_target_length=max_target_length, + shuffle=True, + use_iterable_dataset=False, + random_seed=cls.RANDOM_SEED, + task_ids=0, ) - # Get config of the base model - base_model_config = base_model.get_config() + # Filter **training_arguments to only process allowed ones + filtered_training_arguments = { + k: v for k, v in kwargs.items() if k in ALLOWED_TRAINING_ARGS + } - # Remove _name_or_path field as a model can be - # saved in different location but still same - del base_model_config["_name_or_path"] - error.value_check( - "", - "_name_or_path" not in base_model_config, - "_name_or_path needs to be removed from config!", + extra_training_args = set(kwargs.keys()).difference( + filtered_training_arguments.keys() ) + if extra_training_args: + log.warning( + "", + f"{extra_training_args} parameter(s) not allowed by \ + {cls.__name__} currently and will be ignored!", + ) + + if num_epochs < 1: + log.warning( + "", + f"Number of epochs configured is {num_epochs} which is less than minimum 1. \ + No training will be performed", + ) + + return PeftPromptTuning( + tokenizer=base_model.tokenizer, + model=peft_model, + base_model_config=base_model_config, + base_model_name=base_model_name, + verbalizer=verbalizer, + task_type=task_type, + tuning_type=tuning_type, + output_model_types=output_model_types, + training_metadata={"loss": []}, + ) + + + processing_configuration = {} + + # Conditionally enable sharding if multiple GPUs available + if torch.cuda.is_available() and torch.cuda.device_count() > 1: + processing_configuration = { + "fsdp": "full_shard auto_wrap offload", + "fsdp_config": { + # NOTE: Every transformers model has `_no_split_modules` property that can be + # leveraged to identify the layers to split. This seems to be a decent + # "default" behavior unless we want to optimize further. We will start with + # this generic approach, since it allows us to handle variety + # of models and iterate on it, based on what we encounter. + "fsdp_transformer_layer_cls_to_wrap": base_model._model._no_split_modules, + # We need to use the original parameters for peft because we have mixed values + # for require_grads in our parameters, which otherwise breaks layer flattening + # in FSDP. + "use_orig_params": "True", + # Recommended configs for peft + "fsdp_sharding_strategy": 1, + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_backward_prefetch_policy": "BACKWARD_PRE" + }, + } + + # Open an intermediate checkpoint directory until we've bootstrapped + # our model or we've early exited (if epochs < 1) + with tempfile.TemporaryDirectory() as checkpoint_dir: + + # Formulate training arguments + training_args = collect_trainer_arguments( + torch_dtype, + checkpoint_dir, + batch_size, + num_epochs, + cls.RANDOM_SEED, + learning_rate, + max_steps=infer_max_steps(num_epochs, batch_size, training_dataset), + silence_progress_bars=silence_progress_bars, + # NOTE: following can override above arguments in order + **filtered_training_arguments, + **processing_configuration, + ) + + if torch.cuda.is_available() and torch.cuda.device_count() > 1: + launch_config = get_torch_elastic_launch_config() + + training_loss_history = torch.distributed.launcher.api.elastic_launch( + launch_config, launch_training + )( + peft_model, + training_dataset, + training_args, + checkpoint_dir, + base_model, + ) + # NOTE: We are currently only storing the loss information from + # rank 0, i.e main process. training_loss_history is dictionary containing + # rank of the process as key + training_loss_history = training_loss_history[0] + else: + # Use HF Trainer to kick off training on either + # CPU or GPU + training_loss_history = launch_training( + peft_model, + training_dataset, + training_args, + checkpoint_dir, + base_model, + ) + # Wrap up the trained model in a class instance return cls( tokenizer=base_model.tokenizer, @@ -457,7 +568,7 @@ def train( task_type=task_type, tuning_type=tuning_type, output_model_types=output_model_types, - training_metadata=training_loss_tracker, + training_metadata={"loss": training_loss_history}, # TODO: Export other training params to model as well ) @@ -717,83 +828,6 @@ def get_exportable_prompt_vectors( return prompt_dict - @classmethod - def create_dataloaders_from_stream( - cls, - base_model: "caikit_nlp.resources.pretrained_model.base.PretrainedModelBase", - task_type: str, - train_stream: DataStream[GenerationTrainRecord], - verbalizer: str, - batch_size: int, - max_source_length: int, - max_target_length: int, - validation_stream: Union[DataStream[GenerationTrainRecord], None] = None, - collate_fn: Callable = None, - ) -> Tuple[DataLoader]: - """Build PyTorch data loaders around training and (optionally) evaluation DataStreams. - - Args: - base_model: caikit_nlp.resources.pretrained_model.base.PretrainedModelBase - Base resource model used for underlying generation. - task_type: str - Str indicating which task is being accomplished; currently used for determining - tokenization / preprocessing behavior. - train_stream: DataStream[GenerationTrainRecord] - Data to be used for training the prompt vectors of the generation model. - verbalizer: str - Verbalizer template with which we will render text at both train & inference time. - batch_size: int - Batch size to be used for train/eval data loaders. - max_source_length: int - Maximum length to be used for tokenized sequences. - max_target_length: int - Max length of target sequences being predicted. - validation_stream: Union[DataStream[GenerationTrainRecord], None] - Data to be used for validation throughout the train process or None. - collate_fn: Callable - Function to be used for forming batches via lists of dataset inputs. - - Returns: - Tuple[torch.utils.data.DataLoader] - Training & evaluation datastreams for the provided data, respectively. If no - validation_stream is provided, the returned loader for validation_stream will - be None. - """ - if collate_fn is None: - # collate_fn -> pads and maps our inputs to PyTorch vectors - collate_fn = cls._get_collate_fn(base_model.tokenizer, task_type) - - # Grab the data loaders for this task. - # NOTE: Currently we do not expose the buffer size and we - # default to loading the whole dataset into memory - train_dataloader = cls._get_data_loaders_from_stream( - base_model, - train_stream, - base_model.tokenizer, - batch_size, - collate_fn, - verbalizer, - max_source_length, - max_target_length, - shuffle=True, - ) - if validation_stream is not None: - val_dataloader = cls._get_data_loaders_from_stream( - base_model, - validation_stream, - base_model.tokenizer, - batch_size, - collate_fn, - verbalizer, - max_source_length, - max_target_length, - shuffle=False, - ) - else: - val_dataloader = None - - return train_dataloader, val_dataloader - @classmethod def create_hf_tuning_config( cls, @@ -919,251 +953,6 @@ def _get_collate_fn(tokenizer: AutoTokenizer, task_type: str) -> Callable: # want to set labels ourselves. TODO: centralize collator management. return default_data_collator - @staticmethod - def _get_data_loaders_from_stream( - base_model: PretrainedModelBase, - train_stream: DataStream[GenerationTrainRecord], - tokenizer: AutoTokenizer, - batch_size: int, - collate_fn: Callable, - verbalizer: str, - max_source_length: int, - max_target_length: int, - shuffle: bool, - ) -> DataLoader: - """Get the data loaders for train / evaluation. - Args: - base_model: caikit_nlp.resources.pretrained_model.base.PretrainedModelBase - Base resource model used for underlying generation. - train_stream: DataStream[GenerationTrainRecord] - Data to be used for training the prompt vectors of the generation model. - tokenizer: AutoTokenizer - Model tokenizer to be used in preprocessing, i.e., when we iterate over our data. - batch_size: int - Batch sized to be used when building the DataLoader around the stream. - collate_fn: Callable - Function to be used for forming batches via lists of dataset inputs. - verbalizer: str - Verbalizer template to be used for formatting data. This template may use brackets - to indicate where fields from the data model TrainGenerationRecord must be rendered. - max_source_length: int - Max length of sequences being considered. - max_target_length: int - Max length of target sequences being predicted. - shuffle: bool - Indicates whether or not the stream should reshuffle upon reentry. - - Returns: - torch.utils.data.DataLoader - DataLoader to be used for training / evaluating the stream data. - """ - (tokenize_function, _,) = base_model.build_task_tokenize_closure( - tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0 - ) - mapped_stream = train_stream.map(tokenize_function) - # TODO: Deprecate and remove stream wrapper & use trainer - wrapped_stream = SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle) - dataloader = DataLoader( - wrapped_stream, collate_fn=collate_fn, batch_size=batch_size - ) - - return dataloader - - @classmethod - def _execute_train_loop( - cls, - model: PeftModel, - num_epochs: int, - train_dataloader: DataLoader, - device: str, - eval_dataloader: Union[DataLoader, None] = None, - metric: Optional[Callable] = None, - learning_rate: int = 1e-3, - tokenizer: Union[AutoTokenizer, None] = None, - accumulate_steps: int = 1, - silence_progress_bars: bool = True, - torch_dtype: "torch.dtype" = torch.float32, - ) -> None: - """Execute the core training logic for training the prompt vectors on the frozen model. - Note that this is done by reference. - - Args: - model: PeftModel - Underlying model being leveraged for text generation via prompt tuning. - num_epochs: int - Number of epochs to train. - train_dataloader: torch.utils.data.DataLoader - DataLoader to be used for loading training data. - device: str - Device to be used for training the model. - eval_dataloader: Union[DataLoader, None]. - DataLoader to be used for loading eval data or None. - metric: Union[Callable, None] - Function to be used for evaluating data if an eval data loader is provided. - Default: None. - learning_rate: float - Learning rate to be used while tuning prompt vectors. Default: 1e-3. - tokenizer: Union[AutoTokenizer, None] - Tokenizer for default evaluation; only used if no metric is provided and we have - an eval dataloader. - TODO - remove this can likely be removed. - accumulate_steps: int - Number of steps to use for gradient accumulation. Default: 1. - silence_progress_bars: bool - Silences TQDM progress bars. Default: True - torch_dtype: torch.dtype - Dtype to be used for training. Default: torch.float32 - - Returns: - training_metadata: Dict - Metadata computed during training - """ - optimizer = AdamW(params=model.parameters(), lr=learning_rate) - lr_scheduler = get_linear_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=0, - num_training_steps=(len(train_dataloader) * num_epochs), - ) - - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - - if torch_dtype == torch.float16: - mixed_precision = "fp16" - elif ( - torch.cuda.is_available() - and torch.cuda.is_bf16_supported() - and torch_dtype == torch.bfloat16 - ): - mixed_precision = "bf16" - else: - mixed_precision = "no" - - accelerator = Accelerator( - gradient_accumulation_steps=accumulate_steps, - device_placement=True, - mixed_precision=mixed_precision, - ) - - # Disable cache for training - model.config.use_cache = False - - # Below would send all the data and model to - # configured device and convert them to required dtypes - model, optimizer, new_train_dataloader, lr_scheduler = accelerator.prepare( - model, - optimizer, - train_dataloader, - lr_scheduler, - ) - - training_loss_tracker = [] - - step_count = 1 - - for epoch in range(num_epochs): - step_loss_log = {} - model.train() - total_loss = 0 - tqdm_loader = tqdm(new_train_dataloader, disable=silence_progress_bars) - for batch in tqdm_loader: - - tqdm_loader.set_description("Epoch: {}".format(epoch)) - - # TODO Can this dict comprehension always replace "batch.to(device)" for us? - try: - with accelerator.accumulate(model): - outputs = model(**batch) - loss = outputs.loss - # We are converting loss to float explicitely for later use - # keeping it in tensor form can potentially cause memory issues - loss_float = loss.detach().float().item() - total_loss += loss_float - accelerator.backward(loss) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - step_loss_log[step_count] = loss_float - step_count += 1 - except ( - torch.cuda.OutOfMemoryError # pylint: disable=catching-non-exception - ): - error( - "", - MemoryError("Not enough memory available for training!"), - ) - - log.info("", {"loss": loss_float, "epoch": epoch}) - - for step, loss_val in step_loss_log.items(): - - # Below is added to be propagated and stored as training_metadata - training_loss_tracker.append( - { - "epoch": epoch, - "step": step, - "value": loss_val, - "timestamp": datetime.isoformat(datetime.now()), - } - ) - - if eval_dataloader is not None: - model.eval() - - if metric is not None: - for _, batch in enumerate( - tqdm(eval_dataloader, disable=silence_progress_bars) - ): - batch.to(device) - with torch.no_grad(): - outputs = model(**batch) - predictions = outputs.logits.argmax(dim=-1) - references = batch["labels"] - metric.add_batch( - predictions=predictions, - references=references, - ) - eval_metric = metric.compute() - - log.info("epoch %s: %s", epoch, eval_metric) - else: - eval_loss = 0 - # TODO Can we get away with not maintaining eval_preds? - eval_preds = [] - for _, batch in enumerate( - tqdm(eval_dataloader, disable=silence_progress_bars) - ): - batch = {k: v.to(device) for k, v in batch.items()} - with torch.no_grad(): - outputs = model(**batch) - loss = outputs.loss - eval_loss += loss.detach().float() - - if tokenizer is not None: - eval_preds.extend( - tokenizer.batch_decode( - torch.argmax(outputs.logits, -1) - .detach() - .cpu() - .numpy(), - skip_special_tokens=True, - ) - ) - - eval_epoch_loss = eval_loss / len(train_dataloader) - eval_ppl = torch.exp(eval_epoch_loss) - train_epoch_loss = total_loss / len(eval_dataloader) - train_ppl = torch.exp(train_epoch_loss) - log.debug( - "epoch %s: %s %s %s %s", - epoch, - train_ppl, - train_epoch_loss, - eval_ppl, - eval_epoch_loss, - ) - return {"loss": training_loss_tracker} - @classmethod def _filter_params_for_prompt_config(cls, prompt_config, params): """Utility function to filter out required parameters for prompt_config diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index e0558d4a..65934c77 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -70,6 +70,9 @@ class TextGeneration(ModuleBase): # Below list is taken from # https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments + # FIXME: Temporarily disable duplicate code check here as + # we will remove below code in next iteration when we consolidate HF Trainer + # pylint: disable=duplicate-code allowed_training_args = { "weight_decay", "adam_beta1", diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index eba74744..6adad903 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -279,6 +279,7 @@ def get_trainer( train_dataset: IterableDataset, eval_dataset: Union[IterableDataset, None] = None, optimizers=(None, None), + model=None, **kwargs, ): """ @@ -303,6 +304,10 @@ def get_trainer( "optimizers": optimizers, "eval_dataset": eval_dataset, } + # If extra model is provided, we will configure trainer + # with that model + if model: + return LoggingTrainer(model, training_args, **trainer_arguments) return LoggingTrainer(self._model, training_args, **trainer_arguments) diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py index 598b0136..089dcfd9 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py @@ -98,6 +98,7 @@ def get_trainer( train_dataset: IterableDataset, eval_dataset: Union[IterableDataset, None] = None, optimizers=(None, None), + model=None, **kwargs ): """ @@ -128,6 +129,11 @@ def get_trainer( # "generation_max_length": max_target_length, } + # If extra model is provided, we will configure trainer + # with that model + if model: + return LoggingTrainer(model, training_args, **trainer_arguments) + return LoggingTrainer(self._model, training_args, **trainer_arguments) def _get_data_collator(self, **kwargs): diff --git a/caikit_nlp/toolkit/text_generation/training_utils.py b/caikit_nlp/toolkit/text_generation/training_utils.py new file mode 100644 index 00000000..92904367 --- /dev/null +++ b/caikit_nlp/toolkit/text_generation/training_utils.py @@ -0,0 +1,251 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility script that contains logic for training""" + +# Standard +from typing import List, Optional, Union + +# Third Party +from datasets import Dataset +from datasets import IterableDataset as TransformersIterableDataset +from transformers import AutoTokenizer +import torch + +# First Party +from caikit.core.data_model import DataStream +from caikit.core.exceptions import error_handler +import alog + +# Local +from ...data_model import GenerationTrainRecord +from ...resources.pretrained_model import PretrainedModelBase + +log = alog.use_channel("TXTGEN_TRN_UTLS") +error = error_handler.get(log) + +# Below list is taken from +# https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments +ALLOWED_TRAINING_ARGS = { + "weight_decay", + "adam_beta1", + "adam_beta2", + "adam_epsilon", + "max_grad_norm", + "lr_scheduler_type", + "warmup_ratio", + "warmup_steps", + "use_ipex", + "disable_tqdm", + "label_names", + "optim", + "optim_args", + "group_by_length", + "dataloader_pin_memory", + "gradient_checkpointing", + "full_determinism", +} + +# Create trainer arguments +def collect_trainer_arguments( + torch_dtype, + output_dir, + batch_size, + num_epochs, + random_seed, + learning_rate, + max_steps, + silence_progress_bars=True, + **kwargs +): + """Utility function to return processed HF Trainer argument dictionary""" + + # NOTE: Following is not exhaustive list of all parameters + # for all dtypes + if torch_dtype == torch.float16: + dtype_based_params = { + "fp16": True, + } + elif torch_dtype == torch.bfloat16: + dtype_based_params = { + "bf16": True, + } + else: + # default to float32 + dtype_based_params = {} + + return { + # trainer settings + "output_dir": output_dir, + # NOTE: We have disabled evaluation for now + "do_eval": False, + "do_train": True, + "no_cuda": not torch.cuda.is_available(), + # NOTE: This is explicitly set to false since it will + # negatively impact the performance + "full_determinism": False, + # logging configuration + "logging_strategy": "steps", + "logging_steps": 1, # logging at every step + "disable_tqdm": silence_progress_bars, + # computation configurations + "seed": random_seed, + "per_device_train_batch_size": batch_size, + "per_device_eval_batch_size": batch_size, + "num_train_epochs": num_epochs, + "learning_rate": learning_rate, + "weight_decay": 0.01, + "save_total_limit": 3, + "gradient_checkpointing": True, + # huggingface configurations + "push_to_hub": False, + # dataset configurations + "remove_unused_columns": False, + "dataloader_pin_memory": False, + # Required for iterable dataset + "max_steps": max_steps, + # others + # NOTE: Below would automatically adjust the batch size if the provided + # batch size doesn't fit in memory + "auto_find_batch_size": True, + **dtype_based_params, + **kwargs, + } + + +def preprocess_function( + base_model: PretrainedModelBase, + train_stream: DataStream[GenerationTrainRecord], + tokenizer: AutoTokenizer, + max_source_length: int, + max_target_length: int, + shuffle: bool, + use_iterable_dataset: bool, + random_seed: int, + task_ids: Optional[List[int]] = None, +): + """Pre-process each example to get it prepared for training.""" + dataset_type = TransformersIterableDataset if use_iterable_dataset else Dataset + log.debug("Loading dataset class: [%s]", dataset_type.__name__) + fn_kwargs = { + "tokenizer": tokenizer, + "max_source_length": max_source_length, + "max_target_length": max_target_length, + } + if task_ids is not None: + fn_kwargs["task_ids"] = task_ids + + # TODO: Add check for empty training stream + dataset = dataset_type.from_generator( + get_record, gen_kwargs={"train_stream": train_stream} + ) + mapped_dataset = dataset.map( + base_model.tokenize_function, + fn_kwargs=fn_kwargs, + # For now, we hardcode to False, since causal LM chunking is not exposed yet + batched=False, + # batched=base_model.REQUIRES_TOKEN_UNWRAPPING, + # Drop the input / output columns; we need to do this for dimensions to play + # happily when operating on batched inputs for causal language modeling. + remove_columns=["input", "output"], + ) + + if shuffle: + log.debug("Shuffling the dataset") + return mapped_dataset.shuffle(seed=random_seed) + + return mapped_dataset + + +def launch_training( + base_model, + training_dataset, + training_args, + checkpoint_dir, + caikit_resource=None, + tokenizer=None, +) -> None: + """Utility function to wrap trainer and execute training""" + + # If we have a caikit resource, grab the trainer through it + if caikit_resource is not None: + trainer = caikit_resource.get_trainer( + train_dataset=training_dataset, model=base_model, **training_args + ) + + else: + # If trainer is not provided fetch it from base_m`odel + if hasattr(base_model, "get_trainer"): + trainer = base_model.get_trainer( + train_dataset=training_dataset, **training_args + ) + else: + error("", "could not resolve trainer. Check base model type!") + + # Start training via Trainer.train function + result = trainer.train() + + # Log the output of the training. This will include stats about training + log.info("", "Training completed. Summary: {}".format(result)) + + # save the model temporarily and reload it + # this is done, since otherwise the model might be distributed in different + # devices, in which case its better to use trainer's `prediction_step` + # functions, but then, they don't always give API similar to `generate` + # and thus cause incompatibilities in `run` function + trainer.save_state() + trainer.save_model(checkpoint_dir) + + # save tokenizer explicitly + if hasattr(base_model, "tokenizer"): + base_model.tokenizer.save_pretrained(checkpoint_dir) + elif tokenizer: + tokenizer.save_pretrained(checkpoint_dir) + else: + log.warning( + "", + "Cannot save tokenizer as not available to train function.", + ) + + # Below will return log history but launch will automatically attach rank to it. + # if started in distributed fashion + return trainer.state.log_history + + +def infer_max_steps( + num_epochs: int, + batch_size: int, + training_dataset: Union[Dataset, TransformersIterableDataset], +): + # Calculate the number of samples that we have + if isinstance(training_dataset, Dataset): + data_len = len(training_dataset) + else: + data_len = 0 + for _ in training_dataset: + data_len += 1 + # Figure out how many batches we'll have per epoch + num_batches = data_len // batch_size + # Assume drop_last=False; in general, this doesn't really matter. + # We mostly do this to avoid strange behavior when the dataset + # size is smaller than the batch size. + if num_batches != (data_len * batch_size): + num_batches += 1 + num_steps = num_batches * num_epochs + log.debug("Number of inferred steps: [%s]", num_steps) + return num_steps + + +def get_record(train_stream): + for data in train_stream: + yield {"input": data.input, "output": data.output} diff --git a/caikit_nlp/toolkit/torch_run.py b/caikit_nlp/toolkit/torch_run.py index 3a8879c8..9e78830e 100644 --- a/caikit_nlp/toolkit/torch_run.py +++ b/caikit_nlp/toolkit/torch_run.py @@ -21,6 +21,7 @@ # Standard import os +import uuid # Third Party from torch import cuda @@ -64,16 +65,17 @@ def determine_local_world_size(): def get_torch_elastic_launch_config( - master_addr: str, - master_port: str, start_method: str = "spawn", - max_restarts=3, + max_restarts=1, ) -> LaunchConfig: # Constants; we assume everything executes on the same node min_nodes = 1 max_nodes = 1 rdzv_configs = {"rank": 0} + run_id = str(uuid.uuid4()) + + log.debug("", f"run_id: {run_id}") nproc_per_node = determine_local_world_size() @@ -96,9 +98,14 @@ def get_torch_elastic_launch_config( max_nodes=max_nodes, nproc_per_node=nproc_per_node, start_method=start_method, + # rdzv_backend="c10d", + # rdzv_endpoint="localhost:0", rdzv_backend="static", - rdzv_endpoint=f"{master_addr}:{master_port}", + rdzv_endpoint="localhost:29500", + run_id = run_id, rdzv_configs=rdzv_configs, tee=Std.ALL, - max_restarts=max_restarts, + # TODO: Make this configurable + log_dir="./torch_log/" + # max_restarts=max_restarts, )