-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix: Override hparams via CLI #21455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
26ce013
ead2d79
d658b76
89a59be
81bc423
7f3e84e
fb86acf
0100d65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -405,7 +405,7 @@ def __init__( | |
| main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs) | ||
| self.setup_parser(run, main_kwargs, subparser_kwargs) | ||
| self.parse_arguments(self.parser, args) | ||
| self._parse_ckpt_path() | ||
| self._parse_ckpt_path(self.parser, args) | ||
|
|
||
| self.subcommand = self.config["subcommand"] if run else None | ||
|
|
||
|
|
@@ -560,8 +560,18 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No | |
| else: | ||
| self.config = parser.parse_args(args) | ||
|
|
||
| def _parse_ckpt_path(self) -> None: | ||
| """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config.""" | ||
| def _parse_ckpt_path(self, parser: LightningArgumentParser, args: ArgsType) -> None: | ||
| """Parses the checkpoint path, loads hyperparameters, and injects them as new defaults. | ||
|
|
||
| If `ckpt_path` is provided, this method: | ||
| 1. Loads hyperparameters from the checkpoint file. | ||
| 2. Sets them as new default values for the specific subcommand parser. | ||
| 3. Re-runs argument parsing. | ||
|
|
||
| This ensures the correct priority order: | ||
| __init__ defaults < ckpt hparams < cfg file < CLI args | ||
|
|
||
| """ | ||
| if not self.config.get("subcommand"): | ||
| return | ||
| ckpt_path = self.config[self.config.subcommand].get("ckpt_path") | ||
|
|
@@ -576,12 +586,16 @@ def _parse_ckpt_path(self) -> None: | |
| "class_path": hparams.pop("_class_path"), | ||
| "dict_kwargs": hparams, | ||
| } | ||
| hparams = {self.config.subcommand: {"model": hparams}} | ||
| hparams = {"model": hparams} | ||
| try: | ||
| self.config = self.parser.parse_object(hparams, self.config) | ||
| except SystemExit: | ||
| if parser._subcommands_action is None: | ||
| return | ||
| subparser = parser._subcommands_action._name_parser_map[self.config.subcommand] | ||
| subparser.set_defaults(hparams) | ||
| except KeyError as ex: | ||
| sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n") | ||
| raise | ||
| parser.error(str(ex), ex) | ||
| self.parse_arguments(parser, args) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like this the entire parsing is done twice. The speed is not a problem, since training/predicting is normally significantly slower than the parsing. However, there might be some edge case where the first parsing fails, even though this second one would succeed. But, I am fine with leaving it like this. If there are problematic edge cases, then we can fix them when they happen. |
||
|
|
||
| def _dump_config(self) -> None: | ||
| if hasattr(self, "config_dump"): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -556,9 +556,8 @@ def add_arguments_to_parser(self, parser): | |
| with mock.patch("sys.argv", ["any.py"] + cli_args): | ||
| cli = CkptPathCLI(BoringCkptPathModel, subclass_mode_model=True) | ||
|
|
||
| assert isinstance(cli.model, BoringCkptPathSubclass) | ||
| assert not isinstance(cli.model, BoringCkptPathSubclass) | ||
| assert cli.model.hidden_dim == 8 | ||
| assert cli.model.extra is True | ||
|
Comment on lines
-559
to
-561
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did this change? Old behaviors should not change, unless there was bug. Was there a bug? If not, revert back.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the old behavior was effectively a bug. It violated the standard priority Previously, explicit flags like This also complements PR #21408: while that PR handles parameter adaptation, this ensures the class itself can be overridden. Together, they enable a clean Training -> Inference workflow with different classes.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One drawback of setting hparams as defaults is that the typical left-to-right parsing priority of command line arguments is not respected. For example, a command like Ignoring the order of command line arguments, the reason for including the |
||
| assert cli.model.layer.out_features == 4 | ||
|
|
||
|
|
||
|
|
@@ -588,6 +587,151 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai | |
| assert isinstance(cli.model.submodule2, BoringModel) | ||
|
|
||
|
|
||
| class DemoModel(BoringModel): | ||
| def __init__( | ||
| self, | ||
| num_classes: int = 10, | ||
| learning_rate: float = 0.01, | ||
| dropout: float = 0.1, | ||
| backbone_hidden_dim: int = 128, | ||
| ): | ||
| super().__init__() | ||
| self.save_hyperparameters() | ||
| self.num_classes = num_classes | ||
| self.learning_rate = learning_rate | ||
| self.dropout = dropout | ||
| self.backbone_hidden_dim = backbone_hidden_dim | ||
|
|
||
|
|
||
| def test_lightning_cli_args_override_checkpoint_hparams(cleandir): | ||
| """ | ||
| Check priority: ckpt hparams < CLI Args | ||
|
|
||
| Scenario: | ||
| 1. Save checkpoint with specific `dropout`, `backbone_hidden_dim` | ||
| 2. Load checkpoint, but explicitly override 'learning_rate` and `backbone_hidden_dim` | ||
| 3. Verify that `num_classes` and `dropout` is restored from ckpt, | ||
| but `learning_rate` and `backbone_hidden_dim` is update from the CLI arg. | ||
| """ | ||
|
|
||
| # --- Phase 1: Create a base checkpoint | ||
| orig_hidden_dim = 256 | ||
| orig_dropout = 0.5 | ||
|
|
||
| save_args = [ | ||
| "fit", | ||
| f"--model.dropout={orig_dropout}", | ||
| f"--model.backbone_hidden_dim={orig_hidden_dim}", | ||
| "--trainer.devices=1", | ||
| "--trainer.max_steps=1", | ||
| "--trainer.limit_train_batches=1", | ||
| "--trainer.limit_val_batches=1", | ||
| "--trainer.default_root_dir=./", | ||
| ] | ||
|
|
||
| with mock.patch("sys.argv", ["any.py"] + save_args): | ||
| cli = LightningCLI(DemoModel) | ||
|
|
||
| checkpoint_path = str(next(Path(cli.trainer.default_root_dir).rglob("*.ckpt"))) | ||
|
|
||
| # --- Phase 2: Predict with CLI overrides --- | ||
| new_lr = 0.123 | ||
| new_hidden_dim = 512 | ||
| override_args = [ | ||
| "predict", | ||
| "--trainer.devices=1", | ||
| f"--model.learning_rate={new_lr}", | ||
| f"--model.backbone_hidden_dim={new_hidden_dim}", | ||
| f"--ckpt_path={checkpoint_path}", | ||
| ] | ||
|
|
||
| with mock.patch("sys.argv", ["any.py"] + override_args): | ||
| new_cli = LightningCLI(DemoModel) | ||
|
|
||
| # --- Phase 3: Assertions --- | ||
| assert new_cli.model.learning_rate == new_lr, ( | ||
| f"CLI override failed! Expected LR {new_lr}, got {new_cli.model.learning_rate}" | ||
| ) | ||
|
|
||
| assert new_cli.model.dropout == orig_dropout, ( | ||
| f"Checkpoint restoration failed! Expected dropout {orig_dropout}, got {new_cli.model.dropout}" | ||
| ) | ||
| assert new_cli.model.backbone_hidden_dim == new_hidden_dim, ( | ||
| f"CLI override failed! Expected dim {new_hidden_dim}, got {new_cli.model.backbone_hidden_dim}" | ||
| ) | ||
|
|
||
|
|
||
| def test_lightning_cli_config_priority_over_checkpoint_hparams(cleandir): | ||
| """ | ||
| Test the full priority hierarchy: | ||
| ckpt hparams < Config < CLI Args | ||
|
|
||
| Scenario: | ||
| 1. Save checkpoint with specific `num_classes`, `learning_rate` and `dropout` | ||
| 2. Load checkpoint, but explicitly override: | ||
| num_classes by: config, cli | ||
| learning_rate: config | ||
| 3. Verify that: | ||
| num_classes from: CLI Args | ||
| learning_rate: Config | ||
| dropout: dropout | ||
|
|
||
| """ | ||
| orig_classes = 60_000 | ||
| orig_lr = 1e-4 | ||
| orig_dropout = 0.01 | ||
| save_args = [ | ||
| "fit", | ||
| f"--model.num_classes={orig_classes}", | ||
| f"--model.learning_rate={orig_lr}", | ||
| f"--model.dropout={orig_dropout}", | ||
| "--trainer.devices=1", | ||
| "--trainer.max_steps=1", | ||
| "--trainer.limit_train_batches=1", | ||
| "--trainer.limit_val_batches=1", | ||
| "--trainer.default_root_dir=./", | ||
| ] | ||
|
|
||
| with mock.patch("sys.argv", ["any.py"] + save_args): | ||
| cli = LightningCLI(DemoModel) | ||
|
|
||
| cfg_lr = 2e-5 | ||
| config = f""" | ||
| model: | ||
| num_classes: 1000 | ||
| learning_rate: {cfg_lr} | ||
| """ | ||
|
|
||
| config_path = Path("config.yaml") | ||
| config_path.write_text(config) | ||
|
|
||
| checkpoint_path = str(next(Path(cli.trainer.default_root_dir).rglob("*.ckpt"))) | ||
|
|
||
| cli_classes = 1024 | ||
| cli_args = [ | ||
| "predict", | ||
| f"--config={config_path}", | ||
| f"--model.num_classes={cli_classes}", | ||
| "--trainer.devices=1", | ||
| f"--ckpt_path={checkpoint_path}", | ||
| ] | ||
| with mock.patch("sys.argv", ["any.py"] + cli_args): | ||
| new_cli = LightningCLI(DemoModel) | ||
|
|
||
| assert new_cli.model.num_classes == cli_classes, ( | ||
| f"CLI priority failed! Expected num_classes {cli_classes}, got {new_cli.model.num_classes}" | ||
| ) | ||
| assert new_cli.model.learning_rate == cfg_lr, ( | ||
| f"Config override failed! Expected LR {cfg_lr}, got {new_cli.model.learning_rate}" | ||
| ) | ||
| assert new_cli.model.dropout == orig_dropout, ( | ||
| f"Checkpoint restoration failed! Expected dropout {orig_dropout}, got {new_cli.model.dropout}" | ||
| ) | ||
| assert new_cli.model.backbone_hidden_dim == 128, ( | ||
| f"Checkpoint restoration failed! Expected dim {128}, got {new_cli.model.backbone_hidden_dim}" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE)) | ||
| def test_lightning_cli_torch_modules(cleandir): | ||
| class TestModule(BoringModel): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come you added this? The
ifcondition in line 575 is supposed to already return when no subcommands are defined. Maybe the condition is not robust enough, since it would not work when a model has asubcommandparameter. Still, if there are no subcommands, why execute the lines from 577 up to here? Maybe better instead to change line 575 toif parser._subcommands_action is None or not self.config.get("subcommand"):.