diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 80d68b83534b4..81c87fb295bc5 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -1176,6 +1176,29 @@ By setting to False, you have to add your own distributed sampler: dataloader = DataLoader(dataset, batch_size=32, sampler=sampler) return dataloader +Custom samplers and automatic shuffling +--------------------------------------- + +When using a custom sampler, Lightning trainer will still apply automatic shuffling during training. +If your sampler fully controls the iteration order (for example, to enforce a specific +or deterministic ordering), you can opt out of this behavior by setting +``disable_auto_shuffle=True`` on the sampler. + +This is particularly important when ``use_distributed_sampler=True`` (the default), as Lightning wraps custom samplers with ``DistributedSamplerWrapper`` and passes the ``shuffle`` parameter. + +.. code-block:: python + + class InOrderSampler(torch.utils.data.Sampler): + def __init__(self, dataset): + self.dataset = dataset + self.disable_auto_shuffle = True # <-------- opt out of auto shuffle + + def __iter__(self): + yield from range(len(self.dataset)) + + def __len__(self): + return len(self.dataset) + val_check_interval ^^^^^^^^^^^^^^^^^^ diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 343a79c76b17f..4ca0cc7ab9189 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -6,22 +6,28 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). --- -### Fixed +## [Unreleased] - YYYY-MM-DD + +### Added + +- Added support for custom samplers to opt out of automatic shuffling during training by setting `disable_auto_shuffle = True` on the sampler. ([#21449](https://github.com/Lightning-AI/pytorch-lightning/pull/21449)) -- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357)) -- ### Deprecated - Deprecated `to_torchscript` method due to deprecation of TorchScript in PyTorch ([#21397](https://github.com/Lightning-AI/pytorch-lightning/pull/21397)) + ### Removed ---- - Removed support for Python 3.9 due to end-of-life status ([#21398](https://github.com/Lightning-AI/pytorch-lightning/pull/21398)) + ### Fixed +- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357)) + + - Sanitize profiler filenames when saving to avoid crashes due to invalid characters ([#21395](https://github.com/Lightning-AI/pytorch-lightning/pull/21395)) @@ -31,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix `_generate_seed_sequence_sampling` function not producing unique seeds ([#21399](https://github.com/Lightning-AI/pytorch-lightning/pull/21399)) +--- + + ## [2.6.0] - 2025-11-28 ### Added diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 240dae6296c1f..4798d3f277826 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -488,7 +488,11 @@ def _process_dataloader( category=PossibleUserWarning, ) else: - is_shuffled = True + # during training, Lightning assumes data should be shuffled by default. + # custom samplers can opt out by setting `disable_auto_shuffle = True` + sampler = getattr(dataloader, "sampler", None) + disable_auto_shuffle = getattr(sampler, "disable_auto_shuffle", False) + is_shuffled = not disable_auto_shuffle # automatically add samplers dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=is_shuffled, mode=stage) diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 6373985687ad3..d083b936c9a87 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -17,9 +17,11 @@ import pytest import torch +import torch.nn as nn from torch.distributed.optim import ZeroRedundancyOptimizer from torch.multiprocessing import ProcessRaisedException from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader, Dataset, Sampler import lightning.pytorch as pl import tests_pytorch.helpers.pipelines as tpipes @@ -495,3 +497,71 @@ def on_train_batch_end(self, *args, **kwargs): gmin = trainer.callback_metrics["grad_sum_min"] gmax = trainer.callback_metrics["grad_sum_max"] assert torch.allclose(gmin, gmax) + + +@RunIf(min_cuda_gpus=2, standalone=True) +@pytest.mark.parametrize("disabled_auto_shuffle", [None, False, True]) +def test_custom_sampler_disable_auto_shuffle(tmp_path, disabled_auto_shuffle): + """Test that a custom sampler can opt out of Lightning's automatic shuffling in DDP.""" + world_size = 2 + + class IntegerDataset(Dataset): + def __len__(self): + return 16 + + def __getitem__(self, idx): + return idx + + class CustomInOrderSampler(Sampler): + def __init__(self, dataset): + self.dataset = dataset + if disabled_auto_shuffle is not None: + self.disable_auto_shuffle = disabled_auto_shuffle + + def __iter__(self): + return iter(range(len(self.dataset))) + + def __len__(self): + return len(self.dataset) + + class RecordingModule(pl.LightningModule): + def __init__(self): + super().__init__() + self.layer = nn.Linear(1, 1) + self.seen_indices = [] + + def training_step(self, batch, batch_idx): + # batch is a tensor of indices + self.seen_indices.extend(batch.tolist()) + return torch.tensor(0.0, requires_grad=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.1) + + dataset = IntegerDataset() + sampler = CustomInOrderSampler(dataset) + dataloader = DataLoader(dataset, sampler=sampler, batch_size=2) + + model = RecordingModule() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="gpu", + devices=world_size, + strategy="ddp", + max_steps=2, + enable_progress_bar=False, + enable_model_summary=False, + ) + + trainer.fit(model, train_dataloaders=dataloader) + + seen = model.seen_indices + + if disabled_auto_shuffle is True: + # In-order distributed sampling: indices differ by world size + diffs = [j - i for i, j in zip(seen[:-1], seen[1:])] + assert all(d == world_size for d in diffs) + else: + # Order is no longer guaranteed + diffs = [j - i for i, j in zip(seen[:-1], seen[1:])] + assert any(d != world_size for d in diffs)