Skip to content

Commit

Permalink
[BACKPORT] Implements mars.learn.wrappers.ParallelPostFit (#2425) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi authored Sep 5, 2021
1 parent 3b8d61d commit ec2be5f
Show file tree
Hide file tree
Showing 10 changed files with 556 additions and 3 deletions.
16 changes: 16 additions & 0 deletions docs/source/reference/learn/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,22 @@ Utilities
utils.validation.check_is_fitted
utils.validation.column_or_1d

.. _learn_misc_ref:
Misc
====

.. automodule:: mars.learn.wrappers
:no-members:
:no-inherited-members:

.. currentmodule:: mars.learn

.. autosummary::
:toctree: generated/

wrappers.ParallelPostFit

.. _lightgbm_ref:

LightGBM Integration
Expand Down
2 changes: 0 additions & 2 deletions mars/core/entity/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def recursive_tile(tileable: TileableType, *tileables: TileableType) -> \
tileable = raw[0]
tileables = raw[1:]

inputs_set = set(tileable.op.inputs)
to_tile = [tileable] + list(tileables)
q = [t for t in to_tile if t.is_coarse()]
while q:
Expand All @@ -79,7 +78,6 @@ def recursive_tile(tileable: TileableType, *tileables: TileableType) -> \
for inp in t.op.inputs:
if has_unknown_shape(inp):
to_update_inputs.append(inp)
if inp not in inputs_set:
chunks.extend(inp.chunks)
if obj is None:
yield chunks + to_update_inputs
Expand Down
1 change: 1 addition & 0 deletions mars/learn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from ._classification import accuracy_score, log_loss
from ._ranking import roc_curve, auc
from ._regresssion import r2_score
from ._scorer import get_scorer
59 changes: 59 additions & 0 deletions mars/learn/metrics/_scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Union

from sklearn.metrics import make_scorer

from . import accuracy_score, log_loss, r2_score


accuracy_score = make_scorer(accuracy_score)
r2_score = make_scorer(r2_score)
neg_log_loss_scorer = make_scorer(log_loss, greater_is_better=False,
needs_proba=True)


SCORERS = dict(
r2=r2_score,
accuracy=accuracy_score,
neg_log_loss=neg_log_loss_scorer,
)


def get_scorer(score_func: Union[str, Callable], **kwargs) -> Callable:
"""
Get a scorer from string
Parameters
----------
score_func : str | callable
scoring method as string. If callable it is returned as is.
Returns
-------
scorer : callable
The scorer.
"""
if isinstance(score_func, str):
try:
scorer = SCORERS[score_func]
except KeyError:
raise ValueError(
"{} is not a valid scoring value. "
"Valid options are {}".format(score_func, sorted(SCORERS))
)
return scorer
else:
return make_scorer(score_func, **kwargs)
26 changes: 26 additions & 0 deletions mars/learn/metrics/tests/test_scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from sklearn.metrics import r2_score

from .. import get_scorer


def test_get_scorer():
with pytest.raises(ValueError):
get_scorer('unknown')

assert get_scorer('r2') is not None
assert get_scorer(r2_score) is not None
13 changes: 13 additions & 0 deletions mars/learn/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
102 changes: 102 additions & 0 deletions mars/learn/tests/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
from sklearn.datasets import make_classification
from sklearn.decomposition import PCA
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LinearRegression, LogisticRegression

from ... import tensor as mt
from ..wrappers import ParallelPostFit


raw_x, raw_y = make_classification(n_samples=1000)
X, y = mt.tensor(raw_x, chunk_size=100), mt.tensor(raw_y, chunk_size=100)


def test_parallel_post_fit_basic(setup):
clf = ParallelPostFit(GradientBoostingClassifier())
clf.fit(X, y)

assert isinstance(clf.predict(X), mt.Tensor)
assert isinstance(clf.predict_proba(X), mt.Tensor)

result = clf.score(X, y)
expected = clf.estimator.score(X, y)
assert result.fetch() == expected

clf = ParallelPostFit(LinearRegression())
clf.fit(X, y)
with pytest.raises(AttributeError,
match="The wrapped estimator (.|\n)* 'predict_proba' method."):
clf.predict_proba(X)


def test_parallel_post_fit_predict(setup):
base = LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs")
wrap = ParallelPostFit(LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs"))

base.fit(X, y)
wrap.fit(X, y)

result = wrap.predict(X)
expected = base.predict(X)
np.testing.assert_allclose(result, expected)

result = wrap.predict_proba(X)
expected = base.predict_proba(X)
np.testing.assert_allclose(result, expected)

result = wrap.predict_log_proba(X)
expected = base.predict_log_proba(X)
np.testing.assert_allclose(result, expected)


def test_parallel_post_fit_transform(setup):
base = PCA(random_state=0)
wrap = ParallelPostFit(PCA(random_state=0))

base.fit(raw_x, raw_y)
wrap.fit(X, y)

result = base.transform(X)
expected = wrap.transform(X)
np.testing.assert_allclose(result, expected, atol=.1)


def test_parallel_post_fit_multiclass(setup):
raw_x, raw_y = make_classification(n_classes=3, n_informative=4)
X, y = mt.tensor(raw_x, chunk_size=50), mt.tensor(raw_y, chunk_size=50)

clf = ParallelPostFit(
LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs", multi_class="auto")
)

clf.fit(X, y)
result = clf.predict(X)
expected = clf.estimator.predict(X)

np.testing.assert_allclose(result, expected)

result = clf.predict_proba(X)
expected = clf.estimator.predict_proba(X)

np.testing.assert_allclose(result, expected)

result = clf.predict_log_proba(X)
expected = clf.estimator.predict_log_proba(X)

np.testing.assert_allclose(result, expected)
2 changes: 1 addition & 1 deletion mars/learn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .collect_ports import collect_ports
from .core import convert_to_tensor_or_dataframe, \
concat_chunks
concat_chunks, copy_learned_attributes
from .validation import check_array, assert_all_finite, \
check_consistent_length, column_or_1d, check_X_y
from .shuffle import shuffle
8 changes: 8 additions & 0 deletions mars/learn/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import pandas as pd
from sklearn.base import BaseEstimator

from ...tensor import tensor as astensor
from ...dataframe import DataFrame, Series
Expand All @@ -32,3 +33,10 @@ def convert_to_tensor_or_dataframe(item):
def concat_chunks(chunks):
tileable = chunks[0].op.create_tileable_from_chunks(chunks)
return tileable.op.concat_tileable_chunks(tileable).chunks[0]


def copy_learned_attributes(from_estimator: BaseEstimator,
to_estimator: BaseEstimator):
attrs = {k: v for k, v in vars(from_estimator).items() if k.endswith('_')}
for k, v in attrs.items():
setattr(to_estimator, k, v)
Loading

0 comments on commit ec2be5f

Please sign in to comment.