-
Notifications
You must be signed in to change notification settings - Fork 38
[930][evaluation] implement CSVReader #932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,11 +8,13 @@ | |
# nor does it submit to any jurisdiction. | ||
|
||
import logging | ||
import re | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import omegaconf as oc | ||
import pandas as pd | ||
import xarray as xr | ||
from tqdm import tqdm | ||
|
||
|
@@ -82,7 +84,6 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non | |
self.eval_cfg = eval_cfg | ||
self.run_id = run_id | ||
self.private_paths = private_paths | ||
|
||
self.streams = eval_cfg.streams.keys() | ||
|
||
# If results_base_dir and model_base_dir are not provided, default paths are used | ||
|
@@ -269,6 +270,89 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili | |
) | ||
|
||
|
||
class CsvReader(Reader): | ||
""" | ||
Reader class to read evaluation data from CSV files and convert to xarray DataArray. | ||
""" | ||
|
||
def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): | ||
""" | ||
Initialize the CsvReader. | ||
|
||
Parameters | ||
---------- | ||
eval_cfg : dir | ||
config with plotting and evaluation options for that run id | ||
run_id : str | ||
run id of the model | ||
private_paths: lists | ||
list of private paths for the supported HPC | ||
""" | ||
|
||
super().__init__(eval_cfg, run_id, private_paths) | ||
self.csv_path = eval_cfg.get("csv_path") | ||
assert self.csv_path is not None, "CSV path must be provided in the config." | ||
|
||
self.data = pd.read_csv(self.csv_path, index_col=0) | ||
|
||
self.data = self.rename_channels() | ||
self.metrics_base_dir = Path(self.csv_path).parent | ||
# for backward compatibility allow metric_dir to be specified in the run config | ||
self.metrics_dir = Path( | ||
self.eval_cfg.get( | ||
"metrics_dir", self.metrics_base_dir / self.run_id / "evaluation" | ||
) | ||
) | ||
|
||
assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream." | ||
self.stream = list(eval_cfg.streams.keys())[0] | ||
self.channels = self.data.index.tolist() | ||
self.samples = [0] | ||
self.forecast_steps = [int(col.split()[0]) for col in self.data.columns] | ||
self.npoints_per_sample = [0] | ||
self.epoch = eval_cfg.get("epoch", 0) | ||
self.metric = eval_cfg.get("metric") | ||
self.region = eval_cfg.get("region") | ||
|
||
def rename_channels(self) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you return a pd.DataFrame There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, personally, I would write this as a little helper function: class CSVReader:
...
pd_data = pd.read_csv(self.csv_path, index_col=0)
self.data = _rename_channels(pd_data)
def _rename_channels(data) -> pd.DataFrame:
# No need for self.data here |
||
""" | ||
Rename channel names to include underscore between letters and digits. | ||
E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, why do we need to do that renaming? I trust you it was necessary, just put a line in the docstring |
||
|
||
Parameters | ||
---------- | ||
name : str | ||
Original channel name. | ||
|
||
Returns | ||
------- | ||
str | ||
Renamed channel name. | ||
""" | ||
for name in list(self.data.index): | ||
# If it starts with digits (surface vars like 2t, 10ff) → leave unchanged | ||
if re.match(r"^\d", name): | ||
continue | ||
|
||
# Otherwise, insert underscore between letters and digits | ||
self.data = self.data.rename( | ||
index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)} | ||
) | ||
|
||
return self.data | ||
|
||
def get_samples(self) -> set[int]: | ||
return set(self.samples) # Placeholder implementation | ||
|
||
def get_forecast_steps(self) -> set[int]: | ||
return set(self.forecast_steps) # Placeholder implementation | ||
|
||
# TODO: get this from config | ||
def get_channels(self, stream: str | None = None) -> list[str]: | ||
assert stream == self.stream, "streams do not match in CSVReader." | ||
return list(self.channels) # Placeholder implementation | ||
|
||
|
||
class WeatherGenReader(Reader): | ||
def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): | ||
"""Data reader class for WeatherGenerator model outputs stored in Zarr format.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,13 +17,13 @@ | |
from omegaconf import DictConfig, OmegaConf | ||
|
||
from weathergen.common.config import _REPO_ROOT | ||
from weathergen.evaluate.io_reader import WeatherGenReader | ||
from weathergen.evaluate.io_reader import CsvReader, WeatherGenReader | ||
from weathergen.evaluate.utils import ( | ||
calc_scores_per_stream, | ||
metric_list_to_json, | ||
plot_data, | ||
plot_summary, | ||
retrieve_metric_from_json, | ||
retrieve_metric_from_file, | ||
) | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
@@ -77,7 +77,15 @@ def evaluate_from_config(cfg): | |
for run_id, run in runs.items(): | ||
_logger.info(f"RUN {run_id}: Getting data...") | ||
|
||
reader = WeatherGenReader(run, run_id, private_paths) | ||
type = run.get("type", "zarr") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thank you for putting a sensible default! |
||
if type == "zarr": | ||
reader = WeatherGenReader(run, run_id, private_paths) | ||
elif type == "csv": | ||
reader = CsvReader(run, run_id, private_paths) | ||
else: | ||
raise ValueError( | ||
f"Unknown run type {type} for run {run_id}. Supported: zarr, csv." | ||
) | ||
|
||
for stream in reader.streams: | ||
_logger.info(f"RUN {run_id}: Processing stream {stream}...") | ||
|
@@ -96,7 +104,7 @@ def evaluate_from_config(cfg): | |
|
||
for metric in metrics: | ||
try: | ||
metric_data = retrieve_metric_from_json( | ||
metric_data = retrieve_metric_from_file( | ||
reader, | ||
stream, | ||
region, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,30 +52,9 @@ def calc_scores_per_stream( | |
) | ||
|
||
available_data = reader.check_availability(stream, mode="evaluation") | ||
|
||
output_data = reader.get_data( | ||
stream, | ||
region=region, | ||
fsteps=available_data.fsteps, | ||
samples=available_data.samples, | ||
channels=available_data.channels, | ||
return_counts=True, | ||
) | ||
|
||
da_preds = output_data.prediction | ||
da_tars = output_data.target | ||
points_per_sample = output_data.points_per_sample | ||
|
||
# get coordinate information from retrieved data | ||
fsteps = [int(k) for k in da_tars.keys()] | ||
|
||
first_da = list(da_preds.values())[0] | ||
|
||
# TODO: improve the way we handle samples. | ||
samples = list(np.atleast_1d(np.unique(first_da.sample.values))) | ||
channels = list(np.atleast_1d(first_da.channel.values)) | ||
|
||
metric_list = [] | ||
channels = available_data.channels | ||
samples = available_data.samples | ||
fsteps = available_data.fsteps | ||
|
||
metric_stream = xr.DataArray( | ||
np.full( | ||
|
@@ -90,6 +69,22 @@ def calc_scores_per_stream( | |
}, | ||
) | ||
|
||
output_data = reader.get_data( | ||
stream, | ||
region=region, | ||
fsteps=fsteps, | ||
samples=samples, | ||
channels=channels, | ||
return_counts=True, | ||
) | ||
|
||
da_preds = output_data.prediction | ||
da_tars = output_data.target | ||
points_per_sample = output_data.points_per_sample | ||
|
||
# TODO: improve the way we handle samples. | ||
metric_list = [] | ||
|
||
for (fstep, tars), (_, preds) in zip( | ||
da_tars.items(), da_preds.items(), strict=False | ||
): | ||
|
@@ -337,9 +332,9 @@ def metric_list_to_json( | |
) | ||
|
||
|
||
def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str): | ||
def retrieve_metric_from_file(reader: Reader, stream: str, region: str, metric: str): | ||
""" | ||
Retrieve the score for a given run, stream, metric, epoch, and rank from a JSON file. | ||
Retrieve the score for a given run, stream, metric, epoch, and rank from a given file (Json or csv). | ||
|
||
Parameters | ||
---------- | ||
|
@@ -357,18 +352,53 @@ def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: | |
xr.DataArray | ||
The metric DataArray. | ||
""" | ||
score_path = ( | ||
Path(reader.metrics_dir) | ||
/ f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json" | ||
) | ||
_logger.debug(f"Looking for: {score_path}") | ||
if hasattr(reader, "data") and reader.data is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
class Reader:
data: pd.DataFrame | None # Data attributes (if specified)
def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
...
self.data = None
...
# No change to WG reader or CSVReader
# Now you can directly use:
if reader.data is not None: |
||
available_data = reader.check_availability(stream, mode="evaluation") | ||
|
||
# empty DataArray with NaNs | ||
data = np.full( | ||
( | ||
len(available_data.samples), | ||
len(available_data.fsteps), | ||
len(available_data.channels), | ||
1, | ||
), | ||
np.nan, | ||
) | ||
# fill it only for matching metric | ||
if ( | ||
metric == reader.metric | ||
and region == reader.region | ||
and stream == reader.stream | ||
): | ||
data = reader.data.values[np.newaxis, :, :, np.newaxis].T | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style note: you colud have else:
data = np.full(
(
len(available_data.samples),
len(available_data.fsteps),
len(available_data.channels),
1,
),
np.nan,
) |
||
|
||
da = xr.DataArray( | ||
data.astype(np.float32), | ||
dims=("sample", "forecast_step", "channel", "metric"), | ||
coords={ | ||
"sample": available_data.samples, | ||
"forecast_step": available_data.fsteps, | ||
"channel": available_data.channels, | ||
"metric": [metric], | ||
}, | ||
attrs={"npoints_per_sample": reader.npoints_per_sample}, | ||
) | ||
|
||
if score_path.exists(): | ||
with open(score_path) as f: | ||
data_dict = json.load(f) | ||
return xr.DataArray.from_dict(data_dict) | ||
return da | ||
else: | ||
raise FileNotFoundError(f"File {score_path} not found in the archive.") | ||
score_path = ( | ||
Path(reader.metrics_dir) | ||
/ f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json" | ||
) | ||
_logger.debug(f"Looking for: {score_path}") | ||
|
||
if score_path.exists(): | ||
with open(score_path) as f: | ||
data_dict = json.load(f) | ||
return xr.DataArray.from_dict(data_dict) | ||
else: | ||
raise FileNotFoundError(f"File {score_path} not found in the archive.") | ||
|
||
|
||
def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing I would do is cast all the values to np.float32 (or float). pandas tries to be very clever and would for example use int32 if the data allows. I am not sure if xarray can deal with that later.