Skip to content

Commit 6da965e

Browse files
prasunanandchauhankaranraj
authored andcommitted
Annotate metrics (dask#630)
* Add annotations for metrics module
1 parent 49a550b commit 6da965e

11 files changed

+125
-24
lines changed

ci/code_checks.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,8 @@ MSG='Checking isort... ' ; echo $MSG
1313
isort --recursive --check-only .
1414
RET=$(($RET + $?)) ; echo $MSG "DONE"
1515

16+
MSG='Checking mypy... ' ; echo $MSG
17+
mypy dask_ml/metrics
18+
RET=$(($RET + $?)) ; echo $MSG "DONE"
19+
1620
exit $RET

ci/environment-3.6.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies:
1212
- isort
1313
- msgpack-python ==0.6.2
1414
- multipledispatch
15+
- mypy
1516
- numba
1617
- numpy ==1.17.3
1718
- numpydoc

ci/environment-3.7.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies:
1212
- flake8
1313
- isort
1414
- multipledispatch >=0.4.9
15+
- mypy
1516
- numba
1617
- numpy >=1.16.3
1718
- numpydoc

ci/environment-docs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies:
1212
- ipython
1313
- isort
1414
- multipledispatch
15+
- mypy
1516
- nbsphinx
1617
- nomkl
1718
- nose

ci/windows.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,7 @@ jobs:
4242
echo "[codecov]"
4343
codecov
4444
45+
echo "[mypy]"
46+
mypy dask_ml/metrics
47+
4548
displayName: "Lint"

dask_ml/_typing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import TypeVar
2+
3+
import numpy as np
4+
from dask.array import Array
5+
from pandas import Index, Series
6+
7+
# array-like
8+
9+
AnyArrayLike = TypeVar("AnyArrayLike", Index, Series, Array, np.ndarray)
10+
ArrayLike = TypeVar("ArrayLike", Array, np.ndarray)

dask_ml/metrics/classification.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
1+
from typing import Optional
2+
13
import dask
24
import dask.array as da
35
import numpy as np
46
import sklearn.metrics
57
import sklearn.utils.multiclass
68

9+
from .._typing import ArrayLike
10+
711

8-
def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None, compute=True):
12+
def accuracy_score(
13+
y_true: ArrayLike,
14+
y_pred: ArrayLike,
15+
normalize: bool = True,
16+
sample_weight: Optional[ArrayLike] = None,
17+
compute: bool = True,
18+
) -> ArrayLike:
919
"""Accuracy classification score.
1020
1121
In multilabel classification, this function computes subset accuracy:
@@ -84,7 +94,9 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None, compute=T
8494
return score
8595

8696

87-
def _log_loss_inner(x, y, sample_weight, **kwargs):
97+
def _log_loss_inner(
98+
x: ArrayLike, y: ArrayLike, sample_weight: Optional[ArrayLike], **kwargs
99+
):
88100
# da.map_blocks wasn't able to concatenate together the results
89101
# when we reduce down to a scalar per block. So we make an
90102
# array with 1 element.
@@ -110,7 +122,7 @@ def log_loss(
110122

111123
if y_pred.ndim > 1 and y_true.ndim == 1:
112124
y_true = y_true.reshape(-1, 1)
113-
drop_axis = 1
125+
drop_axis: Optional[int] = 1
114126
if sample_weight is not None:
115127
sample_weight = sample_weight.reshape(-1, 1)
116128
else:

dask_ml/metrics/pairwise.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Daskified versions of sklearn.metrics.pairwise
33
"""
44
import warnings
5+
from typing import Any, Callable, Dict, Optional, Tuple, Union
56

67
import dask.array as da
78
import numpy as np
@@ -10,11 +11,17 @@
1011
from sklearn import metrics
1112
from sklearn.metrics.pairwise import KERNEL_PARAMS
1213

14+
from .._typing import ArrayLike
1315
from ..utils import row_norms
1416

1517

