Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,29 @@ By setting to False, you have to add your own distributed sampler:
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
return dataloader

Custom samplers and automatic shuffling
---------------------------------------

When using a custom sampler, Lightning trainer will still apply automatic shuffling during training.
If your sampler fully controls the iteration order (for example, to enforce a specific
or deterministic ordering), you can opt out of this behavior by setting
``disable_auto_shuffle=True`` on the sampler.

This is particularly important when ``use_distributed_sampler=True`` (the default), as Lightning wraps custom samplers with ``DistributedSamplerWrapper`` and passes the ``shuffle`` parameter.

.. code-block:: python

class InOrderSampler(torch.utils.data.Sampler):
def __init__(self, dataset):
self.dataset = dataset
self.disable_auto_shuffle = True # <-------- opt out of auto shuffle

def __iter__(self):
yield from range(len(self.dataset))

def __len__(self):
return len(self.dataset)


val_check_interval
^^^^^^^^^^^^^^^^^^
Expand Down
17 changes: 13 additions & 4 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,28 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

---

### Fixed
## [Unreleased] - YYYY-MM-DD

### Added

- Added support for custom samplers to opt out of automatic shuffling during training by setting `disable_auto_shuffle = True` on the sampler. ([#21449](https://github.com/Lightning-AI/pytorch-lightning/pull/21449))

- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))
-

### Deprecated

- Deprecated `to_torchscript` method due to deprecation of TorchScript in PyTorch ([#21397](https://github.com/Lightning-AI/pytorch-lightning/pull/21397))


### Removed

---
- Removed support for Python 3.9 due to end-of-life status ([#21398](https://github.com/Lightning-AI/pytorch-lightning/pull/21398))


### Fixed

- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))


- Sanitize profiler filenames when saving to avoid crashes due to invalid characters ([#21395](https://github.com/Lightning-AI/pytorch-lightning/pull/21395))


Expand All @@ -31,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fix `_generate_seed_sequence_sampling` function not producing unique seeds ([#21399](https://github.com/Lightning-AI/pytorch-lightning/pull/21399))


---


## [2.6.0] - 2025-11-28

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,11 @@ def _process_dataloader(
category=PossibleUserWarning,
)
else:
is_shuffled = True
# during training, Lightning assumes data should be shuffled by default.
# custom samplers can opt out by setting `disable_auto_shuffle = True`
sampler = getattr(dataloader, "sampler", None)
disable_auto_shuffle = getattr(sampler, "disable_auto_shuffle", False)
is_shuffled = not disable_auto_shuffle

# automatically add samplers
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=is_shuffled, mode=stage)
Expand Down
70 changes: 70 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

import pytest
import torch
import torch.nn as nn
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.multiprocessing import ProcessRaisedException
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import DataLoader, Dataset, Sampler

import lightning.pytorch as pl
import tests_pytorch.helpers.pipelines as tpipes
Expand Down Expand Up @@ -495,3 +497,71 @@ def on_train_batch_end(self, *args, **kwargs):
gmin = trainer.callback_metrics["grad_sum_min"]
gmax = trainer.callback_metrics["grad_sum_max"]
assert torch.allclose(gmin, gmax)


@RunIf(min_cuda_gpus=2, standalone=True)
@pytest.mark.parametrize("disabled_auto_shuffle", [None, False, True])
def test_custom_sampler_disable_auto_shuffle(tmp_path, disabled_auto_shuffle):
"""Test that a custom sampler can opt out of Lightning's automatic shuffling in DDP."""
world_size = 2

class IntegerDataset(Dataset):
def __len__(self):
return 16

def __getitem__(self, idx):
return idx

class CustomInOrderSampler(Sampler):
def __init__(self, dataset):
self.dataset = dataset
if disabled_auto_shuffle is not None:
self.disable_auto_shuffle = disabled_auto_shuffle

def __iter__(self):
return iter(range(len(self.dataset)))

def __len__(self):
return len(self.dataset)

class RecordingModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(1, 1)
self.seen_indices = []

def training_step(self, batch, batch_idx):
# batch is a tensor of indices
self.seen_indices.extend(batch.tolist())
return torch.tensor(0.0, requires_grad=True)

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)

dataset = IntegerDataset()
sampler = CustomInOrderSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)

model = RecordingModule()
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="gpu",
devices=world_size,
strategy="ddp",
max_steps=2,
enable_progress_bar=False,
enable_model_summary=False,
)

trainer.fit(model, train_dataloaders=dataloader)

seen = model.seen_indices

if disabled_auto_shuffle is True:
# In-order distributed sampling: indices differ by world size
diffs = [j - i for i, j in zip(seen[:-1], seen[1:])]
assert all(d == world_size for d in diffs)
else:
# Order is no longer guaranteed
diffs = [j - i for i, j in zip(seen[:-1], seen[1:])]
assert any(d != world_size for d in diffs)
Loading