diff --git a/fme/core/benchmark/run.py b/fme/core/benchmark/run.py index f1fd85b4b..4e2328aed 100644 --- a/fme/core/benchmark/run.py +++ b/fme/core/benchmark/run.py @@ -144,7 +144,7 @@ def get_filename(name, extension) -> pathlib.Path: entity=entity, name=wandb_name, ) - wandb.log(wandb_logs, commit=True) + wandb.log(wandb_logs, step=0, commit=True) return 0 diff --git a/fme/core/disk_metric_logger.py b/fme/core/disk_metric_logger.py new file mode 100644 index 000000000..6a22ae508 --- /dev/null +++ b/fme/core/disk_metric_logger.py @@ -0,0 +1,114 @@ +import io +import json +import logging +import os +from typing import Any + +METRICS_FILENAME = "metrics.jsonl" + + +class DiskMetricLogger: + """Logs scalar metrics to a JSONL file on disk. + + Each line in the file is a JSON object with a "step" key and scalar metric + key-value pairs. On construction, any existing file is read to determine the + high-water mark (maximum step already logged). Subsequent calls to ``log`` + with a step at or below that mark are silently skipped, which makes this + logger safe for job resumption. + + Non-JSON-serializable values (e.g. images, tensors) are silently dropped. + """ + + def __init__(self, directory: str | os.PathLike): + os.makedirs(directory, exist_ok=True) + self._path = os.path.join(directory, METRICS_FILENAME) + self._high_water_mark: int | None = None + self._file: io.TextIOWrapper | None = None + self._read_high_water_mark() + self._file = open(self._path, "a") + + def _read_high_water_mark(self): + """Read the existing file (if any) to find the max step logged.""" + if not os.path.exists(self._path): + return + with open(self._path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + except json.JSONDecodeError: + continue + step = record.get("step") + if isinstance(step, int): + if self._high_water_mark is None or step > self._high_water_mark: + self._high_water_mark = step + + def log(self, data: dict[str, Any], step: int) -> None: + """Log scalar metrics for a given step. + + If ``step`` is at or below the high-water mark from a previous run, + the call is silently skipped. Non-serializable values are dropped. + """ + if self._high_water_mark is not None and step <= self._high_water_mark: + logging.warning( + "DiskMetricLogger: skipping log for step %d " + "(at or below high-water mark %d)", + step, + self._high_water_mark, + ) + return + scalars = _extract_serializable(data) + if not scalars: + return + record = {"step": step, **scalars} + assert self._file is not None + self._file.write(json.dumps(record) + "\n") + self._file.flush() + + def close(self): + if self._file is not None: + self._file.close() + self._file = None + + +def _extract_serializable(data: dict[str, Any]) -> dict[str, Any]: + """Return only JSON-serializable entries from *data*.""" + result: dict[str, Any] = {} + for key, value in data.items(): + if isinstance(value, int | float | bool): + result[key] = value + elif isinstance(value, str): + result[key] = value + else: + try: + json.dumps(value) + except (TypeError, ValueError, OverflowError): + logging.debug( + f"DiskMetricLogger: skipping non-serializable key '{key}'" + ) + else: + result[key] = value + return result + + +def read_metrics(directory: str | os.PathLike) -> list[dict[str, Any]]: + """Read all metric records from a metrics JSONL file. + + Returns a list of dicts, one per logged line, in file order. + """ + path = os.path.join(directory, METRICS_FILENAME) + records: list[dict[str, Any]] = [] + if not os.path.exists(path): + return records + with open(path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + records.append(json.loads(line)) + except json.JSONDecodeError: + continue + return records diff --git a/fme/core/logging_utils.py b/fme/core/logging_utils.py index cedba8511..20dcc0d73 100644 --- a/fme/core/logging_utils.py +++ b/fme/core/logging_utils.py @@ -38,6 +38,8 @@ class LoggingConfig: log_to_screen: Whether to log to the screen. log_to_file: Whether to log to a file. log_to_wandb: Whether to log to Weights & Biases. + metrics_log_dir: Directory to write scalar metrics to disk as JSONL. + If None, disk metric logging is disabled. log_format: Format of the log messages. level: Sets the logging level. wandb_dir_in_experiment_dir: Whether to create the wandb_dir in the @@ -49,6 +51,7 @@ class LoggingConfig: log_to_screen: bool = True log_to_file: bool = True log_to_wandb: bool = True + metrics_log_dir: str | None = None log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" level: str | int = logging.INFO wandb_dir_in_experiment_dir: bool = False @@ -143,7 +146,10 @@ def _configure_wandb( # must ensure wandb.configure is called before wandb.init wandb = WandB.get_instance() - wandb.configure(log_to_wandb=self.log_to_wandb) + wandb.configure( + log_to_wandb=self.log_to_wandb, + metrics_log_dir=self.metrics_log_dir, + ) notes = _get_wandb_notes(_get_beaker_id()) wandb.init( config=config_copy, diff --git a/fme/core/test_disk_metric_logger.py b/fme/core/test_disk_metric_logger.py new file mode 100644 index 000000000..9ff34c12b --- /dev/null +++ b/fme/core/test_disk_metric_logger.py @@ -0,0 +1,211 @@ +import json +import math +import os + +import pytest + +from fme.core.disk_metric_logger import METRICS_FILENAME, DiskMetricLogger, read_metrics + + +@pytest.fixture +def log_dir(tmp_path): + return str(tmp_path / "metrics") + + +def test_creates_directory(log_dir): + assert not os.path.exists(log_dir) + logger = DiskMetricLogger(log_dir) + assert os.path.isdir(log_dir) + logger.close() + + +def test_log_writes_jsonl(log_dir): + logger = DiskMetricLogger(log_dir) + logger.log({"loss": 0.5, "lr": 1e-3}, step=0) + logger.log({"loss": 0.3, "lr": 1e-4}, step=1) + logger.close() + + records = read_metrics(log_dir) + assert len(records) == 2 + assert records[0] == {"step": 0, "loss": 0.5, "lr": 1e-3} + assert records[1] == {"step": 1, "loss": 0.3, "lr": 1e-4} + + +def test_log_flushes_each_line(log_dir): + logger = DiskMetricLogger(log_dir) + logger.log({"loss": 0.5}, step=0) + # Read before close — data should be on disk due to flush + records = read_metrics(log_dir) + assert len(records) == 1 + logger.close() + + +def test_resume_skips_steps_at_or_below_high_water_mark(log_dir): + logger = DiskMetricLogger(log_dir) + logger.log({"loss": 0.5}, step=0) + logger.log({"loss": 0.4}, step=1) + logger.log({"loss": 0.3}, step=2) + logger.close() + + # Simulate resume: re-create logger, re-log steps 1 and 2, + # then continue with step 3 + logger = DiskMetricLogger(log_dir) + logger.log({"loss": 0.45}, step=1) # skipped + logger.log({"loss": 0.35}, step=2) # skipped + logger.log({"loss": 0.2}, step=3) # written + logger.close() + + records = read_metrics(log_dir) + assert len(records) == 4 + # Original steps preserved + assert records[0]["step"] == 0 + assert records[1]["step"] == 1 + assert records[1]["loss"] == 0.4 # original, not overwritten + assert records[2]["step"] == 2 + assert records[2]["loss"] == 0.3 # original, not overwritten + # New step appended + assert records[3] == {"step": 3, "loss": 0.2} + + +def test_resume_catches_up_then_continues(log_dir): + """Steps exactly at the high-water mark are skipped; one above is written.""" + logger = DiskMetricLogger(log_dir) + logger.log({"a": 1}, step=5) + logger.close() + + logger = DiskMetricLogger(log_dir) + logger.log({"a": 2}, step=5) # skipped (== high water mark) + logger.log({"a": 3}, step=6) # written + logger.close() + + records = read_metrics(log_dir) + assert len(records) == 2 + assert records[0] == {"step": 5, "a": 1} + assert records[1] == {"step": 6, "a": 3} + + +def test_non_scalar_values_are_skipped(log_dir): + logger = DiskMetricLogger(log_dir) + + class NotSerializable: + pass + + logger.log( + {"loss": 0.5, "image": NotSerializable(), "count": 10}, + step=0, + ) + logger.close() + + records = read_metrics(log_dir) + assert len(records) == 1 + assert records[0] == {"step": 0, "loss": 0.5, "count": 10} + + +def test_all_non_scalar_skips_entire_line(log_dir): + """If all values are non-serializable, no line is written.""" + + class NotSerializable: + pass + + logger = DiskMetricLogger(log_dir) + logger.log({"image": NotSerializable()}, step=0) + logger.close() + + records = read_metrics(log_dir) + assert len(records) == 0 + + +def test_empty_data_skips_line(log_dir): + logger = DiskMetricLogger(log_dir) + logger.log({}, step=0) + logger.close() + + records = read_metrics(log_dir) + assert len(records) == 0 + + +def test_string_values_are_logged(log_dir): + logger = DiskMetricLogger(log_dir) + logger.log({"phase": "train", "loss": 0.5}, step=0) + logger.close() + + records = read_metrics(log_dir) + assert records[0] == {"step": 0, "phase": "train", "loss": 0.5} + + +def test_bool_values_are_logged(log_dir): + logger = DiskMetricLogger(log_dir) + logger.log({"converged": True}, step=0) + logger.close() + + records = read_metrics(log_dir) + assert records[0] == {"step": 0, "converged": True} + + +def test_no_existing_file(log_dir): + """Logger works when directory exists but no metrics file.""" + os.makedirs(log_dir, exist_ok=True) + logger = DiskMetricLogger(log_dir) + logger.log({"x": 1}, step=0) + logger.close() + + records = read_metrics(log_dir) + assert len(records) == 1 + + +def test_corrupt_line_is_skipped(log_dir): + """A corrupt line in the file doesn't prevent reading or resuming.""" + os.makedirs(log_dir, exist_ok=True) + path = os.path.join(log_dir, METRICS_FILENAME) + with open(path, "w") as f: + f.write(json.dumps({"step": 0, "loss": 0.5}) + "\n") + f.write("NOT VALID JSON\n") + f.write(json.dumps({"step": 2, "loss": 0.3}) + "\n") + + logger = DiskMetricLogger(log_dir) + # High water mark should be 2 despite the corrupt line + logger.log({"loss": 0.1}, step=2) # skipped + logger.log({"loss": 0.05}, step=3) # written + logger.close() + + records = read_metrics(log_dir) + # read_metrics also skips corrupt lines + assert len(records) == 3 + assert records[0]["step"] == 0 + assert records[1]["step"] == 2 + assert records[2]["step"] == 3 + + +def test_read_metrics_empty_directory(tmp_path): + """read_metrics returns empty list when no file exists.""" + assert read_metrics(str(tmp_path)) == [] + + +def test_multiple_logs_same_step(log_dir): + """Multiple log calls at the same step each produce a line.""" + logger = DiskMetricLogger(log_dir) + logger.log({"loss": 0.5}, step=0) + logger.log({"lr": 1e-3}, step=0) + logger.close() + + records = read_metrics(log_dir) + assert len(records) == 2 + assert records[0] == {"step": 0, "loss": 0.5} + assert records[1] == {"step": 0, "lr": 1e-3} + + +def test_nan_and_inf_values_are_logged(log_dir): + """NaN and Inf are valid floats and get logged. + + Note: Python's json.dumps produces non-standard NaN/Infinity tokens. + Use read_metrics (which uses json.loads) to round-trip these values + rather than strict JSON parsers like jq. + """ + logger = DiskMetricLogger(log_dir) + logger.log({"nan_val": float("nan"), "inf_val": float("inf")}, step=0) + logger.close() + + records = read_metrics(log_dir) + assert len(records) == 1 + assert math.isnan(records[0]["nan_val"]) + assert records[0]["inf_val"] == float("inf") diff --git a/fme/core/test_wandb.py b/fme/core/test_wandb.py index 6de42d911..d25238fd4 100644 --- a/fme/core/test_wandb.py +++ b/fme/core/test_wandb.py @@ -1,6 +1,8 @@ import numpy as np import pytest +from fme.core.disk_metric_logger import read_metrics +from fme.core.testing.wandb import mock_wandb from fme.core.wandb import DirectInitializationError, Image, WandB @@ -13,3 +15,54 @@ def test_image_is_image_instance(): def test_wandb_direct_initialization_raises(): with pytest.raises(DirectInitializationError): Image(np.zeros((10, 10))) + + +class TestDiskLoggingIntegration: + def test_metrics_written_to_disk_via_mock_wandb(self, tmp_path): + log_dir = str(tmp_path / "metrics") + with mock_wandb() as wandb: + wandb.configure(log_to_wandb=True, metrics_log_dir=log_dir) + wandb.log({"loss": 0.5, "lr": 1e-3}, step=0) + wandb.log({"loss": 0.3, "lr": 1e-4}, step=1) + + records = read_metrics(log_dir) + assert len(records) == 2 + assert records[0] == {"step": 0, "loss": 0.5, "lr": 1e-3} + assert records[1] == {"step": 1, "loss": 0.3, "lr": 1e-4} + + def test_no_disk_logging_when_dir_is_none(self, tmp_path): + with mock_wandb() as wandb: + wandb.configure(log_to_wandb=True, metrics_log_dir=None) + wandb.log({"loss": 0.5}, step=0) + assert wandb._disk_logger is None + + def test_disk_logging_resume_skips_old_steps(self, tmp_path): + log_dir = str(tmp_path / "metrics") + with mock_wandb() as wandb: + wandb.configure(log_to_wandb=True, metrics_log_dir=log_dir) + wandb.log({"loss": 0.5}, step=0) + wandb.log({"loss": 0.3}, step=1) + + # Simulate resume + with mock_wandb() as wandb: + wandb.configure(log_to_wandb=True, metrics_log_dir=log_dir) + wandb.log({"loss": 0.45}, step=0) # skipped + wandb.log({"loss": 0.35}, step=1) # skipped + wandb.log({"loss": 0.2}, step=2) # written + + records = read_metrics(log_dir) + assert len(records) == 3 + assert records[0]["loss"] == 0.5 # original preserved + assert records[1]["loss"] == 0.3 # original preserved + assert records[2] == {"step": 2, "loss": 0.2} + + def test_disk_logging_independent_of_wandb_enabled(self, tmp_path): + """Disk logging works even when log_to_wandb is False.""" + log_dir = str(tmp_path / "metrics") + with mock_wandb() as wandb: + wandb.configure(log_to_wandb=False, metrics_log_dir=log_dir) + wandb.log({"loss": 0.5}, step=0) + + records = read_metrics(log_dir) + assert len(records) == 1 + assert records[0] == {"step": 0, "loss": 0.5} diff --git a/fme/core/testing/wandb.py b/fme/core/testing/wandb.py index 3e4a21340..b42ff5783 100644 --- a/fme/core/testing/wandb.py +++ b/fme/core/testing/wandb.py @@ -6,6 +6,7 @@ from typing import Any, Literal from fme.core import wandb +from fme.core.disk_metric_logger import DiskMetricLogger from fme.core.distributed import Distributed @@ -16,11 +17,14 @@ def __init__(self): self._logs: dict[int, dict[str, Any]] = collections.defaultdict(dict) self._last_step = 0 self._id: str | None = None + self._disk_logger: DiskMetricLogger | None = None - def configure(self, log_to_wandb: bool): + def configure(self, log_to_wandb: bool, metrics_log_dir: str | None = None): dist = Distributed.get_instance() self._enabled = log_to_wandb and dist.is_root() self._configured = True + if metrics_log_dir is not None and dist.is_root(): + self._disk_logger = DiskMetricLogger(metrics_log_dir) def init( self, @@ -96,6 +100,8 @@ def log(self, data: Mapping[str, Any], step: int, sleep=None): # sleep arg is ignored since we don't want to sleep in tests if self._enabled: self._logs[step].update(data) + if self._disk_logger is not None: + self._disk_logger.log(dict(data), step=step) def get_logs(self) -> list[dict[str, Any]]: if len(self._logs) == 0: @@ -138,10 +144,13 @@ def mock_wandb(): the given fill_value, which can be checked for in tests. """ original = wandb.singleton - wandb.singleton = MockWandB() # type: ignore + mock = MockWandB() + wandb.singleton = mock # type: ignore try: - yield wandb.singleton + yield mock finally: + if mock._disk_logger is not None: + mock._disk_logger.close() wandb.singleton = original diff --git a/fme/core/wandb.py b/fme/core/wandb.py index bed5b6cac..a3169718f 100644 --- a/fme/core/wandb.py +++ b/fme/core/wandb.py @@ -7,6 +7,7 @@ import numpy as np import wandb +from fme.core.disk_metric_logger import DiskMetricLogger from fme.core.distributed import Distributed WANDB_RUN_ID_FILE = "wandb_run_id" @@ -111,11 +112,14 @@ def __init__(self): self._enabled = False self._configured = False self._id = None + self._disk_logger: DiskMetricLogger | None = None - def configure(self, log_to_wandb: bool): + def configure(self, log_to_wandb: bool, metrics_log_dir: str | None = None): dist = Distributed.get_instance() self._enabled = log_to_wandb and dist.is_root() self._configured = True + if metrics_log_dir is not None and dist.is_root(): + self._disk_logger = DiskMetricLogger(metrics_log_dir) def init( self, @@ -161,12 +165,18 @@ def watch(self, modules): wandb.watch(modules) def log( - self, data: Mapping[str, Any], step=None, sleep=None, commit: bool | None = None + self, + data: Mapping[str, Any], + step: int, + sleep: float | None = None, + commit: bool | None = None, ): if self._enabled: wandb.log(dict(data), step=step, commit=commit) if sleep is not None: time.sleep(sleep) + if self._disk_logger is not None: + self._disk_logger.log(dict(data), step=step) dist = Distributed.get_instance() dist.barrier()