1618
def pairwise_distances_argmin_min(
17-
X, Y, axis=1, metric="euclidean", batch_size=None, metric_kwargs=None
19+
X: ArrayLike,
20+
Y: ArrayLike,
21+
axis: int = 1,
22+
metric: Union[str, Callable[[ArrayLike, ArrayLike], float]] = "euclidean",
23+
batch_size: Optional[int] = None,
24+
metric_kwargs: Optional[Dict[str, Any]] = None,
1825
):
1926
if batch_size is not None:
2027
msg = "'batch_size' is deprecated. Use sklearn.config_context instead.'"
@@ -39,7 +46,13 @@ def pairwise_distances_argmin_min(
3946
return argmins, mins
4047

4148

42-
def pairwise_distances(X, Y, metric="euclidean", n_jobs=None, **kwargs):
49+
def pairwise_distances(
50+
X: ArrayLike,
51+
Y: ArrayLike,
52+
metric: Union[str, Callable[[ArrayLike, ArrayLike], float]] = "euclidean",
53+
n_jobs: Optional[int] = None,
54+
**kwargs: Any
55+
):
4356
if isinstance(Y, da.Array):
4457
raise TypeError("`Y` must be a numpy array")
4558
chunks = (X.chunks[0], (len(Y),))
@@ -54,8 +67,12 @@ def pairwise_distances(X, Y, metric="euclidean", n_jobs=None, **kwargs):
5467

5568

5669
def euclidean_distances(
57-
X, Y=None, Y_norm_squared=None, squared=False, X_norm_squared=None
58-
):
70+
X: ArrayLike,
71+
Y: Optional[ArrayLike] = None,
72+
Y_norm_squared: Optional[ArrayLike] = None,
73+
squared: bool = False,
74+
X_norm_squared: Optional[ArrayLike] = None,
75+
) -> ArrayLike:
5976
if Y is None:
6077
Y = X
6178

@@ -87,7 +104,9 @@ def euclidean_distances(
87104
return distances if squared else da.sqrt(distances)
88105

89106

90-
def check_pairwise_arrays(X, Y, precomputed=False):
107+
def check_pairwise_arrays(
108+
X: ArrayLike, Y: ArrayLike, precomputed: bool = False
109+
) -> Tuple[ArrayLike, ArrayLike]:
91110
# XXX
92111
if Y is None:
93112
Y = X
@@ -113,13 +132,15 @@ def check_pairwise_arrays(X, Y, precomputed=False):
113132

114133

115134
@derived_from(metrics.pairwise)
116-
def linear_kernel(X, Y=None):
135+
def linear_kernel(X: ArrayLike, Y: Optional[ArrayLike] = None) -> ArrayLike:
117136
X, Y = check_pairwise_arrays(X, Y)
118137
return da.dot(X, Y.T)
119138

120139

121140
@derived_from(metrics.pairwise)
122-
def rbf_kernel(X, Y=None, gamma=None):
141+
def rbf_kernel(
142+
X: ArrayLike, Y: Optional[ArrayLike] = None, gamma: Optional[float] = None
143+
) -> ArrayLike:
123144
X, Y = check_pairwise_arrays(X, Y)
124145
if gamma is None:
125146
gamma = 1.0 / X.shape[1]
@@ -130,7 +151,13 @@ def rbf_kernel(X, Y=None, gamma=None):
130151

131152

132153
@derived_from(metrics.pairwise)
133-
def polynomial_kernel(X, Y=None, degree=3, gamma=None, coef0=1):
154+
def polynomial_kernel(
155+
X: ArrayLike,
156+
Y: Optional[ArrayLike] = None,
157+
degree: int = 3,
158+
gamma: Optional[float] = None,
159+
coef0: float = 1,
160+
) -> ArrayLike:
134161
X, Y = check_pairwise_arrays(X, Y)
135162
if gamma is None:
136163
gamma = 1.0 / X.shape[1]
@@ -140,7 +167,12 @@ def polynomial_kernel(X, Y=None, degree=3, gamma=None, coef0=1):
140167

141168

142169
@derived_from(metrics.pairwise)
143-
def sigmoid_kernel(X, Y=None, gamma=None, coef0=1):
170+
def sigmoid_kernel(
171+
X: ArrayLike,
172+
Y: Optional[ArrayLike] = None,
173+
gamma: Optional[float] = None,
174+
coef0: float = 1,
175+
) -> ArrayLike:
144176
X, Y = check_pairwise_arrays(X, Y)
145177
if gamma is None:
146178
gamma = 1.0 / X.shape[1]
@@ -165,7 +197,14 @@ def sigmoid_kernel(X, Y=None, gamma=None, coef0=1):
165197
}
166198

167199

168-
def pairwise_kernels(X, Y=None, metric="linear", filter_params=False, n_jobs=1, **kwds):
200+
def pairwise_kernels(
201+
X: ArrayLike,
202+
Y: Optional[ArrayLike] = None,
203+
metric: Union[str, Callable[[ArrayLike, ArrayLike], float]] = "linear",
204+
filter_params: bool = False,
205+
n_jobs: Optional[int] = 1,
206+
**kwds
207+
):
169208
from sklearn.gaussian_process.kernels import Kernel as GPKernel
170209

