-
Notifications
You must be signed in to change notification settings - Fork 38
Add metrics_log_dir option to LoggingConfig #992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1f1b41c
7fc7bab
0612103
15c7827
5dcd298
9890c8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| import io | ||
| import json | ||
| import logging | ||
| import os | ||
| from typing import Any | ||
|
|
||
| METRICS_FILENAME = "metrics.jsonl" | ||
|
|
||
|
|
||
| class DiskMetricLogger: | ||
| """Logs scalar metrics to a JSONL file on disk. | ||
|
|
||
| Each line in the file is a JSON object with a "step" key and scalar metric | ||
| key-value pairs. On construction, any existing file is read to determine the | ||
| high-water mark (maximum step already logged). Subsequent calls to ``log`` | ||
| with a step at or below that mark are silently skipped, which makes this | ||
| logger safe for job resumption. | ||
|
|
||
| Non-JSON-serializable values (e.g. images, tensors) are silently dropped. | ||
| """ | ||
|
|
||
| def __init__(self, directory: str | os.PathLike): | ||
| os.makedirs(directory, exist_ok=True) | ||
| self._path = os.path.join(directory, METRICS_FILENAME) | ||
| self._high_water_mark: int | None = None | ||
| self._file: io.TextIOWrapper | None = None | ||
| self._read_high_water_mark() | ||
| self._file = open(self._path, "a") | ||
|
|
||
| def _read_high_water_mark(self): | ||
| """Read the existing file (if any) to find the max step logged.""" | ||
| if not os.path.exists(self._path): | ||
| return | ||
| with open(self._path) as f: | ||
| for line in f: | ||
| line = line.strip() | ||
| if not line: | ||
| continue | ||
| try: | ||
| record = json.loads(line) | ||
| except json.JSONDecodeError: | ||
| continue | ||
| step = record.get("step") | ||
| if isinstance(step, int): | ||
| if self._high_water_mark is None or step > self._high_water_mark: | ||
| self._high_water_mark = step | ||
|
|
||
| def log(self, data: dict[str, Any], step: int) -> None: | ||
| """Log scalar metrics for a given step. | ||
|
|
||
| If ``step`` is at or below the high-water mark from a previous run, | ||
| the call is silently skipped. Non-serializable values are dropped. | ||
| """ | ||
| if self._high_water_mark is not None and step <= self._high_water_mark: | ||
| logging.warning( | ||
| "DiskMetricLogger: skipping log for step %d " | ||
| "(at or below high-water mark %d)", | ||
| step, | ||
| self._high_water_mark, | ||
| ) | ||
| return | ||
| scalars = _extract_serializable(data) | ||
| if not scalars: | ||
| return | ||
| record = {"step": step, **scalars} | ||
| assert self._file is not None | ||
| self._file.write(json.dumps(record) + "\n") | ||
| self._file.flush() | ||
|
|
||
| def close(self): | ||
| if self._file is not None: | ||
| self._file.close() | ||
| self._file = None | ||
|
|
||
|
|
||
| def _extract_serializable(data: dict[str, Any]) -> dict[str, Any]: | ||
| """Return only JSON-serializable entries from *data*.""" | ||
| result: dict[str, Any] = {} | ||
| for key, value in data.items(): | ||
| if isinstance(value, int | float | bool): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From agent review:
Apparently this may lead to parsing errors for certain versions of I guess NaN / Inf values are also not included in tests, might be worthwhile to add them.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a test for NaN/Inf round-tripping through read_metrics, with a docstring noting the strict-parser caveat. |
||
| result[key] = value | ||
| elif isinstance(value, str): | ||
| result[key] = value | ||
| else: | ||
| try: | ||
| json.dumps(value) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From agent review:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, renamed to _extract_serializable. |
||
| except (TypeError, ValueError, OverflowError): | ||
| logging.debug( | ||
| f"DiskMetricLogger: skipping non-serializable key '{key}'" | ||
| ) | ||
| else: | ||
| result[key] = value | ||
| return result | ||
|
|
||
|
|
||
| def read_metrics(directory: str | os.PathLike) -> list[dict[str, Any]]: | ||
| """Read all metric records from a metrics JSONL file. | ||
|
|
||
| Returns a list of dicts, one per logged line, in file order. | ||
| """ | ||
| path = os.path.join(directory, METRICS_FILENAME) | ||
| records: list[dict[str, Any]] = [] | ||
| if not os.path.exists(path): | ||
| return records | ||
| with open(path) as f: | ||
| for line in f: | ||
| line = line.strip() | ||
| if not line: | ||
| continue | ||
| try: | ||
| records.append(json.loads(line)) | ||
| except json.JSONDecodeError: | ||
| continue | ||
| return records | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,211 @@ | ||
| import json | ||
| import math | ||
| import os | ||
|
|
||
| import pytest | ||
|
|
||
| from fme.core.disk_metric_logger import METRICS_FILENAME, DiskMetricLogger, read_metrics | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def log_dir(tmp_path): | ||
| return str(tmp_path / "metrics") | ||
|
|
||
|
|
||
| def test_creates_directory(log_dir): | ||
| assert not os.path.exists(log_dir) | ||
| logger = DiskMetricLogger(log_dir) | ||
| assert os.path.isdir(log_dir) | ||
| logger.close() | ||
|
|
||
|
|
||
| def test_log_writes_jsonl(log_dir): | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"loss": 0.5, "lr": 1e-3}, step=0) | ||
| logger.log({"loss": 0.3, "lr": 1e-4}, step=1) | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert len(records) == 2 | ||
| assert records[0] == {"step": 0, "loss": 0.5, "lr": 1e-3} | ||
| assert records[1] == {"step": 1, "loss": 0.3, "lr": 1e-4} | ||
|
|
||
|
|
||
| def test_log_flushes_each_line(log_dir): | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"loss": 0.5}, step=0) | ||
| # Read before close — data should be on disk due to flush | ||
| records = read_metrics(log_dir) | ||
| assert len(records) == 1 | ||
| logger.close() | ||
|
|
||
|
|
||
| def test_resume_skips_steps_at_or_below_high_water_mark(log_dir): | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"loss": 0.5}, step=0) | ||
| logger.log({"loss": 0.4}, step=1) | ||
| logger.log({"loss": 0.3}, step=2) | ||
| logger.close() | ||
|
|
||
| # Simulate resume: re-create logger, re-log steps 1 and 2, | ||
| # then continue with step 3 | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"loss": 0.45}, step=1) # skipped | ||
| logger.log({"loss": 0.35}, step=2) # skipped | ||
| logger.log({"loss": 0.2}, step=3) # written | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert len(records) == 4 | ||
| # Original steps preserved | ||
| assert records[0]["step"] == 0 | ||
| assert records[1]["step"] == 1 | ||
| assert records[1]["loss"] == 0.4 # original, not overwritten | ||
| assert records[2]["step"] == 2 | ||
| assert records[2]["loss"] == 0.3 # original, not overwritten | ||
| # New step appended | ||
| assert records[3] == {"step": 3, "loss": 0.2} | ||
|
|
||
|
|
||
| def test_resume_catches_up_then_continues(log_dir): | ||
| """Steps exactly at the high-water mark are skipped; one above is written.""" | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"a": 1}, step=5) | ||
| logger.close() | ||
|
|
||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"a": 2}, step=5) # skipped (== high water mark) | ||
| logger.log({"a": 3}, step=6) # written | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert len(records) == 2 | ||
| assert records[0] == {"step": 5, "a": 1} | ||
| assert records[1] == {"step": 6, "a": 3} | ||
|
|
||
|
|
||
| def test_non_scalar_values_are_skipped(log_dir): | ||
| logger = DiskMetricLogger(log_dir) | ||
|
|
||
| class NotSerializable: | ||
| pass | ||
|
|
||
| logger.log( | ||
| {"loss": 0.5, "image": NotSerializable(), "count": 10}, | ||
| step=0, | ||
| ) | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert len(records) == 1 | ||
| assert records[0] == {"step": 0, "loss": 0.5, "count": 10} | ||
|
|
||
|
|
||
| def test_all_non_scalar_skips_entire_line(log_dir): | ||
| """If all values are non-serializable, no line is written.""" | ||
|
|
||
| class NotSerializable: | ||
| pass | ||
|
|
||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"image": NotSerializable()}, step=0) | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert len(records) == 0 | ||
|
|
||
|
|
||
| def test_empty_data_skips_line(log_dir): | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({}, step=0) | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert len(records) == 0 | ||
|
|
||
|
|
||
| def test_string_values_are_logged(log_dir): | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"phase": "train", "loss": 0.5}, step=0) | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert records[0] == {"step": 0, "phase": "train", "loss": 0.5} | ||
|
|
||
|
|
||
| def test_bool_values_are_logged(log_dir): | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"converged": True}, step=0) | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert records[0] == {"step": 0, "converged": True} | ||
|
|
||
|
|
||
| def test_no_existing_file(log_dir): | ||
| """Logger works when directory exists but no metrics file.""" | ||
| os.makedirs(log_dir, exist_ok=True) | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"x": 1}, step=0) | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert len(records) == 1 | ||
|
|
||
|
|
||
| def test_corrupt_line_is_skipped(log_dir): | ||
| """A corrupt line in the file doesn't prevent reading or resuming.""" | ||
| os.makedirs(log_dir, exist_ok=True) | ||
| path = os.path.join(log_dir, METRICS_FILENAME) | ||
| with open(path, "w") as f: | ||
| f.write(json.dumps({"step": 0, "loss": 0.5}) + "\n") | ||
| f.write("NOT VALID JSON\n") | ||
| f.write(json.dumps({"step": 2, "loss": 0.3}) + "\n") | ||
|
|
||
| logger = DiskMetricLogger(log_dir) | ||
| # High water mark should be 2 despite the corrupt line | ||
| logger.log({"loss": 0.1}, step=2) # skipped | ||
| logger.log({"loss": 0.05}, step=3) # written | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| # read_metrics also skips corrupt lines | ||
| assert len(records) == 3 | ||
| assert records[0]["step"] == 0 | ||
| assert records[1]["step"] == 2 | ||
| assert records[2]["step"] == 3 | ||
|
|
||
|
|
||
| def test_read_metrics_empty_directory(tmp_path): | ||
| """read_metrics returns empty list when no file exists.""" | ||
| assert read_metrics(str(tmp_path)) == [] | ||
|
|
||
|
|
||
| def test_multiple_logs_same_step(log_dir): | ||
| """Multiple log calls at the same step each produce a line.""" | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"loss": 0.5}, step=0) | ||
| logger.log({"lr": 1e-3}, step=0) | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert len(records) == 2 | ||
| assert records[0] == {"step": 0, "loss": 0.5} | ||
| assert records[1] == {"step": 0, "lr": 1e-3} | ||
|
|
||
|
|
||
| def test_nan_and_inf_values_are_logged(log_dir): | ||
| """NaN and Inf are valid floats and get logged. | ||
|
|
||
| Note: Python's json.dumps produces non-standard NaN/Infinity tokens. | ||
| Use read_metrics (which uses json.loads) to round-trip these values | ||
| rather than strict JSON parsers like jq. | ||
| """ | ||
| logger = DiskMetricLogger(log_dir) | ||
| logger.log({"nan_val": float("nan"), "inf_val": float("inf")}, step=0) | ||
| logger.close() | ||
|
|
||
| records = read_metrics(log_dir) | ||
| assert len(records) == 1 | ||
| assert math.isnan(records[0]["nan_val"]) | ||
| assert records[0]["inf_val"] == float("inf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WandB logging to a past step issues a warning. I think it would be good to do the same here for visibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, added a logging.warning with the step and high-water mark values.