diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 22cf899496ed6..058aeae8124ae 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -28,6 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `LightningDataModule.load_from_checkpoint` to restore the datamodule subclass and hyperparameters ([#21478](https://github.com/Lightning-AI/pytorch-lightning/pull/21478)) +- Fixed ``RichModelSummary`` model size display formatting ([#21467](https://github.com/Lightning-AI/pytorch-lightning/pull/21467)) - Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357)) diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py index 2843806ca595a..e287361e7412a 100644 --- a/src/lightning/pytorch/callbacks/rich_model_summary.py +++ b/src/lightning/pytorch/callbacks/rich_model_summary.py @@ -17,7 +17,7 @@ from lightning.pytorch.callbacks import ModelSummary from lightning.pytorch.utilities.imports import _RICH_AVAILABLE -from lightning.pytorch.utilities.model_summary import get_human_readable_count +from lightning.pytorch.utilities.model_summary import get_formatted_model_size, get_human_readable_count class RichModelSummary(ModelSummary): @@ -105,8 +105,9 @@ def summarize( console.print(table) parameters = [] - for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]: + for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters]: parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10)) + parameters.append("{:<{}}".format(get_formatted_model_size(model_size), 10)) grid = Table.grid(expand=True) grid.add_column() diff --git a/tests/tests_pytorch/callbacks/test_rich_model_summary.py b/tests/tests_pytorch/callbacks/test_rich_model_summary.py index af385bb1a9b39..4fd22e39babc5 100644 --- a/tests/tests_pytorch/callbacks/test_rich_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_rich_model_summary.py @@ -70,3 +70,34 @@ def example_input_array(self) -> Any: # assert that the input summary data was converted correctly args, _ = mock_table_add_row.call_args_list[0] assert args[1:] == ("0", "layer", "Linear", "66 ", "train", "512 ", "[4, 32]", "[4, 2]") + + +@RunIf(rich=True) +def test_rich_summary_model_size_formatting(): + """Ensure model_size uses get_formatted_model_size, not get_human_readable_count.""" + from io import StringIO + + from rich.console import Console + + model_summary = RichModelSummary() + model = BoringModel() + summary = summarize(model) + summary_data = summary._get_summary_data() + + output = StringIO() + console = Console(file=output, force_terminal=True) + + with mock.patch("rich.get_console", return_value=console): + model_summary.summarize( + summary_data=summary_data, + total_parameters=1, + trainable_parameters=1, + model_size=5500.0, + total_training_modes=summary.total_training_modes, + total_flops=1, + ) + + result = output.getvalue() + # model_size=5500.0 should display as "5,500.000" (formatted), not "5.5 K" (human readable count) + assert "5,500.000" in result + assert "5.5 K" not in result