Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def __init__(
main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs)
self.setup_parser(run, main_kwargs, subparser_kwargs)
self.parse_arguments(self.parser, args)
self._parse_ckpt_path()
self._parse_ckpt_path(self.parser, args)

self.subcommand = self.config["subcommand"] if run else None

Expand Down Expand Up @@ -560,8 +560,18 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def _parse_ckpt_path(self) -> None:
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
def _parse_ckpt_path(self, parser: LightningArgumentParser, args: ArgsType) -> None:
"""Parses the checkpoint path, loads hyperparameters, and injects them as new defaults.

If `ckpt_path` is provided, this method:
1. Loads hyperparameters from the checkpoint file.
2. Sets them as new default values for the specific subcommand parser.
3. Re-runs argument parsing.

This ensures the correct priority order:
__init__ defaults < ckpt hparams < cfg file < CLI args

"""
if not self.config.get("subcommand"):
return
ckpt_path = self.config[self.config.subcommand].get("ckpt_path")
Expand All @@ -576,12 +586,16 @@ def _parse_ckpt_path(self) -> None:
"class_path": hparams.pop("_class_path"),
"dict_kwargs": hparams,
}
hparams = {self.config.subcommand: {"model": hparams}}
hparams = {"model": hparams}
try:
self.config = self.parser.parse_object(hparams, self.config)
except SystemExit:
if parser._subcommands_action is None:
return
Comment on lines +591 to +592
Copy link
Contributor

@mauvilsa mauvilsa Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come you added this? The if condition in line 575 is supposed to already return when no subcommands are defined. Maybe the condition is not robust enough, since it would not work when a model has a subcommand parameter. Still, if there are no subcommands, why execute the lines from 577 up to here? Maybe better instead to change line 575 to if parser._subcommands_action is None or not self.config.get("subcommand"):.

