Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fme/core/benchmark/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_filename(name, extension) -> pathlib.Path:
entity=entity,
name=wandb_name,
)
wandb.log(wandb_logs, commit=True)
wandb.log(wandb_logs, step=0, commit=True)
return 0


Expand Down
114 changes: 114 additions & 0 deletions fme/core/disk_metric_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import io
import json
import logging
import os
from typing import Any

METRICS_FILENAME = "metrics.jsonl"


class DiskMetricLogger:
"""Logs scalar metrics to a JSONL file on disk.

Each line in the file is a JSON object with a "step" key and scalar metric
key-value pairs. On construction, any existing file is read to determine the
high-water mark (maximum step already logged). Subsequent calls to ``log``
with a step at or below that mark are silently skipped, which makes this
logger safe for job resumption.

Non-JSON-serializable values (e.g. images, tensors) are silently dropped.
"""

def __init__(self, directory: str | os.PathLike):
os.makedirs(directory, exist_ok=True)
self._path = os.path.join(directory, METRICS_FILENAME)
self._high_water_mark: int | None = None
self._file: io.TextIOWrapper | None = None
self._read_high_water_mark()
self._file = open(self._path, "a")

def _read_high_water_mark(self):
"""Read the existing file (if any) to find the max step logged."""
if not os.path.exists(self._path):
return
with open(self._path) as f:
for line in f:
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
except json.JSONDecodeError:
continue
step = record.get("step")
if isinstance(step, int):
if self._high_water_mark is None or step > self._high_water_mark:
self._high_water_mark = step

def log(self, data: dict[str, Any], step: int) -> None:
"""Log scalar metrics for a given step.

If ``step`` is at or below the high-water mark from a previous run,
the call is silently skipped. Non-serializable values are dropped.
"""
if self._high_water_mark is not None and step <= self._high_water_mark:
logging.warning(
"DiskMetricLogger: skipping log for step %d "
"(at or below high-water mark %d)",
step,
self._high_water_mark,
)
return
Comment on lines +54 to +61
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WandB logging to a past step issues a warning. I think it would be good to do the same here for visibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, added a logging.warning with the step and high-water mark values.

scalars = _extract_serializable(data)
if not scalars:
return
record = {"step": step, **scalars}
assert self._file is not None
self._file.write(json.dumps(record) + "\n")
self._file.flush()

def close(self):
if self._file is not None:
self._file.close()
self._file = None


def _extract_serializable(data: dict[str, Any]) -> dict[str, Any]:
"""Return only JSON-serializable entries from *data*."""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, int | float | bool):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From agent review:

In _extract_scalars, isinstance(value, float) matches NaN and Inf, which bypass the json.dumps validation path. Python's json.dumps produces non-standard NaN / Infinity tokens (allowed by default via allow_nan=True), making the JSONL unreadable by strict JSON parsers. ML training metrics can legitimately be NaN (e.g., loss explosion).

Apparently this may lead to parsing errors for certain versions of jq and possibly for some pandas.read_json usages. I don't think we should change the behavior here, but maybe a brief note in the docs for DiskMetricLogger would be helpful to say that these values are possible or maybe just encourage use of the read_metrics util.

I guess NaN / Inf values are also not included in tests, might be worthwhile to add them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for NaN/Inf round-tripping through read_metrics, with a docstring noting the strict-parser caveat.

