Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 182 additions & 1 deletion fme/core/histogram.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import logging
import math
from collections import namedtuple
from collections.abc import Mapping
from typing import Literal
Expand All @@ -25,6 +26,37 @@ def _add_trailing_slash(s):
return s + "/"


def _decimal_places_in_percentile(reference_percentile: float) -> int:
"""Fractional digit count from ``repr`` of the configured percentile."""
s = repr(reference_percentile)
if "." not in s:
return 0
return len(s.split(".")[1])


def _format_percentile_for_metric_key(p: float, reference_percentile: float) -> str:
r"""
Format *p* (the percentile passed to tail routines) for wandb keys.

When *p* matches the configured reference (upper tail), use ``str(reference)``
so keys match e.g. ``99.0th-percentile``. For complementary tails
(``p = 100 - reference``), round using the reference's fractional width
(e.g. ref ``99.9999`` -> ``0.0001``).
"""
if math.isclose(
p,
reference_percentile,
rel_tol=0.0,
abs_tol=1e-9 * max(1.0, abs(reference_percentile)),
):
return str(reference_percentile)
decimals = _decimal_places_in_percentile(reference_percentile)
rounded = round(p, decimals)
if decimals == 0:
return str(int(rounded))
return f"{rounded:.{decimals}f}"


def trim_zero_bins(
counts: np.ndarray, bin_edges: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -100,13 +132,19 @@ def _abs_norm_tail_bias(
target_counts: np.ndarray,
predict_bin_edges: np.ndarray,
target_bin_edges: np.ndarray,
tail: Literal["upper", "lower"] = "upper",
):
pred_counts_rebinned = _rebin_counts(
bin_edges=predict_bin_edges, counts=predict_counts, new_edges=target_bin_edges
)
bin_centers = 0.5 * (target_bin_edges[:-1] + target_bin_edges[1:])
threshold = quantile(target_bin_edges, target_counts, percentile / 100.0)
tail_mask = bin_centers > threshold
if tail == "upper":
tail_mask = bin_centers > threshold
elif tail == "lower":
tail_mask = bin_centers < threshold
else:
raise ValueError(f"Invalid tail arg: {tail}. Must be 'upper' or 'lower'.")

pred_density = (pred_counts_rebinned / np.sum(pred_counts_rebinned))[tail_mask]
target_density = (target_counts / np.sum(target_counts))[tail_mask]
Expand Down Expand Up @@ -512,6 +550,149 @@ def get_dataset(self) -> xr.Dataset:
return ds


class ComparedDynamicTailsHistograms(ComparedDynamicHistograms):
"""ComparedDynamicHistograms with support for upper-tailed, lower-tailed,
and two-tailed distribution metrics.
"""

def __init__(
self,
n_bins: int,
percentiles: list[float] | None = None,
compute_percentile_frac: bool = False,
two_tailed_variables: list[str] | None = None,
left_tailed_variables: list[str] | None = None,
) -> None:
super().__init__(
n_bins=n_bins,
percentiles=percentiles,
compute_percentile_frac=compute_percentile_frac,
)
self._two_tailed_variables = two_tailed_variables or []
self._left_tailed_variables = left_tailed_variables or []

def _variable_distribution_tail(
self, field_name: str
) -> Literal["upper", "lower", "both"]:
if field_name in self._two_tailed_variables:
return "both"
if field_name in self._left_tailed_variables:
return "lower"
return "upper"

def _percentile_entries_for_field(self, field_name: str) -> list[tuple[float, str]]:
"""(quantile_percentile, stable_key_fragment) pairs for wandb metric names."""
tail = self._variable_distribution_tail(field_name)
entries: list[tuple[float, str]] = []
if tail == "upper":
for ref in self.percentiles:
p = ref
entries.append((p, _format_percentile_for_metric_key(p, ref)))
elif tail == "lower":
for ref in self.percentiles:
p = 100.0 - ref
entries.append((p, _format_percentile_for_metric_key(p, ref)))
else:
for ref in self.percentiles:
p = 100.0 - ref
entries.append((p, _format_percentile_for_metric_key(p, ref)))
for ref in self.percentiles:
p = ref
entries.append((p, _format_percentile_for_metric_key(p, ref)))
return entries

def _get_abs_norm_tail_biases_beyond_percentile(
self, field_name: str, prediction: _Histogram, target: _Histogram
) -> dict[str, float]:
return_dict: dict[str, float] = {}
for ref in self.percentiles:
if field_name in self._left_tailed_variables:
p = 100.0 - ref
else:
p = ref
p_key = _format_percentile_for_metric_key(p, ref)
if self._variable_distribution_tail(field_name) == "upper":
return_dict[
f"abs_norm_tail_bias_beyond_percentile/{p_key}/{field_name}"
] = _abs_norm_tail_bias(
percentile=p,
predict_counts=prediction.counts,
target_counts=target.counts,
predict_bin_edges=prediction.bin_edges,
target_bin_edges=target.bin_edges,
tail="upper",
)
elif self._variable_distribution_tail(field_name) == "lower":
return_dict[
f"abs_norm_tail_bias_beyond_percentile/{p_key}/{field_name}"
] = _abs_norm_tail_bias(
percentile=p,
predict_counts=prediction.counts,
target_counts=target.counts,
predict_bin_edges=prediction.bin_edges,
target_bin_edges=target.bin_edges,
tail="lower",
)
else:
bias_upper = _abs_norm_tail_bias(
percentile=p,
predict_counts=prediction.counts,
target_counts=target.counts,
predict_bin_edges=prediction.bin_edges,
target_bin_edges=target.bin_edges,
tail="upper",
)
bias_lower = _abs_norm_tail_bias(
percentile=p,
predict_counts=prediction.counts,
target_counts=target.counts,
predict_bin_edges=prediction.bin_edges,
target_bin_edges=target.bin_edges,
tail="lower",
)
return_dict[
f"abs_norm_tail_bias_beyond_percentile/{p_key}/{field_name}"
] = (bias_upper + bias_lower) / 2.0
return return_dict

def get_wandb(self) -> dict[str, float]:
return_dict: dict[str, matplotlib.figure.Figure | float] = {}

for field_name, histograms in self._get_histograms().items():
target = histograms.get("target")
prediction = histograms.get("prediction")
fig = self._plot_histogram(target, prediction)
return_dict[field_name] = fig
plt.close(fig)
percentile_entries = self._percentile_entries_for_field(field_name)
if target is not None:
for p, p_key in percentile_entries:
return_dict[f"target/{p_key}th-percentile/{field_name}"] = quantile(
target.bin_edges, target.counts, p / 100.0
)
if prediction is not None:
for p, p_key in percentile_entries:
return_dict[f"prediction/{p_key}th-percentile/{field_name}"] = (
quantile(prediction.bin_edges, prediction.counts, p / 100.0)
)
if self._compute_percentile_frac and target is not None:
return_dict[
f"prediction_frac_of_target/{p_key}th-percentile/{field_name}"
] = (
return_dict[f"prediction/{p_key}th-percentile/{field_name}"]
/ return_dict[f"target/{p_key}th-percentile/{field_name}"]
)

if target is not None:
return_dict.update(
self._get_abs_norm_tail_biases_beyond_percentile(
field_name, prediction, target
)
)

return return_dict


def _normalize_histogram(counts, bin_edges):
"""
Normalize histogram counts so that the integral is 1.
Expand Down
13 changes: 9 additions & 4 deletions fme/downscaling/aggregators/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch
import xarray as xr

from fme.core.histogram import ComparedDynamicHistograms, DynamicHistogramAggregator
from fme.core.histogram import (
ComparedDynamicTailsHistograms,
DynamicHistogramAggregator,
)
from fme.core.typing_ import TensorMapping


Expand Down Expand Up @@ -47,7 +50,7 @@ def get_dataset(self) -> dict[str, Any]:

class DynamicHistogramsAdapter(_HistogramsAdapter):
"""
Adapter to use DynamicHistogramAggregator with the naming and prefix
Adapter to use DynamicTailsHistogramAggregator with the naming and prefix
scheme used by downscaling aggregators.

Args:
Expand All @@ -73,8 +76,10 @@ def record_batch(
self._histograms.record_batch(prediction)


class ComparedDynamicHistogramsAdapter(_HistogramsAdapter):
def __init__(self, histograms: ComparedDynamicHistograms, name: str = "") -> None:
class ComparedDynamicTailsHistogramsAdapter(_HistogramsAdapter):
def __init__(
self, histograms: ComparedDynamicTailsHistograms, name: str = ""
) -> None:
super().__init__(histograms=histograms, name=name)

@torch.no_grad()
Expand Down
15 changes: 11 additions & 4 deletions fme/downscaling/aggregators/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.histogram import ComparedDynamicHistograms
from fme.core.histogram import ComparedDynamicTailsHistograms
from fme.core.typing_ import TensorMapping
from fme.core.wandb import WandB
from fme.downscaling.aggregators.adapters import ComparedDynamicHistogramsAdapter
from fme.downscaling.aggregators.adapters import ComparedDynamicTailsHistogramsAdapter
from fme.downscaling.data import PairedBatchData

from ..metrics_and_maths import (
Expand Down Expand Up @@ -758,6 +758,11 @@ def __init__(
ssim_kwargs: Mapping[str, Any] | None = None,
variable_metadata: Mapping[str, VariableMetadata] | None = None,
include_positional_comparisons: bool = True,
two_tailed_variables: list[str] | None = [
"eastward_wind_at_ten_meters",
"northward_wind_at_ten_meters",
],
left_tailed_variables: list[str] | None = ["PRMSL"],
) -> None:
self.downscale_factor = downscale_factor

Expand All @@ -768,11 +773,13 @@ def __init__(
self._comparisons: list[_ComparisonAggregator] = [
MeanComparison(metrics.root_mean_squared_error, name="metrics/rmse"),
SnapshotAggregator(dims, variable_metadata, name="snapshot"),
ComparedDynamicHistogramsAdapter(
histograms=ComparedDynamicHistograms(
ComparedDynamicTailsHistogramsAdapter(
histograms=ComparedDynamicTailsHistograms(
n_bins=n_histogram_bins,
percentiles=percentiles,
compute_percentile_frac=True,
two_tailed_variables=two_tailed_variables,
left_tailed_variables=left_tailed_variables,
),
name="histogram",
),
Expand Down
Loading