171210
if metric == "precomputed":
@@ -176,6 +215,7 @@ def pairwise_kernels(X, Y=None, metric="linear", filter_params=False, n_jobs=1,
176215
elif metric in PAIRWISE_KERNEL_FUNCTIONS:
177216
if filter_params:
178217
kwds = dict((k, kwds[k]) for k in kwds if k in KERNEL_PARAMS[metric])
218+
assert isinstance(metric, str)
179219
func = PAIRWISE_KERNEL_FUNCTIONS[metric]
180220
elif callable(metric):
181221
raise NotImplementedError()

dask_ml/metrics/regression.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1+
from typing import Optional
2+
13
import dask.array as da
24
import numpy as np
35
import sklearn.metrics
46
from dask.utils import derived_from
57

8+
from .._typing import ArrayLike
9+
610

7-
def _check_sample_weight(sample_weight):
11+
def _check_sample_weight(sample_weight: Optional[ArrayLike]):
812
if sample_weight is not None:
913
raise ValueError("'sample_weight' is not supported.")
1014

1115

12-
def _check_reg_targets(y_true, y_pred, multioutput):
16+
def _check_reg_targets(
17+
y_true: ArrayLike, y_pred: ArrayLike, multioutput: Optional[str]
18+
):
1319
if multioutput != "uniform_average":
1420
raise NotImplementedError("'multioutput' must be 'uniform_average'")
1521

@@ -24,8 +30,12 @@ def _check_reg_targets(y_true, y_pred, multioutput):
2430

2531
@derived_from(sklearn.metrics)
2632
def mean_squared_error(
27-
y_true, y_pred, sample_weight=None, multioutput="uniform_average", compute=True
28-
):
33+
y_true: ArrayLike,
34+
y_pred: ArrayLike,
35+
sample_weight: Optional[ArrayLike] = None,
36+
multioutput: Optional[str] = "uniform_average",
37+
compute: bool = True,
38+
) -> ArrayLike:
2939
_check_sample_weight(sample_weight)
3040
output_errors = ((y_pred - y_true) ** 2).mean(axis=0)
3141

@@ -45,8 +55,12 @@ def mean_squared_error(
4555

4656
@derived_from(sklearn.metrics)
4757
def mean_absolute_error(
48-
y_true, y_pred, sample_weight=None, multioutput="uniform_average", compute=True
49-
):
58+
y_true: ArrayLike,
59+
y_pred: ArrayLike,
60+
sample_weight: Optional[ArrayLike] = None,
61+
multioutput: Optional[str] = "uniform_average",
62+
compute: bool = True,
63+
) -> ArrayLike:
5064
_check_sample_weight(sample_weight)
5165
output_errors = abs(y_pred - y_true).mean(axis=0)
5266

@@ -66,8 +80,12 @@ def mean_absolute_error(
6680

6781
@derived_from(sklearn.metrics)
6882
def r2_score(
69-
y_true, y_pred, sample_weight=None, multioutput="uniform_average", compute=True
70-
):
83+
y_true: ArrayLike,
84+
y_pred: ArrayLike,
85+
sample_weight: Optional[ArrayLike] = None,
86+
multioutput: Optional[str] = "uniform_average",
87+
compute: bool = True,
88+
) -> ArrayLike:
7189
_check_sample_weight(sample_weight)
7290
_, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)
7391
weight = 1.0

dask_ml/metrics/scorer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
from typing import Any, Callable, Tuple, Union
2+
13
from sklearn.metrics import check_scoring as sklearn_check_scoring, make_scorer
24

35
from . import accuracy_score, log_loss, mean_squared_error, r2_score
46

57
# Scorers
6-
accuracy_scorer = (accuracy_score, {})
8+
accuracy_scorer: Tuple[Any, Any] = (accuracy_score, {})
79
neg_mean_squared_error_scorer = (mean_squared_error, dict(greater_is_better=False))
8-
r2_scorer = (r2_score, {})
10+
r2_scorer: Tuple[Any, Any] = (r2_score, {})
911
neg_log_loss_scorer = (log_loss, dict(greater_is_better=False, needs_proba=True))
1012

1113

@@ -17,7 +19,7 @@
1719
)
1820

1921

20-
def get_scorer(scoring, compute=True):
22+
def get_scorer(scoring: Union[str, Callable], compute: bool = True) -> Callable:
2123
"""Get a scorer from string
2224
2325
Parameters

setup.cfg

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ skip=
2525
[coverage:run]
2626
source=dask_ml
2727

28+
[mypy]
29+
ignore_missing_imports=True
30+
no_implicit_optional=True
31+
check_untyped_defs=True
32+
strict_equality=True
33+
34+
[mypy-dask_ml.metrics]
35+
check_untyped_defs=False
36+
2837
[tool:pytest]
2938
addopts = -rsx -v --durations=10
3039
minversion = 3.2

0 commit comments

Comments
 (0)