Skip to content
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
958f95e
feat: introduce `logger_map` property.
GdoongMathew Oct 19, 2025
711fb4f
revert trainer.logger change.
GdoongMathew Oct 19, 2025
791dbdc
add tests.
GdoongMathew Oct 19, 2025
db51529
add test.
GdoongMathew Oct 19, 2025
7cb1382
add test.
GdoongMathew Oct 19, 2025
b1ea7d3
fix pylint
GdoongMathew Oct 19, 2025
6bbc98d
fix pylint
GdoongMathew Oct 19, 2025
56ea5e8
add test.
GdoongMathew Oct 19, 2025
d031985
refactor loggers setter.
GdoongMathew Oct 19, 2025
e8d3695
fix pylint.
GdoongMathew Oct 19, 2025
01aaa41
_ListMap integration.
GdoongMathew Oct 21, 2025
e5a38ed
Merge branch 'master' into feat/logger_dict
GdoongMathew Oct 22, 2025
022fa92
fix: fix unittests.
GdoongMathew Oct 22, 2025
c309642
fix pylint.
GdoongMathew Oct 22, 2025
67b888f
add reverse impl.
GdoongMathew Oct 22, 2025
b3a3a70
Merge remote-tracking branch 'origin/feat/logger_dict' into feat/logg…
GdoongMathew Oct 22, 2025
3e9e398
implement list methods.
GdoongMathew Oct 22, 2025
fa83ab2
implement get method.
GdoongMathew Oct 22, 2025
085f167
adding notes.
GdoongMathew Oct 22, 2025
0d8f725
refactor
GdoongMathew Oct 22, 2025
a393ae3
Merge branch 'master' into feat/logger_dict
GdoongMathew Oct 22, 2025
0e14e09
docs
GdoongMathew Oct 23, 2025
2d9f419
test: add additional unittests.
GdoongMathew Oct 23, 2025
b78daea
fix: fix delete implementation.
GdoongMathew Oct 23, 2025
c371b20
docs: fix doctest.
GdoongMathew Oct 23, 2025
41f4311
add unittest case.
GdoongMathew Oct 23, 2025
172ceb3
fix: fix mypy
GdoongMathew Oct 23, 2025
fffe03b
test
GdoongMathew Oct 23, 2025
c45fa9b
fix mypy
GdoongMathew Oct 24, 2025
ec39fe5
fix mypy
GdoongMathew Oct 24, 2025
69a2ef3
Merge remote-tracking branch 'origin/feat/logger_dict' into feat/logg…
GdoongMathew Oct 24, 2025
a2709c2
ref: refactor __delitem__
GdoongMathew Oct 24, 2025
9d0d39d
fix: mypy
GdoongMathew Oct 24, 2025
01b9247
Merge branch 'master' into feat/logger_dict
GdoongMathew Oct 24, 2025
d1526f7
Merge branch 'master' into feat/logger_dict
GdoongMathew Oct 31, 2025
5edf4b1
fix type annotation
GdoongMathew Oct 31, 2025
76b5311
fix typecheck
GdoongMathew Oct 31, 2025
8da3ea4
fix typecheck
GdoongMathew Oct 31, 2025
9aefdde
fix typecheck
GdoongMathew Oct 31, 2025
535345e
ignore override
GdoongMathew Oct 31, 2025
8266afb
Merge branch 'master' into feat/logger_dict
GdoongMathew Nov 4, 2025
3211f15
refactor __eq__
GdoongMathew Nov 4, 2025
2dff765
refactor __setitem__
GdoongMathew Nov 10, 2025
483ca4a
fix insert implementation and add unittests
GdoongMathew Nov 10, 2025
6ceb93a
Merge branch 'master' into feat/logger_dict
GdoongMathew Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from typing import Any, Optional, Union

from lightning_utilities.core.apply_func import apply_to_collection
Expand Down Expand Up @@ -82,6 +82,8 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non
)
logger_ = CSVLogger(save_dir=self.trainer.default_root_dir) # type: ignore[assignment]
self.trainer.loggers = [logger_]
elif isinstance(logger, Mapping):
self.trainer.loggers = logger
elif isinstance(logger, Iterable):
self.trainer.loggers = list(logger)
else:
Expand Down
24 changes: 21 additions & 3 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import logging
import math
import os
from collections.abc import Generator, Iterable
from collections.abc import Generator, Iterable, Mapping
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, Optional, Union
Expand Down Expand Up @@ -1643,8 +1643,26 @@ def loggers(self) -> list[Logger]:
return self._loggers

