Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.callbacks.checkpoint import Checkpoint
from lightning.pytorch.callbacks.device_stats_monitor import DeviceStatsMonitor
from lightning.pytorch.callbacks.device_summary import DeviceSummary
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
from lightning.pytorch.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
Expand Down Expand Up @@ -42,6 +43,7 @@
"Callback",
"Checkpoint",
"DeviceStatsMonitor",
"DeviceSummary",
"EarlyStopping",
"GradientAccumulationScheduler",
"LambdaCallback",
Expand Down
104 changes: 104 additions & 0 deletions src/lightning/pytorch/callbacks/device_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Device Summary
==============

Logs information about available and used devices (GPU, TPU, etc.) at the start of training.

"""

from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn


class DeviceSummary(Callback):
r"""Logs information about available and used devices at the start of training.

This callback prints the availability and usage status of GPUs (CUDA/MPS) and TPUs.
It also warns if a device is available but not being used.

Args:
show_warnings: Whether to show warnings when available devices are not used.
Defaults to ``True``.

Example::

>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import DeviceSummary
>>> # Default behavior - shows device info and warnings
>>> trainer = Trainer(callbacks=[DeviceSummary()])
>>> # Suppress device availability warnings
>>> trainer = Trainer(callbacks=[DeviceSummary(show_warnings=False)])
>>> # Disable device summary completely by not including the callback
>>> trainer = Trainer(callbacks=[], enable_device_summary=False)

"""

def __init__(self, show_warnings: bool = True) -> None:
self._show_warnings = show_warnings
self._logged = False

@override
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""Log device information at the start of any training stage.

The device summary is only logged once per Trainer instance, even if setup is called multiple times (e.g., for
fit then test).

"""
if self._logged:
return
self._logged = True
self._log_device_info(trainer)

def _log_device_info(self, trainer: "pl.Trainer") -> None:
"""Log information about available and used devices."""
if CUDAAccelerator.is_available():
gpu_available = True
gpu_type = " (cuda)"
elif MPSAccelerator.is_available():
gpu_available = True
gpu_type = " (mps)"
else:
gpu_available = False
gpu_type = ""

gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator))
rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}")

num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0
rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores")

if not self._show_warnings:
return

if (
CUDAAccelerator.is_available()
and not isinstance(trainer.accelerator, CUDAAccelerator)
or MPSAccelerator.is_available()
and not isinstance(trainer.accelerator, MPSAccelerator)
):
rank_zero_warn(
"GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.",
category=PossibleUserWarning,
)

