Skip to content

Commit 33c369f

Browse files
iluiseclessigtjhunterSavvasMelsophie-xhonneux
authored
[issue 1123] restore probabilistic scores (#1128)
* rebase * add ensemble * fix deterministic * fix plotting * lint * fix eval_config * probabilistic scores working now * lint * Fix spoofing and refactor handling of multiple source files (#1118) * Cleaning up spoofing and related code on data preprocessing for model * Fixed typo * Updated comments * Removed merge cells and implemented necessary adjustments * Fixed forecasting * Fixed missing handling of NaNs in coordinates and channel data * Minor clean up * Fix to removing/renaming variables * Changed funtion name to improve readability * Fixed bug with incorrect handling of multiple input datasources. * Addressed reviewer comments * resolve conflict * [1131] fixes circular dependencies (#1134) * fixes dependencies * cleanup * make the type checker not fail * cleanup * cleanup of type issues * Give option to plot only prediction maps (#1139) * add plot_preds_only feature * minor changes after comments * Tell FSDP2 about embedding engine forward functions (#1133) * Tell FSDP2 about embedding engine forward functions Note DO NOT add print functions in forward functions of the model, it will break with FSDP2 * Add comment * recover 'all' option (#1146) * Fixed problem in inferecne (#1145) * implement vrmse (#1147) * [1144] Extra fixes (#1148) * Fixed problem in inferecne * more fixes * fixes * lint * lint --------- Co-authored-by: Christian Lessig <[email protected]> * Jk/log grad norms/log grad norms (#1068) * Log gradient norms * Prototype for recording grad norms * Address review changes + hide behind feature flag * Final fixes including backward compatibility * Ruff * More ruff stuff * forecast config with small decoder * fixed uv.lock * test gradient logging on mutli gpus * update uv.lock to latest develop version * revert to default confit * add comment on FSDP2 specifics * move plot grad script to private repo * rm seaborn from pyproject * updating terminal and metrics loggin, add get_tensor_item fct * check for DTensor instead of world size * revert forecast fct, fix in separate PR * rename grad_norm log names to exclude from MLFlow * add log_grad_norms to default config --------- Co-authored-by: sophiex <[email protected]> * Add forecast and observation activity (#1126) * Add calculation methods for forecast and observation activity metrics in Scores class * Add new calculation methods for forecast activity metrics in Scores class * ruff * fix func name * Rename observation activity calculation method to target activity in Scores class * typo * refactor to common calc_act function for activity * fix cases * have calc_tact and calc_fact that use _calc_act for maintainability * fix small thing in style --------- Co-authored-by: iluise <[email protected]> * hotfix: use correct methot `create` instead of `construct` (#1090) * restore develop * fix deterministic * fix plotting * lint * fix eval_config * probabilistic scores working now * lint * update utils * packages/evaluate/src/weathergen/evaluate/score.py * lint * removing duplication --------- Co-authored-by: Christian Lessig <[email protected]> Co-authored-by: Timothy Hunter <[email protected]> Co-authored-by: Savvas Melidonis <[email protected]> Co-authored-by: Sophie X <[email protected]> Co-authored-by: Julius Polz <[email protected]> Co-authored-by: Julian Kuehnert <[email protected]> Co-authored-by: Simon Grasse <[email protected]>
1 parent 06188d9 commit 33c369f

File tree

4 files changed

+344
-252
lines changed

4 files changed

+344
-252
lines changed

packages/evaluate/src/weathergen/evaluate/io_reader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def check_availability(
164164
fsteps = requested_data.fsteps
165165
samples = requested_data.samples
166166
ensemble = requested_data.ensemble
167-
168167
requested = {
169168
"channel": set(channels) if channels is not None else None,
170169
"fstep": set(fsteps) if fsteps is not None else None,
@@ -478,6 +477,13 @@ def get_data(
478477
_logger.debug(f"Selecting ensemble members {ensemble}.")
479478
pred = pred.sel(ens=ensemble)
480479

480+
if ensemble == ["mean"]:
481+
_logger.debug("Averaging over ensemble members.")
482+
pred = pred.mean("ens", keepdims=True)
483+
else:
484+
_logger.debug(f"Selecting ensemble members {ensemble}.")
485+
pred = pred.sel(ens=ensemble)
486+
481487
da_tars_fs.append(target.squeeze())
482488
da_preds_fs.append(pred.squeeze())
483489
pps.append(npoints)

packages/evaluate/src/weathergen/evaluate/plotter.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def select_from_da(self, da: xr.DataArray, selection: dict) -> xr.DataArray:
138138
-------
139139
xarray DataArray with selected data.
140140
"""
141-
142141
for key, value in selection.items():
143142
if key in da.coords and key not in da.dims:
144143
# Coordinate like 'sample' aligned to another dim
@@ -710,6 +709,77 @@ def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None:
710709
f"LinePlot:: Unknown option for plot_ensemble: {self.plot_ensemble}. Skipping ensemble plotting."
711710
)
712711

712+
def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None:
713+
"""
714+
Plot ensemble spread for a data array.
715+
716+
Parameters
717+
----------
718+
data: xr.xArray
719+
DataArray to be plotted
720+
x_dim: str
721+
Dimension to be used for the x-axis.
722+
label: str
723+
Label for the dataset
724+
Returns
725+
-------
726+
None
727+
"""
728+
averaged = data.mean(dim=[dim for dim in data.dims if dim != x_dim], skipna=True).sortby(
729+
x_dim
730+
)
731+
732+
lines = plt.plot(
733+
averaged[x_dim],
734+
averaged.values,
735+
label=label,
736+
marker="o",
737+
linestyle="-",
738+
)
739+
line = lines[0]
740+
color = line.get_color()
741+
742+
ens = data.mean(
743+
dim=[dim for dim in data.dims if dim not in [x_dim, "ens"]], skipna=True
744+
).sortby(x_dim)
745+
746+
if self.plot_ensemble == "std":
747+
std_dev = ens.std(dim="ens", skipna=True).sortby(x_dim)
748+
plt.fill_between(
749+
averaged[x_dim],
750+
(averaged - std_dev).values,
751+
(averaged + std_dev).values,
752+
label=f"{label} - std dev",
753+
color=color,
754+
alpha=0.2,
755+
)
756+
757+
elif self.plot_ensemble == "minmax":
758+
ens_min = ens.min(dim="ens", skipna=True).sortby(x_dim)
759+
ens_max = ens.max(dim="ens", skipna=True).sortby(x_dim)
760+
761+
plt.fill_between(
762+
averaged[x_dim],
763+
ens_min.values,
764+
ens_max.values,
765+
label=f"{label} - min max",
766+
color=color,
767+
alpha=0.2,
768+
)
769+
770+
elif self.plot_ensemble == "members":
771+
for j in range(ens.ens.size):
772+
plt.plot(
773+
ens[x_dim],
774+
ens.isel(ens=j).values,
775+
color=color,
776+
alpha=0.2,
777+
)
778+
else:
779+
_logger.warning(
780+
f"LinePlot:: Unknown option for plot_ensemble: {self.plot_ensemble}. Skipping ensemble plotting."
781+
)
782+
713783
def plot(
714784
self,
715785
data: xr.DataArray | list,
@@ -737,7 +807,6 @@ def plot(
737807
Name of the dimension to be used for the y-axis.
738808
print_summary:
739809
If True, print a summary of the values from the graph.
740-
741810
Returns
742811
-------
743812
None

0 commit comments

Comments
 (0)