@loggers.setter
def loggers(self, loggers: Optional[list[Logger]]) -> None:
self._loggers = loggers if loggers else []
def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger]]]) -> None:
self._logger_keys: list[Union[str, int]]
if isinstance(loggers, Mapping):
self._loggers = list(loggers.values())
self._logger_keys = list(loggers.keys())
else:
self._loggers = loggers if loggers else []
self._logger_keys = []

@property
def logger_map(self) -> dict[Union[str, int], Logger]:
"""A mapping of logger keys to :class:`~lightning.pytorch.loggers.logger.Logger` used.

.. code-block:: python
tb_logger = trainer.logger_map.get("tensorboard", None)
if tb_logger:
tb_logger.log_hyperparams({"lr": 0.001})

"""
return dict(zip(self._logger_keys, self._loggers))

@property
def callback_metrics(self) -> _OUT_DICT:
Expand Down
37 changes: 37 additions & 0 deletions tests/tests_pytorch/trainer/properties/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,47 @@ def test_trainer_loggers_property():
"""Test for correct initialization of loggers in Trainer."""
logger1 = CustomLogger()
logger2 = CustomLogger()
logger3 = CustomLogger()

# trainer.loggers should be a copy of the input list
trainer = Trainer(logger=[logger1, logger2])

assert trainer.loggers == [logger1, logger2]
assert trainer.logger_map == {}

# trainer.loggers should create a list of size 1
trainer = Trainer(logger=logger1)

assert trainer.logger == logger1
assert trainer.loggers == [logger1]
assert trainer.logger_map == {}

trainer.loggers.append(logger2)
assert trainer.loggers == [logger1, logger2]
assert trainer.logger_map == {}

# trainer.loggers should be a list of size 1 holding the default logger
trainer = Trainer(logger=True)

assert trainer.loggers == [trainer.logger]
assert isinstance(trainer.logger, TensorBoardLogger)

trainer = Trainer(logger={"log1": logger1, "log2": logger2})
assert trainer.logger == logger1
assert trainer.loggers == [logger1, logger2]
assert isinstance(trainer.logger_map, dict)
assert trainer.logger_map == {"log1": logger1, "log2": logger2}

trainer.loggers.append(logger3)
assert trainer.loggers == [logger1, logger2, logger3]
assert trainer.logger_map == {"log1": logger1, "log2": logger2}


def test_trainer_loggers_setters():
"""Test the behavior of setters for trainer.logger and trainer.loggers."""
logger1 = CustomLogger()
logger2 = CustomLogger()
logger3 = CustomLogger()

trainer = Trainer()
assert type(trainer.logger) is TensorBoardLogger
Expand All @@ -59,22 +77,40 @@ def test_trainer_loggers_setters():
trainer.logger = None
assert trainer.logger is None
assert trainer.loggers == []
assert isinstance(trainer.loggers, list)
assert trainer.logger_map == {}

# Test setters for trainer.loggers
trainer.loggers = [logger1, logger2]
assert trainer.loggers == [logger1, logger2]
assert isinstance(trainer.loggers, list)
assert trainer.logger_map == {}

trainer.loggers = [logger1]
assert trainer.loggers == [logger1]
assert trainer.logger == logger1
assert trainer.logger_map == {}

trainer.loggers = []
assert trainer.loggers == []
assert trainer.logger is None
assert isinstance(trainer.loggers, list)
assert trainer.logger_map == {}

trainer.loggers = None
assert trainer.loggers == []
assert trainer.logger is None
assert isinstance(trainer.loggers, list)
assert trainer.logger_map == {}

trainer.loggers = {"log1": logger1, "log2": logger2}
assert trainer.loggers == [logger1, logger2]
assert isinstance(trainer.loggers, list)
assert isinstance(trainer.logger_map, dict)
assert trainer.logger_map == {"log1": logger1, "log2": logger2}

trainer.loggers.append(logger3)
assert trainer.logger_map == {"log1": logger1, "log2": logger2}


@pytest.mark.parametrize(
Expand All @@ -94,3 +130,4 @@ def test_no_logger(tmp_path, logger_value):
assert trainer.logger is None
assert trainer.loggers == []
assert trainer.log_dir == str(tmp_path)
assert trainer.logger_map == {}
Loading