Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion packages/evaluate/src/weathergen/evaluate/io_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
# nor does it submit to any jurisdiction.

import logging
import re
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import omegaconf as oc
import pandas as pd
import xarray as xr
from tqdm import tqdm

Expand Down Expand Up @@ -82,7 +84,6 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non
self.eval_cfg = eval_cfg
self.run_id = run_id
self.private_paths = private_paths

self.streams = eval_cfg.streams.keys()

# If results_base_dir and model_base_dir are not provided, default paths are used
Expand Down Expand Up @@ -269,6 +270,89 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili
)


class CsvReader(Reader):
"""
Reader class to read evaluation data from CSV files and convert to xarray DataArray.
"""

def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
"""
Initialize the CsvReader.

Parameters
----------
eval_cfg : dir
config with plotting and evaluation options for that run id
run_id : str
run id of the model
private_paths: lists
list of private paths for the supported HPC
"""

super().__init__(eval_cfg, run_id, private_paths)
self.csv_path = eval_cfg.get("csv_path")
assert self.csv_path is not None, "CSV path must be provided in the config."

self.data = pd.read_csv(self.csv_path, index_col=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing I would do is cast all the values to np.float32 (or float). pandas tries to be very clever and would for example use int32 if the data allows. I am not sure if xarray can deal with that later.


self.data = self.rename_channels()
self.metrics_base_dir = Path(self.csv_path).parent
# for backward compatibility allow metric_dir to be specified in the run config
self.metrics_dir = Path(
self.eval_cfg.get(
"metrics_dir", self.metrics_base_dir / self.run_id / "evaluation"
)
)

assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream."
self.stream = list(eval_cfg.streams.keys())[0]
self.channels = self.data.index.tolist()
self.samples = [0]
self.forecast_steps = [int(col.split()[0]) for col in self.data.columns]
self.npoints_per_sample = [0]
self.epoch = eval_cfg.get("epoch", 0)
self.metric = eval_cfg.get("metric")
self.region = eval_cfg.get("region")

def rename_channels(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you return a pd.DataFrame

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, personally, I would write this as a little helper function:

class CSVReader:
    ...
    pd_data =  pd.read_csv(self.csv_path, index_col=0)
    self.data = _rename_channels(pd_data)

def _rename_channels(data) -> pd.DataFrame:
    # No need for self.data here

"""
Rename channel names to include underscore between letters and digits.
E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, why do we need to do that renaming? I trust you it was necessary, just put a line in the docstring


Parameters
----------
name : str
Original channel name.

Returns
-------
str
Renamed channel name.
"""
for name in list(self.data.index):
# If it starts with digits (surface vars like 2t, 10ff) → leave unchanged
if re.match(r"^\d", name):
continue

# Otherwise, insert underscore between letters and digits
self.data = self.data.rename(
index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)}
)

return self.data

def get_samples(self) -> set[int]:
return set(self.samples) # Placeholder implementation

def get_forecast_steps(self) -> set[int]:
return set(self.forecast_steps) # Placeholder implementation

# TODO: get this from config
def get_channels(self, stream: str | None = None) -> list[str]:
assert stream == self.stream, "streams do not match in CSVReader."
return list(self.channels) # Placeholder implementation


class WeatherGenReader(Reader):
def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
"""Data reader class for WeatherGenerator model outputs stored in Zarr format."""
Expand Down
16 changes: 12 additions & 4 deletions packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from omegaconf import DictConfig, OmegaConf

from weathergen.common.config import _REPO_ROOT
from weathergen.evaluate.io_reader import WeatherGenReader
from weathergen.evaluate.io_reader import CsvReader, WeatherGenReader
from weathergen.evaluate.utils import (
calc_scores_per_stream,
metric_list_to_json,
plot_data,
plot_summary,
retrieve_metric_from_json,
retrieve_metric_from_file,
)

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -77,7 +77,15 @@ def evaluate_from_config(cfg):
for run_id, run in runs.items():
_logger.info(f"RUN {run_id}: Getting data...")

reader = WeatherGenReader(run, run_id, private_paths)
type = run.get("type", "zarr")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for putting a sensible default!

if type == "zarr":
reader = WeatherGenReader(run, run_id, private_paths)
elif type == "csv":
reader = CsvReader(run, run_id, private_paths)
else:
raise ValueError(
f"Unknown run type {type} for run {run_id}. Supported: zarr, csv."
)

