Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions neural_lam/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Standard library
from typing import Optional
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Project requires Python >= 3.10, so Optional from typing is a legacy alias. torch.Tensor | None is preferred. If you update the usages, this import can be dropped entirely.

Suggested change
from typing import Optional
from collections.abc import Callable

(if you also fix the get_metric return type above)


# Third-party
import torch


def get_metric(metric_name):
def get_metric(metric_name: str):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_metric is missing a return type. The PR description says "all functions" but this one's incomplete. Since it returns one of the metric callables, something like this works:

Suggested change
def get_metric(metric_name: str):
def get_metric(metric_name: str) -> "Callable[..., torch.Tensor]":

You'd also need from collections.abc import Callable (or from typing import Callable) added to the imports.

"""
Get a defined metric with given name

Expand All @@ -18,7 +21,12 @@ def get_metric(metric_name):
return DEFINED_METRICS[metric_name_lower]


def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
def mask_and_reduce_metric(
metric_entry_vals: torch.Tensor,
mask: Optional[torch.Tensor],
average_grid: bool,
sum_vars: bool,
) -> torch.Tensor:
"""
Masks and (optionally) reduces entry-wise metric values

Expand Down Expand Up @@ -53,7 +61,14 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
return metric_entry_vals


def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def wmse(
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
Weighted Mean Squared Error

Expand Down Expand Up @@ -84,7 +99,14 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def mse(
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
(Unweighted) Mean Squared Error

Expand All @@ -108,7 +130,14 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def wmae(
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
Weighted Mean Absolute Error

Expand Down Expand Up @@ -139,7 +168,14 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def mae(
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
(Unweighted) Mean Absolute Error

Expand All @@ -163,7 +199,14 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def nll(
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
Negative Log Likelihood loss, for isotropic Gaussian likelihood

Expand Down Expand Up @@ -191,8 +234,13 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):


def crps_gauss(
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True
):
pred: torch.Tensor,
target: torch.Tensor,
pred_std: torch.Tensor,
mask: Optional[torch.Tensor] = None,
average_grid: bool = True,
sum_vars: bool = True,
) -> torch.Tensor:
"""
(Negative) Continuous Ranked Probability Score (CRPS)
Closed-form expression based on Gaussian predictive distribution
Expand Down