-
Notifications
You must be signed in to change notification settings - Fork 256
Add type hints to all functions in metrics.py #446 #447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,8 +1,11 @@ | ||||||
| # Standard library | ||||||
| from typing import Optional | ||||||
|
|
||||||
| # Third-party | ||||||
| import torch | ||||||
|
|
||||||
|
|
||||||
| def get_metric(metric_name): | ||||||
| def get_metric(metric_name: str): | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
You'd also need |
||||||
| """ | ||||||
| Get a defined metric with given name | ||||||
|
|
||||||
|
|
@@ -18,7 +21,12 @@ def get_metric(metric_name): | |||||
| return DEFINED_METRICS[metric_name_lower] | ||||||
|
|
||||||
|
|
||||||
| def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars): | ||||||
| def mask_and_reduce_metric( | ||||||
| metric_entry_vals: torch.Tensor, | ||||||
| mask: Optional[torch.Tensor], | ||||||
| average_grid: bool, | ||||||
| sum_vars: bool, | ||||||
| ) -> torch.Tensor: | ||||||
| """ | ||||||
| Masks and (optionally) reduces entry-wise metric values | ||||||
|
|
||||||
|
|
@@ -53,7 +61,14 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars): | |||||
| return metric_entry_vals | ||||||
|
|
||||||
|
|
||||||
| def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||||||
| def wmse( | ||||||
| pred: torch.Tensor, | ||||||
| target: torch.Tensor, | ||||||
| pred_std: torch.Tensor, | ||||||
| mask: Optional[torch.Tensor] = None, | ||||||
| average_grid: bool = True, | ||||||
| sum_vars: bool = True, | ||||||
| ) -> torch.Tensor: | ||||||
| """ | ||||||
| Weighted Mean Squared Error | ||||||
|
|
||||||
|
|
@@ -84,7 +99,14 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | |||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||||||
| def mse( | ||||||
| pred: torch.Tensor, | ||||||
| target: torch.Tensor, | ||||||
| pred_std: torch.Tensor, | ||||||
| mask: Optional[torch.Tensor] = None, | ||||||
| average_grid: bool = True, | ||||||
| sum_vars: bool = True, | ||||||
| ) -> torch.Tensor: | ||||||
| """ | ||||||
| (Unweighted) Mean Squared Error | ||||||
|
|
||||||
|
|
@@ -108,7 +130,14 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | |||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||||||
| def wmae( | ||||||
| pred: torch.Tensor, | ||||||
| target: torch.Tensor, | ||||||
| pred_std: torch.Tensor, | ||||||
| mask: Optional[torch.Tensor] = None, | ||||||
| average_grid: bool = True, | ||||||
| sum_vars: bool = True, | ||||||
| ) -> torch.Tensor: | ||||||
| """ | ||||||
| Weighted Mean Absolute Error | ||||||
|
|
||||||
|
|
@@ -139,7 +168,14 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | |||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||||||
| def mae( | ||||||
| pred: torch.Tensor, | ||||||
| target: torch.Tensor, | ||||||
| pred_std: torch.Tensor, | ||||||
| mask: Optional[torch.Tensor] = None, | ||||||
| average_grid: bool = True, | ||||||
| sum_vars: bool = True, | ||||||
| ) -> torch.Tensor: | ||||||
| """ | ||||||
| (Unweighted) Mean Absolute Error | ||||||
|
|
||||||
|
|
@@ -163,7 +199,14 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | |||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||||||
| def nll( | ||||||
| pred: torch.Tensor, | ||||||
| target: torch.Tensor, | ||||||
| pred_std: torch.Tensor, | ||||||
| mask: Optional[torch.Tensor] = None, | ||||||
| average_grid: bool = True, | ||||||
| sum_vars: bool = True, | ||||||
| ) -> torch.Tensor: | ||||||
| """ | ||||||
| Negative Log Likelihood loss, for isotropic Gaussian likelihood | ||||||
|
|
||||||
|
|
@@ -191,8 +234,13 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | |||||
|
|
||||||
|
|
||||||
| def crps_gauss( | ||||||
| pred, target, pred_std, mask=None, average_grid=True, sum_vars=True | ||||||
| ): | ||||||
| pred: torch.Tensor, | ||||||
| target: torch.Tensor, | ||||||
| pred_std: torch.Tensor, | ||||||
| mask: Optional[torch.Tensor] = None, | ||||||
| average_grid: bool = True, | ||||||
| sum_vars: bool = True, | ||||||
| ) -> torch.Tensor: | ||||||
| """ | ||||||
| (Negative) Continuous Ranked Probability Score (CRPS) | ||||||
| Closed-form expression based on Gaussian predictive distribution | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Project requires Python >= 3.10, so
Optionalfromtypingis a legacy alias.torch.Tensor | Noneis preferred. If you update the usages, this import can be dropped entirely.(if you also fix the
get_metricreturn type above)