Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `LightningDataModule.load_from_checkpoint` to restore the datamodule subclass and hyperparameters ([#21478](https://github.com/Lightning-AI/pytorch-lightning/pull/21478))

- Fixed ``RichModelSummary`` model size display formatting ([#21467](https://github.com/Lightning-AI/pytorch-lightning/pull/21467))

- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))

Expand Down
5 changes: 3 additions & 2 deletions src/lightning/pytorch/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
from lightning.pytorch.utilities.model_summary import get_human_readable_count
from lightning.pytorch.utilities.model_summary import get_formatted_model_size, get_human_readable_count


class RichModelSummary(ModelSummary):
Expand Down Expand Up @@ -105,8 +105,9 @@ def summarize(
console.print(table)

parameters = []
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters]:
parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
parameters.append("{:<{}}".format(get_formatted_model_size(model_size), 10))

grid = Table.grid(expand=True)
grid.add_column()
Expand Down
31 changes: 31 additions & 0 deletions tests/tests_pytorch/callbacks/test_rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,34 @@ def example_input_array(self) -> Any:
# assert that the input summary data was converted correctly
args, _ = mock_table_add_row.call_args_list[0]
assert args[1:] == ("0", "layer", "Linear", "66 ", "train", "512 ", "[4, 32]", "[4, 2]")


@RunIf(rich=True)
def test_rich_summary_model_size_formatting():
"""Ensure model_size uses get_formatted_model_size, not get_human_readable_count."""
from io import StringIO

from rich.console import Console

model_summary = RichModelSummary()
model = BoringModel()
summary = summarize(model)
summary_data = summary._get_summary_data()

output = StringIO()
console = Console(file=output, force_terminal=True)

with mock.patch("rich.get_console", return_value=console):
model_summary.summarize(
summary_data=summary_data,
total_parameters=1,
trainable_parameters=1,
model_size=5500.0,
total_training_modes=summary.total_training_modes,
total_flops=1,
)

result = output.getvalue()
# model_size=5500.0 should display as "5,500.000" (formatted), not "5.5 K" (human readable count)
assert "5,500.000" in result
assert "5.5 K" not in result
Loading