Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
83b3d69
Fix missing PEFT availability check for peft_config in exp trainers
albertvillanova Apr 27, 2026
32e8764
Fix missing PEFT availability check in BCO
albertvillanova Apr 27, 2026
34a141b
Fix missing PEFT availability check in CPO
albertvillanova Apr 27, 2026
8eb11c5
Fix missing PEFT availability check in KTO
albertvillanova Apr 27, 2026
0d6660d
Fix missing PEFT availability check in ORPO
albertvillanova Apr 27, 2026
3a62cf9
Fix missing PEFT availability check in PPO
albertvillanova Apr 27, 2026
3f3be7b
Add TypeError check to Distillation
albertvillanova Apr 27, 2026
f571155
Add TypeError check to PPO
albertvillanova Apr 27, 2026
d2664dc
Add TypeError check to TPO
albertvillanova Apr 27, 2026
e7242f3
Add TypeError check to BCO
albertvillanova Apr 27, 2026
1392892
Add TypeError check to CPO
albertvillanova Apr 27, 2026
811f87f
Add TypeError check to KTO
albertvillanova Apr 27, 2026
3d487c8
Add TypeError check to ORPO
albertvillanova Apr 27, 2026
ef4149d
Add PEFT validation to PRM
albertvillanova Apr 27, 2026
7179659
Add PEFT validation to OnlineDPO
albertvillanova Apr 27, 2026
d0a7562
Remove redundant duplicate condition check in Distillation
albertvillanova Apr 27, 2026
d236c55
Merge remote-tracking branch 'upstream/main' into fix-peft-validation…
albertvillanova Apr 27, 2026
9f2cd22
Add PEFT validation to SDFT
albertvillanova Apr 28, 2026
81cb6f2
Add PEFT validation to BaseSelfDistillation
albertvillanova Apr 28, 2026
81153c8
Add PEFT validation to SSD
albertvillanova Apr 28, 2026
70b898f
Merge remote-tracking branch 'upstream/main' into fix-peft-validation…
albertvillanova Apr 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions trl/experimental/bco/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training

if is_wandb_available():
import wandb
Expand Down Expand Up @@ -469,15 +469,21 @@ def __init__(
if isinstance(ref_model, str):
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **model_init_kwargs)

# PEFT
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False

if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
if peft_config is not None:
if not is_peft_available():
raise ImportError(
"You passed `peft_config` but the `peft` library is not installed. "
"Install it with `pip install trl[peft]`."
)
if not isinstance(peft_config, PeftConfig):
raise TypeError(
f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), "
f"got {type(peft_config).__name__}."
)
if isinstance(model, PeftModel):
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first "
Expand Down
20 changes: 13 additions & 7 deletions trl/experimental/cpo/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training


if is_wandb_available():
Expand Down Expand Up @@ -169,15 +169,21 @@ def __init__(
if isinstance(model, str):
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

# PEFT
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False

if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
if peft_config is not None:
if not is_peft_available():
raise ImportError(
"You passed `peft_config` but the `peft` library is not installed. "
"Install it with `pip install trl[peft]`."
)
if not isinstance(peft_config, PeftConfig):
raise TypeError(
f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), "
f"got {type(peft_config).__name__}."
)
if isinstance(model, PeftModel):
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first "
Expand Down
10 changes: 10 additions & 0 deletions trl/experimental/distillation/distillation_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,16 @@ def __init__(

# ── PEFT ──
if peft_config is not None:
if not is_peft_available():
raise ImportError(
"You passed `peft_config` but the `peft` library is not installed. "
"Install it with `pip install trl[peft]`."
)
if not isinstance(peft_config, PeftConfig):
raise TypeError(
f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), "
f"got {type(peft_config).__name__}."
)
Comment thread
cursor[bot] marked this conversation as resolved.
model = get_peft_model(model, peft_config)

# ── Data collator ──
Expand Down
31 changes: 19 additions & 12 deletions trl/experimental/kto/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss

if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training