for stream in reader.streams:
_logger.info(f"RUN {run_id}: Processing stream {stream}...")
Expand All @@ -96,7 +104,7 @@ def evaluate_from_config(cfg):

for metric in metrics:
try:
metric_data = retrieve_metric_from_json(
metric_data = retrieve_metric_from_file(
reader,
stream,
region,
Expand Down
102 changes: 66 additions & 36 deletions packages/evaluate/src/weathergen/evaluate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,9 @@ def calc_scores_per_stream(
)

available_data = reader.check_availability(stream, mode="evaluation")

output_data = reader.get_data(
stream,
region=region,
fsteps=available_data.fsteps,
samples=available_data.samples,
channels=available_data.channels,
return_counts=True,
)

da_preds = output_data.prediction
da_tars = output_data.target
points_per_sample = output_data.points_per_sample

# get coordinate information from retrieved data
fsteps = [int(k) for k in da_tars.keys()]

first_da = list(da_preds.values())[0]

# TODO: improve the way we handle samples.
samples = list(np.atleast_1d(np.unique(first_da.sample.values)))
channels = list(np.atleast_1d(first_da.channel.values))

metric_list = []
channels = available_data.channels
samples = available_data.samples
fsteps = available_data.fsteps

metric_stream = xr.DataArray(
np.full(
Expand All @@ -90,6 +69,22 @@ def calc_scores_per_stream(
},
)

output_data = reader.get_data(
stream,
region=region,
fsteps=fsteps,
samples=samples,
channels=channels,
return_counts=True,
)

da_preds = output_data.prediction
da_tars = output_data.target
points_per_sample = output_data.points_per_sample

# TODO: improve the way we handle samples.
metric_list = []

for (fstep, tars), (_, preds) in zip(
da_tars.items(), da_preds.items(), strict=False
):
Expand Down Expand Up @@ -337,9 +332,9 @@ def metric_list_to_json(
)


def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str):
def retrieve_metric_from_file(reader: Reader, stream: str, region: str, metric: str):
"""
Retrieve the score for a given run, stream, metric, epoch, and rank from a JSON file.
Retrieve the score for a given run, stream, metric, epoch, and rank from a given file (Json or csv).

Parameters
----------
Expand All @@ -357,18 +352,53 @@ def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric:
xr.DataArray
The metric DataArray.
"""
score_path = (
Path(reader.metrics_dir)
/ f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json"
)
_logger.debug(f"Looking for: {score_path}")
if hasattr(reader, "data") and reader.data is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasattr is not a good habit because it is really hard for humans and type checkers to figure out if an python object has an attribute. Here is what you can do, which then vscode can rename/check for you:

class Reader:

    data: pd.DataFrame | None  # Data attributes (if specified)

    def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
        ...
        self.data = None
        ...

# No change to WG reader or CSVReader

# Now you can directly use:
   if reader.data is not None:

available_data = reader.check_availability(stream, mode="evaluation")

# empty DataArray with NaNs
data = np.full(
(
len(available_data.samples),
len(available_data.fsteps),
len(available_data.channels),
1,
),
np.nan,
)
# fill it only for matching metric
if (
metric == reader.metric
and region == reader.region
and stream == reader.stream
):
data = reader.data.values[np.newaxis, :, :, np.newaxis].T
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style note: you colud have

else:
        data = np.full(
            (
                len(available_data.samples),
                len(available_data.fsteps),
                len(available_data.channels),
                1,
            ),
            np.nan,
        )


da = xr.DataArray(
data.astype(np.float32),
dims=("sample", "forecast_step", "channel", "metric"),
coords={
"sample": available_data.samples,
"forecast_step": available_data.fsteps,
"channel": available_data.channels,
"metric": [metric],
},
attrs={"npoints_per_sample": reader.npoints_per_sample},
)

if score_path.exists():
with open(score_path) as f:
data_dict = json.load(f)
return xr.DataArray.from_dict(data_dict)
return da
else:
raise FileNotFoundError(f"File {score_path} not found in the archive.")
score_path = (
Path(reader.metrics_dir)
/ f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json"
)
_logger.debug(f"Looking for: {score_path}")

if score_path.exists():
with open(score_path) as f:
data_dict = json.load(f)
return xr.DataArray.from_dict(data_dict)
else:
raise FileNotFoundError(f"File {score_path} not found in the archive.")


def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path):
Expand Down