if XLAAccelerator.is_available() and not isinstance(trainer.accelerator, XLAAccelerator):
rank_zero_warn("TPU available but not used. You can set it by doing `Trainer(accelerator='tpu')`.")
19 changes: 19 additions & 0 deletions src/lightning/pytorch/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lightning.pytorch.callbacks import (
Callback,
Checkpoint,
DeviceSummary,
ModelCheckpoint,
ModelSummary,
ProgressBar,
Expand Down Expand Up @@ -57,6 +58,7 @@ def on_trainer_init(
enable_progress_bar: bool,
default_root_dir: Optional[str],
enable_model_summary: bool,
enable_device_summary: bool,
max_time: Optional[Union[str, timedelta, dict[str, int]]] = None,
) -> None:
# init folder paths for checkpoint + weights save callbacks
Expand All @@ -81,6 +83,9 @@ def on_trainer_init(
# configure the ModelSummary callback
self._configure_model_summary_callback(enable_model_summary)

# configure the DeviceSummary callback
self._configure_device_summary_callback(enable_device_summary)

self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory"))
_validate_callbacks_list(self.trainer.callbacks)

Expand Down Expand Up @@ -129,6 +134,20 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
model_summary = RichModelSummary() if _RICH_AVAILABLE else ModelSummary()
self.trainer.callbacks.append(model_summary)

def _configure_device_summary_callback(self, enable_device_summary: bool) -> None:
if not enable_device_summary:
return

device_summary_cbs = [type(cb) for cb in self.trainer.callbacks if isinstance(cb, DeviceSummary)]
if device_summary_cbs:
rank_zero_info(
f"Trainer already configured with device summary callbacks: {device_summary_cbs}."
" Skipping setting a default `DeviceSummary` callback."
)
return

self.trainer.callbacks.append(DeviceSummary())

def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBar)]
if len(progress_bars) > 1:
Expand Down
17 changes: 15 additions & 2 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
enable_checkpointing: Optional[bool] = None,
enable_progress_bar: Optional[bool] = None,
enable_model_summary: Optional[bool] = None,
enable_device_summary: Optional[bool] = None,
accumulate_grad_batches: int = 1,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
Expand Down Expand Up @@ -238,6 +239,10 @@ def __init__(
enable_model_summary: Whether to enable model summarization by default.
Default: ``True``.

enable_device_summary: Whether to show device availability information (GPU, TPU, etc.)
at the start of training. Set to ``False`` to suppress device printout.
Default: ``True``.

accumulate_grad_batches: Accumulates gradients over k batches before stepping the optimizer.
Default: 1.

Expand Down Expand Up @@ -358,6 +363,12 @@ def __init__(
" Model summary can impact raw speed so it is disabled in barebones mode."
)
enable_model_summary = False
if enable_device_summary:
raise ValueError(
f"`Trainer(barebones=True, enable_device_summary={enable_device_summary!r})` was passed."
" Device summary can impact raw speed so it is disabled in barebones mode."
)
enable_device_summary = False
if num_sanity_val_steps is not None and num_sanity_val_steps != 0:
raise ValueError(
f"`Trainer(barebones=True, num_sanity_val_steps={num_sanity_val_steps!r})` was passed."
Expand All @@ -384,6 +395,7 @@ def __init__(
" - Checkpointing: `Trainer(enable_checkpointing=True)`",
" - Progress bar: `Trainer(enable_progress_bar=True)`",
" - Model summary: `Trainer(enable_model_summary=True)`",
" - Device summary: `Trainer(enable_device_summary=True)`",
" - Logging: `Trainer(logger=True)`, `Trainer(log_every_n_steps>0)`,"
" `LightningModule.log(...)`, `LightningModule.log_dict(...)`",
" - Sanity checking: `Trainer(num_sanity_val_steps>0)`",
Expand All @@ -408,6 +420,8 @@ def __init__(
log_every_n_steps = 50
if enable_model_summary is None:
enable_model_summary = True
if enable_device_summary is None:
enable_device_summary = True
if num_sanity_val_steps is None:
num_sanity_val_steps = 2

Expand Down Expand Up @@ -450,6 +464,7 @@ def __init__(
enable_progress_bar,
default_root_dir,
enable_model_summary,
enable_device_summary,
max_time,
)

Expand Down Expand Up @@ -485,8 +500,6 @@ def __init__(
)
self._detect_anomaly: bool = detect_anomaly

setup._log_device_info(self)

self.should_stop = False
self.state = TrainerState()

Expand Down
145 changes: 145 additions & 0 deletions tests/tests_pytorch/callbacks/test_device_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import DeviceSummary
from lightning.pytorch.demos.boring_classes import BoringModel


def test_device_summary_enabled_by_default(tmp_path):
"""Test that DeviceSummary callback is enabled by default."""
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=1,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
)
device_summary_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, DeviceSummary)]
assert len(device_summary_callbacks) == 1


def test_device_summary_disabled(tmp_path):
"""Test that DeviceSummary callback can be disabled."""
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=1,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
enable_device_summary=False,
)
device_summary_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, DeviceSummary)]
assert len(device_summary_callbacks) == 0


def test_device_summary_custom_callback(tmp_path):
"""Test that custom DeviceSummary callback is used when provided."""
custom_callback = DeviceSummary(show_warnings=False)
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=1,
callbacks=[custom_callback],
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
)
device_summary_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, DeviceSummary)]
assert len(device_summary_callbacks) == 1
assert device_summary_callbacks[0] is custom_callback


def test_device_summary_logs_once(tmp_path):
"""Test that DeviceSummary only logs once per Trainer instance."""
model = BoringModel()
callback = DeviceSummary()

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
callbacks=[callback],
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
enable_device_summary=False, # Don't add default callback
)

with mock.patch.object(callback, "_log_device_info") as mock_log:
trainer.fit(model)
assert mock_log.call_count == 1

# Run validation - should not log again
trainer.validate(model)
assert mock_log.call_count == 1


@mock.patch("lightning.pytorch.callbacks.device_summary.rank_zero_info")
def test_device_summary_output(mock_info, tmp_path):
"""Test that DeviceSummary logs expected information."""
model = BoringModel()
callback = DeviceSummary()

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=1,
callbacks=[callback],
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
enable_device_summary=False,
)

trainer.fit(model)

# Check that GPU and TPU info was logged
calls = [str(call) for call in mock_info.call_args_list]
gpu_logged = any("GPU available" in call for call in calls)
tpu_logged = any("TPU available" in call for call in calls)
assert gpu_logged
assert tpu_logged


def test_device_summary_show_warnings_disabled(tmp_path):
"""Test that warnings can be suppressed."""
callback = DeviceSummary(show_warnings=False)
assert callback._show_warnings is False


def test_device_summary_barebones_mode_raises(tmp_path):
"""Test that enable_device_summary raises error in barebones mode."""
with pytest.raises(ValueError, match="barebones=True, enable_device_summary"):
Trainer(
default_root_dir=tmp_path,
barebones=True,
enable_device_summary=True,
)


def test_device_summary_barebones_mode_disabled(tmp_path):
"""Test that DeviceSummary is disabled in barebones mode."""
trainer = Trainer(
default_root_dir=tmp_path,
barebones=True,
)
device_summary_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, DeviceSummary)]
assert len(device_summary_callbacks) == 0
Loading