diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 66eafe741..cedb2dfc9 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -8,11 +8,13 @@ # nor does it submit to any jurisdiction. import logging +import re from dataclasses import dataclass from pathlib import Path import numpy as np import omegaconf as oc +import pandas as pd import xarray as xr from tqdm import tqdm @@ -82,7 +84,6 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.eval_cfg = eval_cfg self.run_id = run_id self.private_paths = private_paths - self.streams = eval_cfg.streams.keys() # If results_base_dir and model_base_dir are not provided, default paths are used @@ -269,6 +270,89 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili ) +class CsvReader(Reader): + """ + Reader class to read evaluation data from CSV files and convert to xarray DataArray. + """ + + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): + """ + Initialize the CsvReader. + + Parameters + ---------- + eval_cfg : dir + config with plotting and evaluation options for that run id + run_id : str + run id of the model + private_paths: lists + list of private paths for the supported HPC + """ + + super().__init__(eval_cfg, run_id, private_paths) + self.csv_path = eval_cfg.get("csv_path") + assert self.csv_path is not None, "CSV path must be provided in the config." + + self.data = pd.read_csv(self.csv_path, index_col=0) + + self.data = self.rename_channels() + self.metrics_base_dir = Path(self.csv_path).parent + # for backward compatibility allow metric_dir to be specified in the run config + self.metrics_dir = Path( + self.eval_cfg.get( + "metrics_dir", self.metrics_base_dir / self.run_id / "evaluation" + ) + ) + + assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream." + self.stream = list(eval_cfg.streams.keys())[0] + self.channels = self.data.index.tolist() + self.samples = [0] + self.forecast_steps = [int(col.split()[0]) for col in self.data.columns] + self.npoints_per_sample = [0] + self.epoch = eval_cfg.get("epoch", 0) + self.metric = eval_cfg.get("metric") + self.region = eval_cfg.get("region") + + def rename_channels(self) -> str: + """ + Rename channel names to include underscore between letters and digits. + E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' + + Parameters + ---------- + name : str + Original channel name. + + Returns + ------- + str + Renamed channel name. + """ + for name in list(self.data.index): + # If it starts with digits (surface vars like 2t, 10ff) → leave unchanged + if re.match(r"^\d", name): + continue + + # Otherwise, insert underscore between letters and digits + self.data = self.data.rename( + index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)} + ) + + return self.data + + def get_samples(self) -> set[int]: + return set(self.samples) # Placeholder implementation + + def get_forecast_steps(self) -> set[int]: + return set(self.forecast_steps) # Placeholder implementation + + # TODO: get this from config + def get_channels(self, stream: str | None = None) -> list[str]: + assert stream == self.stream, "streams do not match in CSVReader." + return list(self.channels) # Placeholder implementation + + class WeatherGenReader(Reader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index a500f96db..7a81cb487 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -17,13 +17,13 @@ from omegaconf import DictConfig, OmegaConf from weathergen.common.config import _REPO_ROOT -from weathergen.evaluate.io_reader import WeatherGenReader +from weathergen.evaluate.io_reader import CsvReader, WeatherGenReader from weathergen.evaluate.utils import ( calc_scores_per_stream, metric_list_to_json, plot_data, plot_summary, - retrieve_metric_from_json, + retrieve_metric_from_file, ) _logger = logging.getLogger(__name__) @@ -77,7 +77,15 @@ def evaluate_from_config(cfg): for run_id, run in runs.items(): _logger.info(f"RUN {run_id}: Getting data...") - reader = WeatherGenReader(run, run_id, private_paths) + type = run.get("type", "zarr") + if type == "zarr": + reader = WeatherGenReader(run, run_id, private_paths) + elif type == "csv": + reader = CsvReader(run, run_id, private_paths) + else: + raise ValueError( + f"Unknown run type {type} for run {run_id}. Supported: zarr, csv." + ) for stream in reader.streams: _logger.info(f"RUN {run_id}: Processing stream {stream}...") @@ -96,7 +104,7 @@ def evaluate_from_config(cfg): for metric in metrics: try: - metric_data = retrieve_metric_from_json( + metric_data = retrieve_metric_from_file( reader, stream, region, diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index 0ae0a1c69..fb300f193 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -52,30 +52,9 @@ def calc_scores_per_stream( ) available_data = reader.check_availability(stream, mode="evaluation") - - output_data = reader.get_data( - stream, - region=region, - fsteps=available_data.fsteps, - samples=available_data.samples, - channels=available_data.channels, - return_counts=True, - ) - - da_preds = output_data.prediction - da_tars = output_data.target - points_per_sample = output_data.points_per_sample - - # get coordinate information from retrieved data - fsteps = [int(k) for k in da_tars.keys()] - - first_da = list(da_preds.values())[0] - - # TODO: improve the way we handle samples. - samples = list(np.atleast_1d(np.unique(first_da.sample.values))) - channels = list(np.atleast_1d(first_da.channel.values)) - - metric_list = [] + channels = available_data.channels + samples = available_data.samples + fsteps = available_data.fsteps metric_stream = xr.DataArray( np.full( @@ -90,6 +69,22 @@ def calc_scores_per_stream( }, ) + output_data = reader.get_data( + stream, + region=region, + fsteps=fsteps, + samples=samples, + channels=channels, + return_counts=True, + ) + + da_preds = output_data.prediction + da_tars = output_data.target + points_per_sample = output_data.points_per_sample + + # TODO: improve the way we handle samples. + metric_list = [] + for (fstep, tars), (_, preds) in zip( da_tars.items(), da_preds.items(), strict=False ): @@ -337,9 +332,9 @@ def metric_list_to_json( ) -def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str): +def retrieve_metric_from_file(reader: Reader, stream: str, region: str, metric: str): """ - Retrieve the score for a given run, stream, metric, epoch, and rank from a JSON file. + Retrieve the score for a given run, stream, metric, epoch, and rank from a given file (Json or csv). Parameters ---------- @@ -357,18 +352,53 @@ def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: xr.DataArray The metric DataArray. """ - score_path = ( - Path(reader.metrics_dir) - / f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json" - ) - _logger.debug(f"Looking for: {score_path}") + if hasattr(reader, "data") and reader.data is not None: + available_data = reader.check_availability(stream, mode="evaluation") + + # empty DataArray with NaNs + data = np.full( + ( + len(available_data.samples), + len(available_data.fsteps), + len(available_data.channels), + 1, + ), + np.nan, + ) + # fill it only for matching metric + if ( + metric == reader.metric + and region == reader.region + and stream == reader.stream + ): + data = reader.data.values[np.newaxis, :, :, np.newaxis].T + + da = xr.DataArray( + data.astype(np.float32), + dims=("sample", "forecast_step", "channel", "metric"), + coords={ + "sample": available_data.samples, + "forecast_step": available_data.fsteps, + "channel": available_data.channels, + "metric": [metric], + }, + attrs={"npoints_per_sample": reader.npoints_per_sample}, + ) - if score_path.exists(): - with open(score_path) as f: - data_dict = json.load(f) - return xr.DataArray.from_dict(data_dict) + return da else: - raise FileNotFoundError(f"File {score_path} not found in the archive.") + score_path = ( + Path(reader.metrics_dir) + / f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json" + ) + _logger.debug(f"Looking for: {score_path}") + + if score_path.exists(): + with open(score_path) as f: + data_dict = json.load(f) + return xr.DataArray.from_dict(data_dict) + else: + raise FileNotFoundError(f"File {score_path} not found in the archive.") def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path):