diff --git a/docs/source-pytorch/data/access.rst b/docs/source-pytorch/data/access.rst index d849ded5bf0bf..93176e7668851 100644 --- a/docs/source-pytorch/data/access.rst +++ b/docs/source-pytorch/data/access.rst @@ -41,3 +41,47 @@ If you are using a :class:`~lightning.pytorch.utilities.CombinedLoader`. A flatt updated.append(new_dl) # it also allows you to easily replace the dataloaders combined_loader.flattened = updated + + +Reloading DataLoaders During Training +------------------------------------- + +Lightning provides two mechanisms for reloading dataloaders during training: + +**Automatic reload with** ``reload_dataloaders_every_n_epochs`` + +Set ``reload_dataloaders_every_n_epochs`` in the Trainer to automatically reload dataloaders at regular intervals: + +.. code-block:: python + + trainer = Trainer(reload_dataloaders_every_n_epochs=5) + +This is useful when your dataset changes periodically, such as in online learning scenarios. + +**Manual reload with** ``trainer.reload_dataloaders()`` + +For dynamic scenarios like curriculum learning or adaptive training strategies, use +:meth:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders` to trigger a reload +based on training metrics or other conditions: + +.. code-block:: python + + class CurriculumCallback(Callback): + def on_train_epoch_end(self, trainer, pl_module): + if trainer.callback_metrics.get("train_loss", 1.0) < 0.5: + # Update datamodule parameters + trainer.datamodule.difficulty_level += 1 + # Trigger reload for next epoch + trainer.reload_dataloaders(train=True, val=True) + +Or directly from your LightningModule: + +.. code-block:: python + + class MyModel(LightningModule): + def on_train_batch_end(self, outputs, batch, batch_idx): + if self.trainer.callback_metrics.get("train_loss", 1.0) < 0.5: + self.trainer.datamodule.sequence_length += 10 + self.trainer.reload_dataloaders() + +The reload happens at the start of the next epoch, ensuring training state consistency. diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 2bcb1d8f4b1fd..9226b1dfab02f 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -560,6 +560,41 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) + def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]: + """Adapt checkpoint hyperparameters before instantiating the model class. + + This method allows for customization of hyperparameters loaded from a checkpoint when + using a different model class than the one used for training. For example, when loading + a checkpoint from a TrainingModule to use with an InferenceModule that has different + ``__init__`` parameters, you can remove or modify incompatible hyperparameters. + + Args: + subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict'). + This allows you to apply different hyperparameter adaptations depending on the context. + checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint. + + Returns: + Dictionary of adapted hyperparameters to be used for model instantiation. + + Example:: + + class MyCLI(LightningCLI): + def adapt_checkpoint_hparams( + self, subcommand: str, checkpoint_hparams: dict[str, Any] + ) -> dict[str, Any]: + # Only remove training-specific hyperparameters for non-fit subcommands + if subcommand != "fit": + checkpoint_hparams.pop("lr", None) + checkpoint_hparams.pop("weight_decay", None) + return checkpoint_hparams + + Note: + If subclass module mode is enabled and ``_class_path`` is present in the checkpoint + hyperparameters, you may need to modify it as well to point to your new module class. + + """ + return checkpoint_hparams + def _parse_ckpt_path(self) -> None: """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config.""" if not self.config.get("subcommand"): @@ -571,6 +606,12 @@ def _parse_ckpt_path(self) -> None: hparams.pop("_instantiator", None) if not hparams: return + + # Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook + hparams = self.adapt_checkpoint_hparams(self.config.subcommand, hparams) + if not hparams: + return + if "_class_path" in hparams: hparams = { "class_path": hparams.pop("_class_path"), diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 6f947160ba9cb..b51f99b5f8b2d 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1206,6 +1206,60 @@ def print(self, *args: Any, **kwargs: Any) -> None: if self.local_rank == 0: print(*args, **kwargs) + def reload_dataloaders(self, train: bool = True, val: bool = False) -> None: + """Manually trigger a reload of dataloaders during training. + + This method allows dynamic reconfiguration of DataLoaders without exiting the ``fit()`` loop. + It's useful for curriculum learning, adaptive training strategies, or any scenario where + DataLoader parameters need to change based on training metrics or progress. + + The reload will occur at the start of the next epoch during training. + + Args: + train: If ``True``, reload the train dataloader. Default: ``True``. + val: If ``True``, reload the validation dataloader. Default: ``False``. + + Example:: + + # In a callback + def on_train_epoch_end(self, trainer, pl_module): + if trainer.current_epoch == 5: + # Update datamodule parameters + trainer.datamodule.sequence_length += 10 + # Trigger reload for next epoch + trainer.reload_dataloaders(train=True, val=True) + + # In a LightningModule + def on_train_batch_end(self, outputs, batch, batch_idx): + if self.trainer.callback_metrics.get('train_loss', 1.0) < 0.5: + self.trainer.datamodule.unroll_steps += 1 + self.trainer.reload_dataloaders() + + Raises: + RuntimeError: If called outside of a ``fit()`` call. + + .. note:: + + The actual reload happens at the beginning of the next training epoch, + not immediately when this method is called. This ensures training state + consistency and proper synchronization in distributed settings. + + """ + if not self.training: + raise RuntimeError( + "`trainer.reload_dataloaders()` can only be called during training (inside `trainer.fit()`)." + ) + + if train: + # Setting to -inf ensures _should_reload_train_dl returns True + self.fit_loop._last_train_dl_reload_epoch = float("-inf") + rank_zero_info("Train dataloader will be reloaded at the start of the next epoch.") + + if val: + # Setting to -inf ensures _should_reload_val_dl returns True + self.fit_loop.epoch_loop.val_loop._last_val_dl_reload_epoch = float("-inf") + rank_zero_info("Validation dataloader will be reloaded at the next validation check.") + """ Accelerator properties """ diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 6ff4bee264a7b..f61fcbe384f6f 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -495,6 +495,21 @@ def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None: self.layer = torch.nn.Linear(32, out_dim) +class AdaptHparamsModel(BoringModel): + """Simple model for testing adapt_checkpoint_hparams hook without dynamic neural network layers. + + This model stores hyperparameters as attributes without creating layers that would cause size mismatches when + hyperparameters are changed between fit and predict phases. + + """ + + def __init__(self, out_dim: int = 8, hidden_dim: int = 16) -> None: + super().__init__() + self.save_hyperparameters() + self.out_dim = out_dim + self.hidden_dim = hidden_dim + + def test_lightning_cli_ckpt_path_argument_hparams(cleandir): class CkptPathCLI(LightningCLI): def add_arguments_to_parser(self, parser): @@ -562,6 +577,62 @@ def add_arguments_to_parser(self, parser): assert cli.model.layer.out_features == 4 +def test_adapt_checkpoint_hparams_hook_pop_keys(cleandir): + """Test that the adapt_checkpoint_hparams hook is called and modifications are applied.""" + + class AdaptHparamsCLI(LightningCLI): + def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict) -> dict: + """Remove out_dim and hidden_dim for non-fit subcommands.""" + if subcommand != "fit": + checkpoint_hparams.pop("out_dim", None) + checkpoint_hparams.pop("hidden_dim", None) + return checkpoint_hparams + + # First, create a checkpoint by running fit + cli_args = ["fit", "--model.out_dim=3", "--model.hidden_dim=6", "--trainer.max_epochs=1"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsCLI(AdaptHparamsModel) + + assert cli.config.fit.model.out_dim == 3 + assert cli.config.fit.model.hidden_dim == 6 + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) + + # Test that predict uses adapted hparams (without out_dim and hidden_dim) + cli_args = ["predict", f"--ckpt_path={checkpoint_path}", "--model.out_dim=5", "--model.hidden_dim=10"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsCLI(AdaptHparamsModel) + + # Since we removed out_dim and hidden_dim for predict, the CLI values should be used + assert cli.config.predict.model.out_dim == 5 + assert cli.config.predict.model.hidden_dim == 10 + + +def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir): + """Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading.""" + + class AdaptHparamsEmptyCLI(LightningCLI): + def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict) -> dict: + """Disable checkpoint hyperparameter loading.""" + return {} + + # First, create a checkpoint + cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsEmptyCLI(AdaptHparamsModel) + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) + + # Test that predict uses default values when hook returns empty dict + cli_args = ["predict", f"--ckpt_path={checkpoint_path}"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsEmptyCLI(AdaptHparamsModel) + + # Model should use default values (out_dim=8, hidden_dim=16) + assert cli.config_init.predict.model.out_dim == 8 + assert cli.config_init.predict.model.hidden_dim == 16 + + def test_lightning_cli_submodules(cleandir): class MainModule(BoringModel): def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1): diff --git a/tests/tests_pytorch/trainer/test_reload_dataloaders.py b/tests/tests_pytorch/trainer/test_reload_dataloaders.py new file mode 100644 index 0000000000000..770025ad7ffe9 --- /dev/null +++ b/tests/tests_pytorch/trainer/test_reload_dataloaders.py @@ -0,0 +1,342 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for manual dataloader reloading feature (issue #21448).""" + +import pytest +import torch +from torch.utils.data import DataLoader + +from lightning.pytorch import Callback, LightningDataModule, Trainer +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset + + +class ReloadTrackingDataModule(LightningDataModule): + """DataModule that tracks when train_dataloader is called.""" + + def __init__(self, sequence_length: int = 32): + super().__init__() + self.sequence_length = sequence_length + self.train_dataloader_call_count = 0 + self.val_dataloader_call_count = 0 + self._train_epochs_called_for = [] + self._val_epochs_called_for = [] + + def train_dataloader(self): + self.train_dataloader_call_count += 1 + if self.trainer is not None: + self._train_epochs_called_for.append(self.trainer.current_epoch) + return DataLoader(RandomDataset(self.sequence_length, 64), batch_size=8) + + def val_dataloader(self): + self.val_dataloader_call_count += 1 + if self.trainer is not None: + self._val_epochs_called_for.append(self.trainer.current_epoch) + return DataLoader(RandomDataset(self.sequence_length, 64), batch_size=8) + + +class ManualReloadCallback(Callback): + """Callback that triggers manual dataloader reload at specific epochs.""" + + def __init__(self, reload_at_epoch: int, reload_train: bool = True, reload_val: bool = False): + super().__init__() + self.reload_at_epoch = reload_at_epoch + self.reload_train = reload_train + self.reload_val = reload_val + + def on_train_epoch_end(self, trainer, pl_module): + if trainer.current_epoch == self.reload_at_epoch: + trainer.reload_dataloaders(train=self.reload_train, val=self.reload_val) + + +class MetricBasedReloadCallback(Callback): + """Callback that triggers reload based on training metrics (curriculum learning example).""" + + def __init__(self, loss_threshold: float = 0.5): + super().__init__() + self.loss_threshold = loss_threshold + self.reload_triggered = False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.reload_triggered: + loss = trainer.callback_metrics.get("train_loss", 1.0) + if isinstance(loss, torch.Tensor): + loss = loss.item() + if loss < self.loss_threshold: + # Update datamodule parameters before reload + if trainer.datamodule is not None: + trainer.datamodule.sequence_length += 10 + trainer.reload_dataloaders() + self.reload_triggered = True + + +def test_reload_dataloaders_outside_training_raises_error(): + """Test that calling reload_dataloaders outside of fit() raises RuntimeError.""" + trainer = Trainer(max_epochs=1) + + with pytest.raises(RuntimeError, match="can only be called during training"): + trainer.reload_dataloaders() + + +def test_manual_reload_train_dataloader(tmp_path): + """Test that manually triggering train dataloader reload works.""" + + class TrackingModel(BoringModel): + def validation_step(self, batch, batch_idx): + return super().validation_step(batch, batch_idx) + + model = TrackingModel() + dm = ReloadTrackingDataModule() + callback = ManualReloadCallback(reload_at_epoch=1, reload_train=True, reload_val=False) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=4, + limit_train_batches=2, + limit_val_batches=2, + callbacks=[callback], + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model, datamodule=dm) + + # Train dataloader should be called at epochs 0, 2 (after manual reload at epoch 1) + # Without the manual reload, it would only be called at epoch 0 + assert dm.train_dataloader_call_count >= 2, ( + f"Expected at least 2 train_dataloader calls, got {dm.train_dataloader_call_count}" + ) + + +def test_manual_reload_val_dataloader(tmp_path): + """Test that manually triggering validation dataloader reload works.""" + + class TrackingModel(BoringModel): + def validation_step(self, batch, batch_idx): + return super().validation_step(batch, batch_idx) + + model = TrackingModel() + dm = ReloadTrackingDataModule() + callback = ManualReloadCallback(reload_at_epoch=1, reload_train=False, reload_val=True) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=4, + limit_train_batches=2, + limit_val_batches=2, + callbacks=[callback], + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model, datamodule=dm) + + # Validation dataloader should be called multiple times due to manual reload + assert dm.val_dataloader_call_count >= 2, ( + f"Expected at least 2 val_dataloader calls, got {dm.val_dataloader_call_count}" + ) + + +def test_manual_reload_both_dataloaders(tmp_path): + """Test that manually triggering both train and val dataloader reload works.""" + + class TrackingModel(BoringModel): + def validation_step(self, batch, batch_idx): + return super().validation_step(batch, batch_idx) + + model = TrackingModel() + dm = ReloadTrackingDataModule() + callback = ManualReloadCallback(reload_at_epoch=1, reload_train=True, reload_val=True) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=4, + limit_train_batches=2, + limit_val_batches=2, + callbacks=[callback], + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model, datamodule=dm) + + # Both dataloaders should be called multiple times + assert dm.train_dataloader_call_count >= 2 + assert dm.val_dataloader_call_count >= 2 + + +def test_manual_reload_updates_datamodule_params(tmp_path): + """Test that datamodule parameters can be updated before manual reload.""" + + class TrackingModel(BoringModel): + def validation_step(self, batch, batch_idx): + return super().validation_step(batch, batch_idx) + + class ParamUpdateCallback(Callback): + def __init__(self): + super().__init__() + self.sequence_lengths_seen = [] + + def on_train_epoch_start(self, trainer, pl_module): + if trainer.datamodule is not None: + self.sequence_lengths_seen.append(trainer.datamodule.sequence_length) + + def on_train_epoch_end(self, trainer, pl_module): + if trainer.current_epoch == 1: + # Update datamodule parameters + trainer.datamodule.sequence_length = 64 + # Trigger reload + trainer.reload_dataloaders(train=True) + + model = TrackingModel() + dm = ReloadTrackingDataModule(sequence_length=32) + callback = ParamUpdateCallback() + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=4, + limit_train_batches=2, + limit_val_batches=2, + callbacks=[callback], + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model, datamodule=dm) + + # After epoch 1, sequence_length should have been updated to 64 + assert dm.sequence_length == 64 + # And it should have been seen in subsequent epochs + assert 64 in callback.sequence_lengths_seen or dm.train_dataloader_call_count >= 2 + + +def test_reload_dataloaders_from_lightning_module(tmp_path): + """Test that reload_dataloaders can be called from within the LightningModule.""" + + class ReloadingModel(BoringModel): + def __init__(self): + super().__init__() + self.reload_triggered = False + + def on_train_epoch_end(self): + if self.current_epoch == 1 and not self.reload_triggered: + self.trainer.reload_dataloaders(train=True) + self.reload_triggered = True + + def validation_step(self, batch, batch_idx): + return super().validation_step(batch, batch_idx) + + model = ReloadingModel() + dm = ReloadTrackingDataModule() + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=4, + limit_train_batches=2, + limit_val_batches=2, + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model, datamodule=dm) + + # Should have triggered reload + assert dm.train_dataloader_call_count >= 2 + + +def test_reload_dataloaders_multiple_times(tmp_path): + """Test that reload_dataloaders can be called multiple times.""" + + class TrackingModel(BoringModel): + def validation_step(self, batch, batch_idx): + return super().validation_step(batch, batch_idx) + + class MultiReloadCallback(Callback): + def on_train_epoch_end(self, trainer, pl_module): + # Reload at every epoch + trainer.reload_dataloaders(train=True) + + model = TrackingModel() + dm = ReloadTrackingDataModule() + callback = MultiReloadCallback() + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=4, + limit_train_batches=2, + limit_val_batches=2, + callbacks=[callback], + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model, datamodule=dm) + + # Train dataloader should be called at every epoch + # Initial load + 3 reloads (after epochs 0, 1, 2) + # Note: reload at epoch 3 end won't take effect since training ends + assert dm.train_dataloader_call_count >= 4 + + +def test_reload_dataloaders_with_reload_every_n_epochs(tmp_path): + """Test that manual reload works alongside reload_dataloaders_every_n_epochs.""" + + class TrackingModel(BoringModel): + def validation_step(self, batch, batch_idx): + return super().validation_step(batch, batch_idx) + + model = TrackingModel() + dm = ReloadTrackingDataModule() + # Manual reload at epoch 0 (will reload at epoch 1 start) + callback = ManualReloadCallback(reload_at_epoch=0, reload_train=True) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=4, + limit_train_batches=2, + limit_val_batches=2, + reload_dataloaders_every_n_epochs=3, # Would reload at epoch 3 + callbacks=[callback], + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model, datamodule=dm) + + # Should have initial load + manual reload + possibly automatic reload + assert dm.train_dataloader_call_count >= 2 + + +def test_reload_dataloaders_default_args(tmp_path): + """Test reload_dataloaders with default arguments (train=True, val=False).""" + + class TrackingModel(BoringModel): + def validation_step(self, batch, batch_idx): + return super().validation_step(batch, batch_idx) + + class DefaultArgsCallback(Callback): + def on_train_epoch_end(self, trainer, pl_module): + if trainer.current_epoch == 1: + # Call with default args + trainer.reload_dataloaders() + + model = TrackingModel() + dm = ReloadTrackingDataModule() + callback = DefaultArgsCallback() + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=4, + limit_train_batches=2, + limit_val_batches=2, + callbacks=[callback], + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model, datamodule=dm) + + # Train dataloader should be reloaded + assert dm.train_dataloader_call_count >= 2