result[key] = value
elif isinstance(value, str):
result[key] = value
else:
try:
json.dumps(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From agent review:

The fallback json.dumps(value) path accepts lists, dicts, None, etc. -- not just scalars. Consider renaming to _extract_serializable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, renamed to _extract_serializable.

except (TypeError, ValueError, OverflowError):
logging.debug(
f"DiskMetricLogger: skipping non-serializable key '{key}'"
)
else:
result[key] = value
return result


def read_metrics(directory: str | os.PathLike) -> list[dict[str, Any]]:
"""Read all metric records from a metrics JSONL file.

Returns a list of dicts, one per logged line, in file order.
"""
path = os.path.join(directory, METRICS_FILENAME)
records: list[dict[str, Any]] = []
if not os.path.exists(path):
return records
with open(path) as f:
for line in f:
line = line.strip()
if not line:
continue
try:
records.append(json.loads(line))
except json.JSONDecodeError:
continue
return records
8 changes: 7 additions & 1 deletion fme/core/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class LoggingConfig:
log_to_screen: Whether to log to the screen.
log_to_file: Whether to log to a file.
log_to_wandb: Whether to log to Weights & Biases.
metrics_log_dir: Directory to write scalar metrics to disk as JSONL.
If None, disk metric logging is disabled.
log_format: Format of the log messages.
level: Sets the logging level.
wandb_dir_in_experiment_dir: Whether to create the wandb_dir in the
Expand All @@ -49,6 +51,7 @@ class LoggingConfig:
log_to_screen: bool = True
log_to_file: bool = True
log_to_wandb: bool = True
metrics_log_dir: str | None = None
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
level: str | int = logging.INFO
wandb_dir_in_experiment_dir: bool = False
Expand Down Expand Up @@ -143,7 +146,10 @@ def _configure_wandb(

# must ensure wandb.configure is called before wandb.init
wandb = WandB.get_instance()
wandb.configure(log_to_wandb=self.log_to_wandb)
wandb.configure(
log_to_wandb=self.log_to_wandb,
metrics_log_dir=self.metrics_log_dir,
)
notes = _get_wandb_notes(_get_beaker_id())
wandb.init(
config=config_copy,
Expand Down
211 changes: 211 additions & 0 deletions fme/core/test_disk_metric_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import json
import math
import os

import pytest

from fme.core.disk_metric_logger import METRICS_FILENAME, DiskMetricLogger, read_metrics


@pytest.fixture
def log_dir(tmp_path):
return str(tmp_path / "metrics")


def test_creates_directory(log_dir):
assert not os.path.exists(log_dir)
logger = DiskMetricLogger(log_dir)
assert os.path.isdir(log_dir)
logger.close()


def test_log_writes_jsonl(log_dir):
logger = DiskMetricLogger(log_dir)
logger.log({"loss": 0.5, "lr": 1e-3}, step=0)
logger.log({"loss": 0.3, "lr": 1e-4}, step=1)
logger.close()

records = read_metrics(log_dir)
assert len(records) == 2
assert records[0] == {"step": 0, "loss": 0.5, "lr": 1e-3}
assert records[1] == {"step": 1, "loss": 0.3, "lr": 1e-4}


def test_log_flushes_each_line(log_dir):
logger = DiskMetricLogger(log_dir)
logger.log({"loss": 0.5}, step=0)
# Read before close — data should be on disk due to flush
records = read_metrics(log_dir)
assert len(records) == 1
logger.close()


def test_resume_skips_steps_at_or_below_high_water_mark(log_dir):
logger = DiskMetricLogger(log_dir)
logger.log({"loss": 0.5}, step=0)
logger.log({"loss": 0.4}, step=1)
logger.log({"loss": 0.3}, step=2)
logger.close()

# Simulate resume: re-create logger, re-log steps 1 and 2,
# then continue with step 3
logger = DiskMetricLogger(log_dir)
logger.log({"loss": 0.45}, step=1) # skipped
logger.log({"loss": 0.35}, step=2) # skipped
logger.log({"loss": 0.2}, step=3) # written
logger.close()

records = read_metrics(log_dir)
assert len(records) == 4
# Original steps preserved
assert records[0]["step"] == 0
assert records[1]["step"] == 1
assert records[1]["loss"] == 0.4 # original, not overwritten
assert records[2]["step"] == 2
assert records[2]["loss"] == 0.3 # original, not overwritten
# New step appended
assert records[3] == {"step": 3, "loss": 0.2}


def test_resume_catches_up_then_continues(log_dir):
"""Steps exactly at the high-water mark are skipped; one above is written."""
logger = DiskMetricLogger(log_dir)
logger.log({"a": 1}, step=5)
logger.close()

logger = DiskMetricLogger(log_dir)
logger.log({"a": 2}, step=5) # skipped (== high water mark)
logger.log({"a": 3}, step=6) # written
logger.close()

records = read_metrics(log_dir)
assert len(records) == 2
assert records[0] == {"step": 5, "a": 1}
assert records[1] == {"step": 6, "a": 3}


def test_non_scalar_values_are_skipped(log_dir):
logger = DiskMetricLogger(log_dir)

class NotSerializable:
pass

logger.log(
{"loss": 0.5, "image": NotSerializable(), "count": 10},
step=0,
)
logger.close()

records = read_metrics(log_dir)
assert len(records) == 1
assert records[0] == {"step": 0, "loss": 0.5, "count": 10}


def test_all_non_scalar_skips_entire_line(log_dir):
"""If all values are non-serializable, no line is written."""

class NotSerializable:
pass

logger = DiskMetricLogger(log_dir)
logger.log({"image": NotSerializable()}, step=0)
logger.close()

records = read_metrics(log_dir)
assert len(records) == 0


def test_empty_data_skips_line(log_dir):
logger = DiskMetricLogger(log_dir)
logger.log({}, step=0)
logger.close()

records = read_metrics(log_dir)
assert len(records) == 0


def test_string_values_are_logged(log_dir):
logger = DiskMetricLogger(log_dir)
logger.log({"phase": "train", "loss": 0.5}, step=0)
logger.close()

records = read_metrics(log_dir)
assert records[0] == {"step": 0, "phase": "train", "loss": 0.5}


def test_bool_values_are_logged(log_dir):
logger = DiskMetricLogger(log_dir)
logger.log({"converged": True}, step=0)
logger.close()

records = read_metrics(log_dir)
assert records[0] == {"step": 0, "converged": True}


def test_no_existing_file(log_dir):
"""Logger works when directory exists but no metrics file."""
os.makedirs(log_dir, exist_ok=True)
logger = DiskMetricLogger(log_dir)
logger.log({"x": 1}, step=0)
logger.close()

records = read_metrics(log_dir)
assert len(records) == 1


def test_corrupt_line_is_skipped(log_dir):
"""A corrupt line in the file doesn't prevent reading or resuming."""
os.makedirs(log_dir, exist_ok=True)
path = os.path.join(log_dir, METRICS_FILENAME)
with open(path, "w") as f:
f.write(json.dumps({"step": 0, "loss": 0.5}) + "\n")
f.write("NOT VALID JSON\n")
f.write(json.dumps({"step": 2, "loss": 0.3}) + "\n")

logger = DiskMetricLogger(log_dir)
# High water mark should be 2 despite the corrupt line
logger.log({"loss": 0.1}, step=2) # skipped
logger.log({"loss": 0.05}, step=3) # written
logger.close()

records = read_metrics(log_dir)
# read_metrics also skips corrupt lines
assert len(records) == 3
assert records[0]["step"] == 0
assert records[1]["step"] == 2
assert records[2]["step"] == 3


def test_read_metrics_empty_directory(tmp_path):
"""read_metrics returns empty list when no file exists."""
assert read_metrics(str(tmp_path)) == []


def test_multiple_logs_same_step(log_dir):
"""Multiple log calls at the same step each produce a line."""
logger = DiskMetricLogger(log_dir)
logger.log({"loss": 0.5}, step=0)
logger.log({"lr": 1e-3}, step=0)
logger.close()

records = read_metrics(log_dir)
assert len(records) == 2
assert records[0] == {"step": 0, "loss": 0.5}
assert records[1] == {"step": 0, "lr": 1e-3}


def test_nan_and_inf_values_are_logged(log_dir):
"""NaN and Inf are valid floats and get logged.

Note: Python's json.dumps produces non-standard NaN/Infinity tokens.
Use read_metrics (which uses json.loads) to round-trip these values
rather than strict JSON parsers like jq.
"""
logger = DiskMetricLogger(log_dir)
logger.log({"nan_val": float("nan"), "inf_val": float("inf")}, step=0)
logger.close()

records = read_metrics(log_dir)
assert len(records) == 1
assert math.isnan(records[0]["nan_val"])
assert records[0]["inf_val"] == float("inf")
Loading
Loading