diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 22cf899496ed6..065dab60dcaec 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix `_generate_seed_sequence_sampling` function not producing unique seeds ([#21399](https://github.com/Lightning-AI/pytorch-lightning/pull/21399)) +- Fix `ThroughputMonitor` callback emitting warnings too frequently ([#21453](https://github.com/Lightning-AI/pytorch-lightning/pull/21453)) + + --- diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index d38928d33de75..ecd5349a3eb68 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -89,6 +89,7 @@ def __init__( self._lengths: dict[RunningStage, int] = {} self._samples: dict[RunningStage, int] = {} self._batches: dict[RunningStage, int] = {} + self._module_has_flops: bool | None = None @override def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: @@ -133,14 +134,15 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, if self.length_fn is not None: self._lengths[stage] += self.length_fn(batch) - if hasattr(pl_module, "flops_per_batch"): - flops_per_batch = pl_module.flops_per_batch - else: - rank_zero_warn( - "When using the `ThroughputMonitor`, you need to define a `flops_per_batch` attribute or property" - f" in {type(pl_module).__name__} to compute the FLOPs." - ) - flops_per_batch = None + if self._module_has_flops is None: + self._module_has_flops = hasattr(pl_module, "flops_per_batch") + if not self._module_has_flops: + rank_zero_warn( + "When using the `ThroughputMonitor`, you need to define a `flops_per_batch` attribute or property" + f" in {type(pl_module).__name__} to compute the FLOPs." + ) + + flops_per_batch = pl_module.flops_per_batch if self._module_has_flops else None self._samples[stage] += self.batch_size_fn(batch) self._batches[stage] += 1 diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 93bfe4e844c3a..d213f167fdbcc 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -1,3 +1,4 @@ +import warnings from unittest import mock from unittest.mock import ANY, Mock, call @@ -482,3 +483,31 @@ def test_throughput_monitor_validation_with_many_epochs(tmp_path): batch_num = 1 if end_train_timings_idx < len(timings): cur_train += timings[end_train_timings_idx] - timings[start_train_timings_idx] + + +def test_throughput_monitor_warn_once(): + monitor = ThroughputMonitor(batch_size_fn=lambda x: 1) + model = BoringModel() + + trainer = Trainer( + devices=1, + logger=False, + callbacks=[monitor], + max_epochs=1, + limit_train_batches=2, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + trainer.fit(model) + + throughput_warnings = [ + warn + for warn in w + if "When using the `ThroughputMonitor`, you need to define a `flops_per_batch`" in warn.message.args[0] + ] + + assert len(throughput_warnings) == 1, "Expected exactly one warning about missing `flops_per_batch`."