diff --git a/src/lightning/pytorch/callbacks/__init__.py b/src/lightning/pytorch/callbacks/__init__.py index dd96c045d8366..9e07cf28a72f0 100644 --- a/src/lightning/pytorch/callbacks/__init__.py +++ b/src/lightning/pytorch/callbacks/__init__.py @@ -15,6 +15,7 @@ from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.callbacks.checkpoint import Checkpoint from lightning.pytorch.callbacks.device_stats_monitor import DeviceStatsMonitor +from lightning.pytorch.callbacks.device_summary import DeviceSummary from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.callbacks.finetuning import BackboneFinetuning, BaseFinetuning from lightning.pytorch.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler @@ -42,6 +43,7 @@ "Callback", "Checkpoint", "DeviceStatsMonitor", + "DeviceSummary", "EarlyStopping", "GradientAccumulationScheduler", "LambdaCallback", diff --git a/src/lightning/pytorch/callbacks/device_summary.py b/src/lightning/pytorch/callbacks/device_summary.py new file mode 100644 index 0000000000000..5c916e9d53929 --- /dev/null +++ b/src/lightning/pytorch/callbacks/device_summary.py @@ -0,0 +1,104 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Device Summary +============== + +Logs information about available and used devices (GPU, TPU, etc.) at the start of training. + +""" + +from typing_extensions import override + +import lightning.pytorch as pl +from lightning.fabric.utilities.warnings import PossibleUserWarning +from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator +from lightning.pytorch.callbacks.callback import Callback +from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn + + +class DeviceSummary(Callback): + r"""Logs information about available and used devices at the start of training. + + This callback prints the availability and usage status of GPUs (CUDA/MPS) and TPUs. + It also warns if a device is available but not being used. + + Args: + show_warnings: Whether to show warnings when available devices are not used. + Defaults to ``True``. + + Example:: + + >>> from lightning.pytorch import Trainer + >>> from lightning.pytorch.callbacks import DeviceSummary + >>> # Default behavior - shows device info and warnings + >>> trainer = Trainer(callbacks=[DeviceSummary()]) + >>> # Suppress device availability warnings + >>> trainer = Trainer(callbacks=[DeviceSummary(show_warnings=False)]) + >>> # Disable device summary completely by not including the callback + >>> trainer = Trainer(callbacks=[], enable_device_summary=False) + + """ + + def __init__(self, show_warnings: bool = True) -> None: + self._show_warnings = show_warnings + self._logged = False + + @override + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + """Log device information at the start of any training stage. + + The device summary is only logged once per Trainer instance, even if setup is called multiple times (e.g., for + fit then test). + + """ + if self._logged: + return + self._logged = True + self._log_device_info(trainer) + + def _log_device_info(self, trainer: "pl.Trainer") -> None: + """Log information about available and used devices.""" + if CUDAAccelerator.is_available(): + gpu_available = True + gpu_type = " (cuda)" + elif MPSAccelerator.is_available(): + gpu_available = True + gpu_type = " (mps)" + else: + gpu_available = False + gpu_type = "" + + gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) + rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") + + num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0 + rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores") + + if not self._show_warnings: + return + + if ( + CUDAAccelerator.is_available() + and not isinstance(trainer.accelerator, CUDAAccelerator) + or MPSAccelerator.is_available() + and not isinstance(trainer.accelerator, MPSAccelerator) + ): + rank_zero_warn( + "GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.", + category=PossibleUserWarning, + ) + + if XLAAccelerator.is_available() and not isinstance(trainer.accelerator, XLAAccelerator): + rank_zero_warn("TPU available but not used. You can set it by doing `Trainer(accelerator='tpu')`.") diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index 48f7ea86048ed..f8b33e56b8e36 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -25,6 +25,7 @@ from lightning.pytorch.callbacks import ( Callback, Checkpoint, + DeviceSummary, ModelCheckpoint, ModelSummary, ProgressBar, @@ -57,6 +58,7 @@ def on_trainer_init( enable_progress_bar: bool, default_root_dir: Optional[str], enable_model_summary: bool, + enable_device_summary: bool, max_time: Optional[Union[str, timedelta, dict[str, int]]] = None, ) -> None: # init folder paths for checkpoint + weights save callbacks @@ -81,6 +83,9 @@ def on_trainer_init( # configure the ModelSummary callback self._configure_model_summary_callback(enable_model_summary) + # configure the DeviceSummary callback + self._configure_device_summary_callback(enable_device_summary) + self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory")) _validate_callbacks_list(self.trainer.callbacks) @@ -129,6 +134,20 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: model_summary = RichModelSummary() if _RICH_AVAILABLE else ModelSummary() self.trainer.callbacks.append(model_summary) + def _configure_device_summary_callback(self, enable_device_summary: bool) -> None: + if not enable_device_summary: + return + + device_summary_cbs = [type(cb) for cb in self.trainer.callbacks if isinstance(cb, DeviceSummary)] + if device_summary_cbs: + rank_zero_info( + f"Trainer already configured with device summary callbacks: {device_summary_cbs}." + " Skipping setting a default `DeviceSummary` callback." + ) + return + + self.trainer.callbacks.append(DeviceSummary()) + def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBar)] if len(progress_bars) > 1: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 6f947160ba9cb..3b75193f15e37 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -116,6 +116,7 @@ def __init__( enable_checkpointing: Optional[bool] = None, enable_progress_bar: Optional[bool] = None, enable_model_summary: Optional[bool] = None, + enable_device_summary: Optional[bool] = None, accumulate_grad_batches: int = 1, gradient_clip_val: Optional[Union[int, float]] = None, gradient_clip_algorithm: Optional[str] = None, @@ -238,6 +239,10 @@ def __init__( enable_model_summary: Whether to enable model summarization by default. Default: ``True``. + enable_device_summary: Whether to show device availability information (GPU, TPU, etc.) + at the start of training. Set to ``False`` to suppress device printout. + Default: ``True``. + accumulate_grad_batches: Accumulates gradients over k batches before stepping the optimizer. Default: 1. @@ -358,6 +363,12 @@ def __init__( " Model summary can impact raw speed so it is disabled in barebones mode." ) enable_model_summary = False + if enable_device_summary: + raise ValueError( + f"`Trainer(barebones=True, enable_device_summary={enable_device_summary!r})` was passed." + " Device summary can impact raw speed so it is disabled in barebones mode." + ) + enable_device_summary = False if num_sanity_val_steps is not None and num_sanity_val_steps != 0: raise ValueError( f"`Trainer(barebones=True, num_sanity_val_steps={num_sanity_val_steps!r})` was passed." @@ -384,6 +395,7 @@ def __init__( " - Checkpointing: `Trainer(enable_checkpointing=True)`", " - Progress bar: `Trainer(enable_progress_bar=True)`", " - Model summary: `Trainer(enable_model_summary=True)`", + " - Device summary: `Trainer(enable_device_summary=True)`", " - Logging: `Trainer(logger=True)`, `Trainer(log_every_n_steps>0)`," " `LightningModule.log(...)`, `LightningModule.log_dict(...)`", " - Sanity checking: `Trainer(num_sanity_val_steps>0)`", @@ -408,6 +420,8 @@ def __init__( log_every_n_steps = 50 if enable_model_summary is None: enable_model_summary = True + if enable_device_summary is None: + enable_device_summary = True if num_sanity_val_steps is None: num_sanity_val_steps = 2 @@ -450,6 +464,7 @@ def __init__( enable_progress_bar, default_root_dir, enable_model_summary, + enable_device_summary, max_time, ) @@ -485,8 +500,6 @@ def __init__( ) self._detect_anomaly: bool = detect_anomaly - setup._log_device_info(self) - self.should_stop = False self.state = TrainerState() diff --git a/tests/tests_pytorch/callbacks/test_device_summary.py b/tests/tests_pytorch/callbacks/test_device_summary.py new file mode 100644 index 0000000000000..038a8185c78b8 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_device_summary.py @@ -0,0 +1,145 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest + +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import DeviceSummary +from lightning.pytorch.demos.boring_classes import BoringModel + + +def test_device_summary_enabled_by_default(tmp_path): + """Test that DeviceSummary callback is enabled by default.""" + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + limit_train_batches=1, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + device_summary_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, DeviceSummary)] + assert len(device_summary_callbacks) == 1 + + +def test_device_summary_disabled(tmp_path): + """Test that DeviceSummary callback can be disabled.""" + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + limit_train_batches=1, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + enable_device_summary=False, + ) + device_summary_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, DeviceSummary)] + assert len(device_summary_callbacks) == 0 + + +def test_device_summary_custom_callback(tmp_path): + """Test that custom DeviceSummary callback is used when provided.""" + custom_callback = DeviceSummary(show_warnings=False) + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + limit_train_batches=1, + callbacks=[custom_callback], + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + device_summary_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, DeviceSummary)] + assert len(device_summary_callbacks) == 1 + assert device_summary_callbacks[0] is custom_callback + + +def test_device_summary_logs_once(tmp_path): + """Test that DeviceSummary only logs once per Trainer instance.""" + model = BoringModel() + callback = DeviceSummary() + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + callbacks=[callback], + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + enable_device_summary=False, # Don't add default callback + ) + + with mock.patch.object(callback, "_log_device_info") as mock_log: + trainer.fit(model) + assert mock_log.call_count == 1 + + # Run validation - should not log again + trainer.validate(model) + assert mock_log.call_count == 1 + + +@mock.patch("lightning.pytorch.callbacks.device_summary.rank_zero_info") +def test_device_summary_output(mock_info, tmp_path): + """Test that DeviceSummary logs expected information.""" + model = BoringModel() + callback = DeviceSummary() + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + limit_train_batches=1, + callbacks=[callback], + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + enable_device_summary=False, + ) + + trainer.fit(model) + + # Check that GPU and TPU info was logged + calls = [str(call) for call in mock_info.call_args_list] + gpu_logged = any("GPU available" in call for call in calls) + tpu_logged = any("TPU available" in call for call in calls) + assert gpu_logged + assert tpu_logged + + +def test_device_summary_show_warnings_disabled(tmp_path): + """Test that warnings can be suppressed.""" + callback = DeviceSummary(show_warnings=False) + assert callback._show_warnings is False + + +def test_device_summary_barebones_mode_raises(tmp_path): + """Test that enable_device_summary raises error in barebones mode.""" + with pytest.raises(ValueError, match="barebones=True, enable_device_summary"): + Trainer( + default_root_dir=tmp_path, + barebones=True, + enable_device_summary=True, + ) + + +def test_device_summary_barebones_mode_disabled(tmp_path): + """Test that DeviceSummary is disabled in barebones mode.""" + trainer = Trainer( + default_root_dir=tmp_path, + barebones=True, + ) + device_summary_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, DeviceSummary)] + assert len(device_summary_callbacks) == 0