Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions neural_lam/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Standard library
from typing import Optional

# Third-party
import torch


def get_metric(metric_name):
def get_metric(metric_name: str):
"""
Get a defined metric with given name

Expand All @@ -18,7 +21,12 @@ def get_metric(metric_name):
return DEFINED_METRICS[metric_name_lower]


def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
def mask_and_reduce_metric(
metric_entry_vals: torch.Tensor,
mask: Optional[torch.Tensor],
average_grid: bool,
sum_vars: bool,
) -> torch.Tensor:
"""
Masks and (optionally) reduces entry-wise metric values

Expand Down Expand Up @@ -53,7 +61,14 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
return metric_entry_vals


def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def wmse(
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
Weighted Mean Squared Error

Expand Down Expand Up @@ -84,7 +99,14 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def mse(
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
(Unweighted) Mean Squared Error

Expand All @@ -108,7 +130,14 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def wmae(
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
Weighted Mean Absolute Error

Expand Down Expand Up @@ -139,7 +168,14 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def mae(
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
(Unweighted) Mean Absolute Error

Expand All @@ -163,7 +199,14 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def nll(
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
Negative Log Likelihood loss, for isotropic Gaussian likelihood

Expand Down Expand Up @@ -191,8 +234,13 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):


def crps_gauss(
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True
):
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
(Negative) Continuous Ranked Probability Score (CRPS)
Closed-form expression based on Gaussian predictive distribution
Expand Down