Skip to content

Commit cc5b577

Browse files
author
Xuye (Chris) Qin
authored
[BACKPORT] Add _binary_roc_auc_score method (#2403) (#2477)
1 parent 0ac9abd commit cc5b577

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

mars/learn/metrics/_ranking.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,43 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None,
185185
return ret.execute(session=session, **(run_kwargs or dict()))
186186

187187

188+
def _binary_roc_auc_score(y_true, y_score, sample_weight=None,
189+
max_fpr=None, session=None, run_kwargs=None):
190+
"""Binary roc auc score."""
191+
192+
from numpy import interp
193+
194+
if len(mt.unique(y_true).execute()) != 2:
195+
raise ValueError("Only one class present in y_true. ROC AUC score "
196+
"is not defined in that case.")
197+
198+
fpr, tpr, _ = roc_curve(y_true, y_score, sample_weight=sample_weight,
199+
session=session, run_kwargs=run_kwargs)
200+
fpr, tpr = mt.ExecutableTuple([fpr, tpr]).fetch(session=session)
201+
202+
if max_fpr is None or max_fpr == 1:
203+
return auc(fpr, tpr, session=session, run_kwargs=run_kwargs).fetch(session=session)
204+
if max_fpr <= 0 or max_fpr > 1:
205+
raise ValueError(f"Expected max_fpr in range (0, 1], got: {max_fpr}")
206+
207+
# Add a single point at max_fpr by linear interpolation
208+
stop = mt.searchsorted(fpr, max_fpr, 'right').execute(
209+
session=session, **(run_kwargs or dict())).fetch(session=session)
210+
x_interp = [fpr[stop - 1], fpr[stop]]
211+
y_interp = [tpr[stop - 1], tpr[stop]]
212+
tpr = list(tpr[:stop])
213+
tpr.append(interp(max_fpr, x_interp, y_interp))
214+
fpr = list(fpr[:stop])
215+
fpr.append(max_fpr)
216+
partial_auc = auc(fpr, tpr, session=session, run_kwargs=run_kwargs)
217+
218+
# McClish correction: standardize result to be 0.5 if non-discriminant
219+
# and 1 if maximal
220+
min_area = 0.5 * max_fpr**2
221+
max_area = max_fpr
222+
return 0.5 * (1 + (partial_auc.fetch(session=session) - min_area) / (max_area - min_area))
223+
224+
188225
def roc_curve(y_true, y_score, pos_label=None, sample_weight=None,
189226
drop_intermediate=True, session=None, run_kwargs=None):
190227
"""Compute Receiver operating characteristic (ROC)

mars/learn/metrics/tests/test_ranking.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@
2525
from sklearn.exceptions import UndefinedMetricWarning
2626
from sklearn.utils import check_random_state
2727
from sklearn.utils._testing import assert_warns
28+
from sklearn.metrics._ranking import _binary_roc_auc_score as sk_binary_roc_auc_score
2829
except ImportError: # pragma: no cover
2930
sklearn = None
3031
import pytest
3132

33+
3234
from .... import dataframe as md
3335
from .... import tensor as mt
3436
from .. import roc_curve, auc, accuracy_score
37+
from .._ranking import _binary_roc_auc_score
3538

3639

3740
def test_roc_curve(setup):
@@ -149,6 +152,30 @@ def test_roc_curve_one_label(setup):
149152
assert fpr.shape == thresholds.shape
150153

151154

155+
def test_binary_roc_auc_score(setup):
156+
# Test the area is equal under binary roc_auc_score
157+
rs = np.random.RandomState(0)
158+
raw_X = rs.randint(0, 2, size=10)
159+
raw_Y = rs.rand(10).astype('float32')
160+
161+
X = mt.tensor(raw_X)
162+
Y = mt.tensor(raw_Y)
163+
164+
for max_fpr in (np.random.rand(), None):
165+
# Calculate the score using both frameworks
166+
score = _binary_roc_auc_score(X, Y, max_fpr=max_fpr)
167+
expected_score = sk_binary_roc_auc_score(raw_X, raw_Y, max_fpr=max_fpr)
168+
169+
# Both the scores should be equal
170+
np.testing.assert_almost_equal(score, expected_score, decimal=6)
171+
172+
with pytest.raises(ValueError):
173+
_binary_roc_auc_score(mt.tensor([0]), Y)
174+
175+
with pytest.raises(ValueError):
176+
_binary_roc_auc_score(X, Y, max_fpr=0)
177+
178+
152179
def test_roc_curve_drop_intermediate(setup):
153180
# Test that drop_intermediate drops the correct thresholds
154181
y_true = [0, 0, 0, 0, 1, 1]

0 commit comments

Comments
 (0)