Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,27 @@ By setting to False, you have to add your own distributed sampler:
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
return dataloader

Custom samplers and automatic shuffling
---------------------------------------

When using a custom sampler, Lightning trainer will still apply automatic shuffling during training.
If your sampler fully controls the iteration order (for example, to enforce a specific
or deterministic ordering), you can opt out of this behavior by setting
``disable_auto_shuffle = True`` on the sampler.

.. code-block:: python

class InOrderSampler(torch.utils.data.Sampler):
def __init__(self, dataset):
self.dataset = dataset
self.disable_auto_shuffle = True

def __iter__(self):
yield from range(len(self.dataset))

def __len__(self):
return len(self.dataset)


val_check_interval
^^^^^^^^^^^^^^^^^^
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

---

### Added

- Added support for custom samplers to opt out of automatic shuffling during training by setting `disable_auto_shuffle = True` on the sampler. ([#21449](https://github.com/Lightning-AI/pytorch-lightning/pull/21449))


### Fixed

- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ def _process_dataloader(
category=PossibleUserWarning,
)
else:
is_shuffled = True
# custom samplers may explicitly disable Lightning's automatic shuffling
# to preserve their intended iteration order
disable_auto_shuffle = getattr(dataloader.sampler, "disable_auto_shuffle", False)
is_shuffled = not disable_auto_shuffle

# automatically add samplers
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=is_shuffled, mode=stage)
Expand Down
Loading