diff --git a/trl/experimental/bco/bco_trainer.py b/trl/experimental/bco/bco_trainer.py index 9cdf47c65ee..a5a10725ae0 100644 --- a/trl/experimental/bco/bco_trainer.py +++ b/trl/experimental/bco/bco_trainer.py @@ -65,7 +65,7 @@ if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training if is_wandb_available(): import wandb @@ -469,15 +469,21 @@ def __init__( if isinstance(ref_model, str): ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **model_init_kwargs) + # PEFT # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` # has been called in order to properly call autocast if needed. self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if isinstance(model, PeftModel): raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " diff --git a/trl/experimental/cpo/cpo_trainer.py b/trl/experimental/cpo/cpo_trainer.py index f3e0f39920e..432c0bfd070 100644 --- a/trl/experimental/cpo/cpo_trainer.py +++ b/trl/experimental/cpo/cpo_trainer.py @@ -61,7 +61,7 @@ if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training if is_wandb_available(): @@ -169,15 +169,21 @@ def __init__( if isinstance(model, str): model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + # PEFT # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` # has been called in order to properly call autocast if needed. self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if isinstance(model, PeftModel): raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index 77dbcc6c08e..f80dd19358a 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -430,6 +430,16 @@ def __init__( # ── PEFT ── if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) model = get_peft_model(model, peft_config) # ── Data collator ── diff --git a/trl/experimental/kto/kto_trainer.py b/trl/experimental/kto/kto_trainer.py index d47fa02637b..2bcc15210c2 100644 --- a/trl/experimental/kto/kto_trainer.py +++ b/trl/experimental/kto/kto_trainer.py @@ -67,7 +67,7 @@ from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training if TYPE_CHECKING: @@ -282,20 +282,27 @@ def __init__( if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + # PEFT # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` # has been called in order to properly call autocast if needed. self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models" - ) - if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " - "and unload the existing adapter, save the resulting base model, and then pass that base model along " - "with the new `peft_config` to the trainer." - ) + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) + if isinstance(model, PeftModel): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) if is_peft_available() and isinstance(model, PeftModel) and ref_model is None: # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy # of the "default" adapter, so that we can use it as the reference model during KTO training. diff --git a/trl/experimental/online_dpo/online_dpo_trainer.py b/trl/experimental/online_dpo/online_dpo_trainer.py index 4c7adef3b6d..5cb73fbcc83 100644 --- a/trl/experimental/online_dpo/online_dpo_trainer.py +++ b/trl/experimental/online_dpo/online_dpo_trainer.py @@ -282,6 +282,18 @@ def __init__( self.is_encoder_decoder = model.config.is_encoder_decoder self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys() + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): model = prepare_peft_model(model, peft_config, args) diff --git a/trl/experimental/orpo/orpo_trainer.py b/trl/experimental/orpo/orpo_trainer.py index 59d7636efb7..22e6a81fbf6 100644 --- a/trl/experimental/orpo/orpo_trainer.py +++ b/trl/experimental/orpo/orpo_trainer.py @@ -62,7 +62,7 @@ if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training if is_wandb_available(): @@ -178,15 +178,21 @@ def __init__( if isinstance(model, str): model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + # PEFT # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` # has been called in order to properly call autocast if needed. self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if isinstance(model, PeftModel): raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " diff --git a/trl/experimental/ppo/ppo_trainer.py b/trl/experimental/ppo/ppo_trainer.py index 6366f987ec4..0d3e32f0a46 100644 --- a/trl/experimental/ppo/ppo_trainer.py +++ b/trl/experimental/ppo/ppo_trainer.py @@ -416,12 +416,18 @@ def __init__( "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details." ) - # peft support - if not is_peft_available() and peft_config is not None: - raise ImportError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if isinstance(self.policy_model, PeftModel): raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " diff --git a/trl/experimental/prm/prm_trainer.py b/trl/experimental/prm/prm_trainer.py index 7b26b69bd82..c6bf17ad453 100644 --- a/trl/experimental/prm/prm_trainer.py +++ b/trl/experimental/prm/prm_trainer.py @@ -44,7 +44,7 @@ if is_peft_available(): - from peft import PeftModel + from peft import PeftConfig, PeftModel logger = logging.get_logger(__name__) @@ -172,6 +172,18 @@ def __init__( if train_dataset is None: raise ValueError("`train_dataset` is required") + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): model = prepare_peft_model(model, peft_config, args) diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 5bf6095c2a0..a4be9b395fe 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -195,6 +195,18 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if is_peft_available() and is_peft_model(model) and peft_config is not None: raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to SDFTTrainer. Pass either a base " diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index bd9abb95164..5484a2efb66 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -104,6 +104,18 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): model = prepare_peft_model(model, peft_config, args) diff --git a/trl/experimental/ssd/ssd_trainer.py b/trl/experimental/ssd/ssd_trainer.py index ea378753653..91a0275047a 100644 --- a/trl/experimental/ssd/ssd_trainer.py +++ b/trl/experimental/ssd/ssd_trainer.py @@ -135,6 +135,18 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if is_peft_available() and is_peft_model(model) and peft_config is not None: raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to SSDTrainer. Pass either a base " diff --git a/trl/experimental/tpo/tpo_trainer.py b/trl/experimental/tpo/tpo_trainer.py index 544c0674cca..7195596e15c 100644 --- a/trl/experimental/tpo/tpo_trainer.py +++ b/trl/experimental/tpo/tpo_trainer.py @@ -350,6 +350,18 @@ def __init__( if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if is_peft_available() and is_peft_model(model) and peft_config is not None: raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge "