Skip to content
Open
41 changes: 41 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,41 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]:
"""Adapt checkpoint hyperparameters before instantiating the model class.

This method allows for customization of hyperparameters loaded from a checkpoint when
using a different model class than the one used for training. For example, when loading
a checkpoint from a TrainingModule to use with an InferenceModule that has different
``__init__`` parameters, you can remove or modify incompatible hyperparameters.

Args:
subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict').
This allows you to apply different hyperparameter adaptations depending on the context.
checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint.

Returns:
Dictionary of adapted hyperparameters to be used for model instantiation.

Example::

class MyCLI(LightningCLI):
def adapt_checkpoint_hparams(
self, subcommand: str, checkpoint_hparams: dict[str, Any]
) -> dict[str, Any]:
# Only remove training-specific hyperparameters for non-fit subcommands
if subcommand != "fit":
checkpoint_hparams.pop("lr", None)
checkpoint_hparams.pop("weight_decay", None)
return checkpoint_hparams

Note:
If subclass module mode is enabled and ``_class_path`` is present in the checkpoint
hyperparameters, you may need to modify it as well to point to your new module class.

"""
return checkpoint_hparams
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really related to this new feature, but there is also my comment in #21116 (comment). Nobody responded to it. Maybe by default fit should not use the hparams from the checkpoint?

Also this could be related #21255 (comment)

I am not really sure what to do here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arrdel any comment on this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it seems #21455 would fix this comment, I think.


def _parse_ckpt_path(self) -> None:
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
if not self.config.get("subcommand"):
Expand All @@ -571,6 +606,12 @@ def _parse_ckpt_path(self) -> None:
hparams.pop("_instantiator", None)
if not hparams:
return

# Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook
hparams = self.adapt_checkpoint_hparams(self.config.subcommand, hparams)
if not hparams:
return

if "_class_path" in hparams:
hparams = {
"class_path": hparams.pop("_class_path"),
Expand Down
71 changes: 71 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,21 @@ def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None:
self.layer = torch.nn.Linear(32, out_dim)


class AdaptHparamsModel(BoringModel):
"""Simple model for testing adapt_checkpoint_hparams hook without dynamic neural network layers.

This model stores hyperparameters as attributes without creating layers that would cause size mismatches when
hyperparameters are changed between fit and predict phases.

"""

def __init__(self, out_dim: int = 8, hidden_dim: int = 16) -> None:
super().__init__()
self.save_hyperparameters()
self.out_dim = out_dim
self.hidden_dim = hidden_dim


def test_lightning_cli_ckpt_path_argument_hparams(cleandir):
class CkptPathCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
Expand Down Expand Up @@ -562,6 +577,62 @@ def add_arguments_to_parser(self, parser):
assert cli.model.layer.out_features == 4


def test_adapt_checkpoint_hparams_hook_pop_keys(cleandir):
"""Test that the adapt_checkpoint_hparams hook is called and modifications are applied."""

class AdaptHparamsCLI(LightningCLI):
def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict) -> dict:
"""Remove out_dim and hidden_dim for non-fit subcommands."""
if subcommand != "fit":
checkpoint_hparams.pop("out_dim", None)
checkpoint_hparams.pop("hidden_dim", None)
Comment on lines +587 to +588
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a testing perspective there is no difference between out_dim and hidden_dim. It might be better if one is popped and the other not, so that both cases are tested?

return checkpoint_hparams

# First, create a checkpoint by running fit
cli_args = ["fit", "--model.out_dim=3", "--model.hidden_dim=6", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsCLI(AdaptHparamsModel)

assert cli.config.fit.model.out_dim == 3
assert cli.config.fit.model.hidden_dim == 6

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))

# Test that predict uses adapted hparams (without out_dim and hidden_dim)
cli_args = ["predict", f"--ckpt_path={checkpoint_path}", "--model.out_dim=5", "--model.hidden_dim=10"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsCLI(AdaptHparamsModel)

# Since we removed out_dim and hidden_dim for predict, the CLI values should be used
assert cli.config.predict.model.out_dim == 5
assert cli.config.predict.model.hidden_dim == 10


def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir):
"""Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading."""

class AdaptHparamsEmptyCLI(LightningCLI):
def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict) -> dict:
"""Disable checkpoint hyperparameter loading."""
return {}

# First, create a checkpoint
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsEmptyCLI(AdaptHparamsModel)

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))

# Test that predict uses default values when hook returns empty dict
cli_args = ["predict", f"--ckpt_path={checkpoint_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsEmptyCLI(AdaptHparamsModel)

# Model should use default values (out_dim=8, hidden_dim=16)
assert cli.config_init.predict.model.out_dim == 8
assert cli.config_init.predict.model.hidden_dim == 16


def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
Expand Down
Loading