subparser = parser._subcommands_action._name_parser_map[self.config.subcommand]
subparser.set_defaults(hparams)
except KeyError as ex:
sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n")
raise
parser.error(str(ex), ex)
self.parse_arguments(parser, args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this the entire parsing is done twice. The speed is not a problem, since training/predicting is normally significantly slower than the parsing. However, there might be some edge case where the first parsing fails, even though this second one would succeed. But, I am fine with leaving it like this. If there are problematic edge cases, then we can fix them when they happen.


def _dump_config(self) -> None:
if hasattr(self, "config_dump"):
Expand Down
148 changes: 146 additions & 2 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,8 @@ def add_arguments_to_parser(self, parser):
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = CkptPathCLI(BoringCkptPathModel, subclass_mode_model=True)

assert isinstance(cli.model, BoringCkptPathSubclass)
assert not isinstance(cli.model, BoringCkptPathSubclass)
assert cli.model.hidden_dim == 8
assert cli.model.extra is True
Comment on lines -559 to -561
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this change? Old behaviors should not change, unless there was bug. Was there a bug? If not, revert back.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the old behavior was effectively a bug. It violated the standard priority Checkpoint < CLI args.

Previously, explicit flags like --model=BoringCkptPathModel were ignored in favor of the checkpoint class. The updated test confirms that the CLI argument now correctly takes precedence.

This also complements PR #21408: while that PR handles parameter adaptation, this ensures the class itself can be overridden. Together, they enable a clean Training -> Inference workflow with different classes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One drawback of setting hparams as defaults is that the typical left-to-right parsing priority of command line arguments is not respected. For example, a command like --model=BoringCkptPathModel --ckpt_path=... can be confusing, because --ckpt_path appears after --model, and arguments to the right usually take precedence. Since the checkpoint specifies a particular model, it would be logical for it to override the model argument. It might make sense to require that if --ckpt_path is provided, it should be the first argument after the subcommand. This would maintain the convention that arguments on the right have higher priority. However, enforcing --ckpt_path as the first argument could be difficult to implement cleanly.

Ignoring the order of command line arguments, the reason for including the --model=BoringCkptPathModel argument in this test is that the model is required. Without this argument, parsing fails due to the missing model. In fact, this is a case where the initial parsing fails, but the second parsing (after applying checkpoint hparams defaults) succeeds. Ideally, if a checkpoint is provided, specifying a model argument should not be mandatory, since the checkpoint already contains one. The model argument could be optional and used to override the checkpoint's model if given, but not required. However, implementing this might be tricky. The parser must still mark the model as required; otherwise, the output of --help would be misleading. One possible solution is to temporarily patch the subparsers during parsing to make the model argument optional, and then, if the model is not provided, call parser.error.

assert cli.model.layer.out_features == 4


Expand Down Expand Up @@ -588,6 +587,151 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai
assert isinstance(cli.model.submodule2, BoringModel)


class DemoModel(BoringModel):
def __init__(
self,
num_classes: int = 10,
learning_rate: float = 0.01,
dropout: float = 0.1,
backbone_hidden_dim: int = 128,
):
super().__init__()
self.save_hyperparameters()
self.num_classes = num_classes
self.learning_rate = learning_rate
self.dropout = dropout
self.backbone_hidden_dim = backbone_hidden_dim


def test_lightning_cli_args_override_checkpoint_hparams(cleandir):
"""
Check priority: ckpt hparams < CLI Args

Scenario:
1. Save checkpoint with specific `dropout`, `backbone_hidden_dim`
2. Load checkpoint, but explicitly override 'learning_rate` and `backbone_hidden_dim`
3. Verify that `num_classes` and `dropout` is restored from ckpt,
but `learning_rate` and `backbone_hidden_dim` is update from the CLI arg.
"""

# --- Phase 1: Create a base checkpoint
orig_hidden_dim = 256
orig_dropout = 0.5

save_args = [
"fit",
f"--model.dropout={orig_dropout}",
f"--model.backbone_hidden_dim={orig_hidden_dim}",
"--trainer.devices=1",
"--trainer.max_steps=1",
"--trainer.limit_train_batches=1",
"--trainer.limit_val_batches=1",
"--trainer.default_root_dir=./",
]

with mock.patch("sys.argv", ["any.py"] + save_args):
cli = LightningCLI(DemoModel)

checkpoint_path = str(next(Path(cli.trainer.default_root_dir).rglob("*.ckpt")))

# --- Phase 2: Predict with CLI overrides ---
new_lr = 0.123
new_hidden_dim = 512
override_args = [
"predict",
"--trainer.devices=1",
f"--model.learning_rate={new_lr}",
f"--model.backbone_hidden_dim={new_hidden_dim}",
f"--ckpt_path={checkpoint_path}",
]

with mock.patch("sys.argv", ["any.py"] + override_args):
new_cli = LightningCLI(DemoModel)

# --- Phase 3: Assertions ---
assert new_cli.model.learning_rate == new_lr, (
f"CLI override failed! Expected LR {new_lr}, got {new_cli.model.learning_rate}"
)

assert new_cli.model.dropout == orig_dropout, (
f"Checkpoint restoration failed! Expected dropout {orig_dropout}, got {new_cli.model.dropout}"
)
assert new_cli.model.backbone_hidden_dim == new_hidden_dim, (
f"CLI override failed! Expected dim {new_hidden_dim}, got {new_cli.model.backbone_hidden_dim}"
)


def test_lightning_cli_config_priority_over_checkpoint_hparams(cleandir):
"""
Test the full priority hierarchy:
ckpt hparams < Config < CLI Args

Scenario:
1. Save checkpoint with specific `num_classes`, `learning_rate` and `dropout`
2. Load checkpoint, but explicitly override:
num_classes by: config, cli
learning_rate: config
3. Verify that:
num_classes from: CLI Args
learning_rate: Config
dropout: dropout

"""
orig_classes = 60_000
orig_lr = 1e-4
orig_dropout = 0.01
save_args = [
"fit",
f"--model.num_classes={orig_classes}",
f"--model.learning_rate={orig_lr}",
f"--model.dropout={orig_dropout}",
"--trainer.devices=1",
"--trainer.max_steps=1",
"--trainer.limit_train_batches=1",
"--trainer.limit_val_batches=1",
"--trainer.default_root_dir=./",
]

with mock.patch("sys.argv", ["any.py"] + save_args):
cli = LightningCLI(DemoModel)

cfg_lr = 2e-5
config = f"""
model:
num_classes: 1000
learning_rate: {cfg_lr}
"""

config_path = Path("config.yaml")
config_path.write_text(config)

checkpoint_path = str(next(Path(cli.trainer.default_root_dir).rglob("*.ckpt")))

cli_classes = 1024
cli_args = [
"predict",
f"--config={config_path}",
f"--model.num_classes={cli_classes}",
"--trainer.devices=1",
f"--ckpt_path={checkpoint_path}",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
new_cli = LightningCLI(DemoModel)

assert new_cli.model.num_classes == cli_classes, (
f"CLI priority failed! Expected num_classes {cli_classes}, got {new_cli.model.num_classes}"
)
assert new_cli.model.learning_rate == cfg_lr, (
f"Config override failed! Expected LR {cfg_lr}, got {new_cli.model.learning_rate}"
)
assert new_cli.model.dropout == orig_dropout, (
f"Checkpoint restoration failed! Expected dropout {orig_dropout}, got {new_cli.model.dropout}"
)
assert new_cli.model.backbone_hidden_dim == 128, (
f"Checkpoint restoration failed! Expected dim {128}, got {new_cli.model.backbone_hidden_dim}"
)


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE))
def test_lightning_cli_torch_modules(cleandir):
class TestModule(BoringModel):
Expand Down