Skip to content

Commit f3f6605

Browse files
Fix StochasticWeightAveraging with infinite epochs (#21396)
* implement special case max_epoch==-1 * add testing * changelog --------- Co-authored-by: Bhimraj Yadav <[email protected]>
1 parent 3876cc5 commit f3f6605

File tree

3 files changed

+47
-7
lines changed

3 files changed

+47
-7
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222

2323
### Fixed
2424

25+
- Fix `StochasticWeightAveraging` with infinite epochs ([#21396](https://github.com/Lightning-AI/pytorch-lightning/pull/21396))
26+
27+
2528
- Fix `_generate_seed_sequence_sampling` function not producing unique seeds ([#21399](https://github.com/Lightning-AI/pytorch-lightning/pull/21399))
2629

2730

src/lightning/pytorch/callbacks/stochastic_weight_avg.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def swa_start(self) -> int:
139139

140140
@property
141141
def swa_end(self) -> int:
142+
if self._max_epochs == -1:
143+
return float("inf") # type: ignore[return-value]
142144
return self._max_epochs - 1 # 0-based
143145

144146
@staticmethod
@@ -163,12 +165,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
163165

164166
assert trainer.max_epochs is not None
165167
if isinstance(self._swa_epoch_start, float):
168+
if trainer.max_epochs == -1:
169+
raise MisconfigurationException(
170+
"SWA with `swa_epoch_start` as a float is not supported when `max_epochs=-1`. "
171+
"Please provide `swa_epoch_start` as an integer."
172+
)
166173
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)
167174

168175
self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)
169176

170177
self._max_epochs = trainer.max_epochs
171-
if self._model_contains_batch_norm:
178+
if self._model_contains_batch_norm and trainer.max_epochs != -1:
172179
# virtually increase max_epochs to perform batch norm update on latest epoch.
173180
assert trainer.fit_loop.max_epochs is not None
174181
trainer.fit_loop.max_epochs += 1
@@ -243,7 +250,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
243250
self._latest_update_epoch = trainer.current_epoch
244251

245252
# Note: No > here in case the callback is saved with the model and training continues
246-
if trainer.current_epoch == self.swa_end + 1:
253+
if self._max_epochs != -1 and trainer.current_epoch == self.swa_end + 1:
247254
# Transfer weights from average model to pl_module
248255
assert self._average_model is not None
249256
self.transfer_weights(self._average_model, pl_module)
@@ -267,17 +274,17 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
267274
@override
268275
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
269276
# the trainer increases the current epoch before this hook is called
270-
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
277+
if self._model_contains_batch_norm and self._max_epochs != -1 and trainer.current_epoch - 1 == self.swa_end + 1:
271278
# BatchNorm epoch update. Reset state
272279
trainer.accumulate_grad_batches = self._accumulate_grad_batches
273280
trainer.fit_loop.max_batches -= 1
274281
assert trainer.fit_loop.max_epochs is not None
275282
trainer.fit_loop.max_epochs -= 1
276283
self.reset_momenta()
277-
elif trainer.current_epoch - 1 == self.swa_end:
278-
# Last SWA epoch. Transfer weights from average model to pl_module
279-
assert self._average_model is not None
280-
self.transfer_weights(self._average_model, pl_module)
284+
elif trainer.current_epoch - 1 == self.swa_end or self._max_epochs == -1:
285+
# Last SWA epoch or infinite training. Transfer weights from average model to pl_module
286+
if self._average_model is not None:
287+
self.transfer_weights(self._average_model, pl_module)
281288

282289
@staticmethod
283290
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None:

tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,5 +387,35 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str):
387387
trainer.fit(model)
388388

389389

390+
def test_swa_with_infinite_epochs_and_batchnorm(tmp_path):
391+
"""Test that SWA works correctly with max_epochs=-1 (infinite training) and BatchNorm."""
392+
model = SwaTestModel(batchnorm=True)
393+
swa_callback = StochasticWeightAveraging(swa_lrs=0.1, swa_epoch_start=2)
394+
395+
trainer = Trainer(
396+
default_root_dir=tmp_path,
397+
enable_progress_bar=False,
398+
enable_model_summary=False,
399+
max_epochs=-1,
400+
max_steps=30, # Use max_steps as stopping condition
401+
limit_train_batches=5,
402+
limit_val_batches=0,
403+
callbacks=[swa_callback],
404+
logger=False,
405+
)
406+
assert trainer.max_epochs == -1
407+
assert trainer.fit_loop.max_epochs == -1
408+
409+
trainer.fit(model)
410+
assert trainer.current_epoch >= 5
411+
assert trainer.global_step == 30
412+
assert trainer.max_epochs == -1
413+
414+
# Verify SWA was actually applied (update_parameters should have been called)
415+
# SWA starts at epoch 2, so with 6 epochs (0-5), we should have 4 updates (epochs 2, 3, 4, 5)
416+
assert swa_callback.n_averaged is not None
417+
assert swa_callback.n_averaged > 0, "SWA should have updated parameters"
418+
419+
390420
def _backward_patch(trainer: Trainer) -> AbstractContextManager:
391421
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)

0 commit comments

Comments
 (0)