if TYPE_CHECKING:
Expand Down Expand Up @@ -282,20 +282,27 @@ def __init__(
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# PEFT
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False

if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
)
if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge "
"and unload the existing adapter, save the resulting base model, and then pass that base model along "
"with the new `peft_config` to the trainer."
)
if peft_config is not None:
if not is_peft_available():
raise ImportError(
"You passed `peft_config` but the `peft` library is not installed. "
"Install it with `pip install trl[peft]`."
)
if not isinstance(peft_config, PeftConfig):
raise TypeError(
f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), "
f"got {type(peft_config).__name__}."
)
if isinstance(model, PeftModel):
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge "
"and unload the existing adapter, save the resulting base model, and then pass that base model along "
"with the new `peft_config` to the trainer."
)
if is_peft_available() and isinstance(model, PeftModel) and ref_model is None:
# If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy
# of the "default" adapter, so that we can use it as the reference model during KTO training.
Expand Down
12 changes: 12 additions & 0 deletions trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,18 @@ def __init__(
self.is_encoder_decoder = model.config.is_encoder_decoder
self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys()

# PEFT
if peft_config is not None:
if not is_peft_available():
raise ImportError(
"You passed `peft_config` but the `peft` library is not installed. "
"Install it with `pip install trl[peft]`."
)
if not isinstance(peft_config, PeftConfig):
raise TypeError(
f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), "
f"got {type(peft_config).__name__}."
)
Comment thread
cursor[bot] marked this conversation as resolved.
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
model = prepare_peft_model(model, peft_config, args)

Expand Down
20 changes: 13 additions & 7 deletions trl/experimental/orpo/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training


if is_wandb_available():
Expand Down Expand Up @@ -178,15 +178,21 @@ def __init__(
if isinstance(model, str):
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

# PEFT
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False

if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
if peft_config is not None:
if not is_peft_available():
raise ImportError(
"You passed `peft_config` but the `peft` library is not installed. "
"Install it with `pip install trl[peft]`."
)
if not isinstance(peft_config, PeftConfig):
raise TypeError(
f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), "
f"got {type(peft_config).__name__}."
)
if isinstance(model, PeftModel):
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first "
Expand Down
18 changes: 12 additions & 6 deletions trl/experimental/ppo/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,12 +416,18 @@ def __init__(
"[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details."
)

# peft support
if not is_peft_available() and peft_config is not None:
raise ImportError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
# PEFT
if peft_config is not None:
if not is_peft_available():
raise ImportError(
"You passed `peft_config` but the `peft` library is not installed. "
"Install it with `pip install trl[peft]`."
)
if not isinstance(peft_config, PeftConfig):
raise TypeError(
f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), "
f"got {type(peft_config).__name__}."
)
if isinstance(self.policy_model, PeftModel):
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first "
Expand Down
14 changes: 13 additions & 1 deletion trl/experimental/prm/prm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@


if is_peft_available():
from peft import PeftModel
from peft import PeftConfig, PeftModel

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -172,6 +172,18 @@ def __init__(
if train_dataset is None:
raise ValueError("`train_dataset` is required")

# PEFT
if peft_config is not None:
if not is_peft_available():
raise ImportError(
"You passed `peft_config` but the `peft` library is not installed. "
"Install it with `pip install trl[peft]`."
)
if not isinstance(peft_config, PeftConfig):
raise TypeError(
f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), "
f"got {type(peft_config).__name__}."
)
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
model = prepare_peft_model(model, peft_config, args)

Expand Down
12 changes: 12 additions & 0 deletions trl/experimental/tpo/tpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,18 @@ def __init__(
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# PEFT
if peft_config is not None:
if not is_peft_available():
raise ImportError(
"You passed `peft_config` but the `peft` library is not installed. "
"Install it with `pip install trl[peft]`."
)
if not isinstance(peft_config, PeftConfig):
raise TypeError(
f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), "
f"got {type(peft_config).__name__}."
)
if is_peft_available() and is_peft_model(model) and peft_config is not None:
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge "
Expand Down
Loading