diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 2bcb1d8f4b1fd..10f9ec8b290ea 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -405,7 +405,7 @@ def __init__( main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs) self.setup_parser(run, main_kwargs, subparser_kwargs) self.parse_arguments(self.parser, args) - self._parse_ckpt_path() + self._parse_ckpt_path(self.parser, args) self.subcommand = self.config["subcommand"] if run else None @@ -560,8 +560,18 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) - def _parse_ckpt_path(self) -> None: - """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config.""" + def _parse_ckpt_path(self, parser: LightningArgumentParser, args: ArgsType) -> None: + """Parses the checkpoint path, loads hyperparameters, and injects them as new defaults. + + If `ckpt_path` is provided, this method: + 1. Loads hyperparameters from the checkpoint file. + 2. Sets them as new default values for the specific subcommand parser. + 3. Re-runs argument parsing. + + This ensures the correct priority order: + __init__ defaults < ckpt hparams < cfg file < CLI args + + """ if not self.config.get("subcommand"): return ckpt_path = self.config[self.config.subcommand].get("ckpt_path") @@ -576,12 +586,16 @@ def _parse_ckpt_path(self) -> None: "class_path": hparams.pop("_class_path"), "dict_kwargs": hparams, } - hparams = {self.config.subcommand: {"model": hparams}} + hparams = {"model": hparams} try: - self.config = self.parser.parse_object(hparams, self.config) - except SystemExit: + if parser._subcommands_action is None: + return + subparser = parser._subcommands_action._name_parser_map[self.config.subcommand] + subparser.set_defaults(hparams) + except KeyError as ex: sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n") - raise + parser.error(str(ex), ex) + self.parse_arguments(parser, args) def _dump_config(self) -> None: if hasattr(self, "config_dump"): diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 094a21a5b932c..a332bccfc8713 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -556,9 +556,8 @@ def add_arguments_to_parser(self, parser): with mock.patch("sys.argv", ["any.py"] + cli_args): cli = CkptPathCLI(BoringCkptPathModel, subclass_mode_model=True) - assert isinstance(cli.model, BoringCkptPathSubclass) + assert not isinstance(cli.model, BoringCkptPathSubclass) assert cli.model.hidden_dim == 8 - assert cli.model.extra is True assert cli.model.layer.out_features == 4 @@ -588,6 +587,151 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai assert isinstance(cli.model.submodule2, BoringModel) +class DemoModel(BoringModel): + def __init__( + self, + num_classes: int = 10, + learning_rate: float = 0.01, + dropout: float = 0.1, + backbone_hidden_dim: int = 128, + ): + super().__init__() + self.save_hyperparameters() + self.num_classes = num_classes + self.learning_rate = learning_rate + self.dropout = dropout + self.backbone_hidden_dim = backbone_hidden_dim + + +def test_lightning_cli_args_override_checkpoint_hparams(cleandir): + """ + Check priority: ckpt hparams < CLI Args + + Scenario: + 1. Save checkpoint with specific `dropout`, `backbone_hidden_dim` + 2. Load checkpoint, but explicitly override 'learning_rate` and `backbone_hidden_dim` + 3. Verify that `num_classes` and `dropout` is restored from ckpt, + but `learning_rate` and `backbone_hidden_dim` is update from the CLI arg. + """ + + # --- Phase 1: Create a base checkpoint + orig_hidden_dim = 256 + orig_dropout = 0.5 + + save_args = [ + "fit", + f"--model.dropout={orig_dropout}", + f"--model.backbone_hidden_dim={orig_hidden_dim}", + "--trainer.devices=1", + "--trainer.max_steps=1", + "--trainer.limit_train_batches=1", + "--trainer.limit_val_batches=1", + "--trainer.default_root_dir=./", + ] + + with mock.patch("sys.argv", ["any.py"] + save_args): + cli = LightningCLI(DemoModel) + + checkpoint_path = str(next(Path(cli.trainer.default_root_dir).rglob("*.ckpt"))) + + # --- Phase 2: Predict with CLI overrides --- + new_lr = 0.123 + new_hidden_dim = 512 + override_args = [ + "predict", + "--trainer.devices=1", + f"--model.learning_rate={new_lr}", + f"--model.backbone_hidden_dim={new_hidden_dim}", + f"--ckpt_path={checkpoint_path}", + ] + + with mock.patch("sys.argv", ["any.py"] + override_args): + new_cli = LightningCLI(DemoModel) + + # --- Phase 3: Assertions --- + assert new_cli.model.learning_rate == new_lr, ( + f"CLI override failed! Expected LR {new_lr}, got {new_cli.model.learning_rate}" + ) + + assert new_cli.model.dropout == orig_dropout, ( + f"Checkpoint restoration failed! Expected dropout {orig_dropout}, got {new_cli.model.dropout}" + ) + assert new_cli.model.backbone_hidden_dim == new_hidden_dim, ( + f"CLI override failed! Expected dim {new_hidden_dim}, got {new_cli.model.backbone_hidden_dim}" + ) + + +def test_lightning_cli_config_priority_over_checkpoint_hparams(cleandir): + """ + Test the full priority hierarchy: + ckpt hparams < Config < CLI Args + + Scenario: + 1. Save checkpoint with specific `num_classes`, `learning_rate` and `dropout` + 2. Load checkpoint, but explicitly override: + num_classes by: config, cli + learning_rate: config + 3. Verify that: + num_classes from: CLI Args + learning_rate: Config + dropout: dropout + + """ + orig_classes = 60_000 + orig_lr = 1e-4 + orig_dropout = 0.01 + save_args = [ + "fit", + f"--model.num_classes={orig_classes}", + f"--model.learning_rate={orig_lr}", + f"--model.dropout={orig_dropout}", + "--trainer.devices=1", + "--trainer.max_steps=1", + "--trainer.limit_train_batches=1", + "--trainer.limit_val_batches=1", + "--trainer.default_root_dir=./", + ] + + with mock.patch("sys.argv", ["any.py"] + save_args): + cli = LightningCLI(DemoModel) + + cfg_lr = 2e-5 + config = f""" + model: + num_classes: 1000 + learning_rate: {cfg_lr} + """ + + config_path = Path("config.yaml") + config_path.write_text(config) + + checkpoint_path = str(next(Path(cli.trainer.default_root_dir).rglob("*.ckpt"))) + + cli_classes = 1024 + cli_args = [ + "predict", + f"--config={config_path}", + f"--model.num_classes={cli_classes}", + "--trainer.devices=1", + f"--ckpt_path={checkpoint_path}", + ] + with mock.patch("sys.argv", ["any.py"] + cli_args): + new_cli = LightningCLI(DemoModel) + + assert new_cli.model.num_classes == cli_classes, ( + f"CLI priority failed! Expected num_classes {cli_classes}, got {new_cli.model.num_classes}" + ) + assert new_cli.model.learning_rate == cfg_lr, ( + f"Config override failed! Expected LR {cfg_lr}, got {new_cli.model.learning_rate}" + ) + assert new_cli.model.dropout == orig_dropout, ( + f"Checkpoint restoration failed! Expected dropout {orig_dropout}, got {new_cli.model.dropout}" + ) + assert new_cli.model.backbone_hidden_dim == 128, ( + f"Checkpoint restoration failed! Expected dim {128}, got {new_cli.model.backbone_hidden_dim}" + ) + + @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE)) def test_lightning_cli_torch_modules(cleandir): class TestModule(BoringModel):