Skip to content

Commit 029928b

Browse files
authored
Merge pull request #85 from sdpython/dev
Fixes #70, implements DecisionTreeLogisticRegression
2 parents 777bbb8 + f2a3cfd commit 029928b

File tree

4 files changed

+46
-13
lines changed

4 files changed

+46
-13
lines changed

_unittests/ut_mlmodel/test_decision_tree_logistic_regression.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ def test_decision_path(self):
146146
leaves = predict_leaves(dtlr, X_test)
147147
self.assertEqual(leaves.shape[0], X_test.shape[0])
148148

149+
def test_classifier_strat(self):
150+
X = numpy.array([[0.1, 0.2], [0.2, 0.3], [-0.2, -0.3], [0.4, 0.3]])
151+
Y = numpy.array([0, 1, 0, 1])
152+
dtlr = DecisionTreeLogisticRegression(
153+
fit_improve_algo=None, strategy='')
154+
self.assertRaise(lambda: dtlr.fit(X, Y), ValueError)
155+
149156

150157
if __name__ == "__main__":
151158
unittest.main()

mlinsights/mlmodel/decision_tree_logreg.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ class DecisionTreeLogisticRegression(BaseEstimator, ClassifierMixin):
315315
where *p* is the proportion of samples falling in the first
316316
fold.
317317
:param verbose: prints out information about the training
318+
:param strategy: `'parallel'` or `'perpendicular'`,
319+
see below
318320
319321
Fitted attributes:
320322
@@ -323,6 +325,14 @@ class DecisionTreeLogisticRegression(BaseEstimator, ClassifierMixin):
323325
or a list of arrays of class labels (multi-output problem).
324326
* `tree_`: Tree
325327
The underlying Tree object.
328+
329+
The class implements two strategies to build the tree.
330+
The first one `'parallel'` splits the feature space using
331+
the hyperplan defined by a logistic regression, the second
332+
strategy `'perpendicular'` splis the feature space based on
333+
a hyperplan perpendicular to a logistic regression. By doing
334+
this, two logistic regression fit on both sub parts must
335+
necessary decreases the training error.
326336
"""
327337

328338
_fit_improve_algo_values = (
@@ -332,7 +342,7 @@ def __init__(self, estimator=None,
332342
max_depth=20, min_samples_split=2,
333343
min_samples_leaf=2, min_weight_fraction_leaf=0.0,
334344
fit_improve_algo='auto', p1p2=0.09,
335-
gamma=1., verbose=0):
345+
gamma=1., verbose=0, strategy='parallel'):
336346
"constructor"
337347
ClassifierMixin.__init__(self)
338348
BaseEstimator.__init__(self)
@@ -354,6 +364,7 @@ def __init__(self, estimator=None,
354364
self.p1p2 = p1p2
355365
self.gamma = gamma
356366
self.verbose = verbose
367+
self.strategy = strategy
357368

358369
if self.fit_improve_algo not in DecisionTreeLogisticRegression._fit_improve_algo_values:
359370
raise ValueError(
@@ -392,13 +403,27 @@ def fit(self, X, y, sample_weight=None):
392403
raise RuntimeError(
393404
"The model only supports binary classification but labels are "
394405
"{}.".format(self.classes_))
406+
407+
if self.strategy == 'parallel':
408+
return self._fit_parallel(X, y, sample_weight)
409+
if self.strategy == 'perpendicular':
410+
return self._fit_perpendicular(X, y, sample_weight)
411+
raise ValueError(
412+
"Unknown strategy '{}'.".format(self.strategy))
413+
414+
def _fit_parallel(self, X, y, sample_weight):
415+
"Implements the parallel strategy."
395416
cls = (y == self.classes_[1]).astype(numpy.int32)
396417
estimator = clone(self.estimator)
397418
self.tree_ = _DecisionTreeLogisticRegressionNode(estimator, 0.5)
398419
self.n_nodes_ = self.tree_.fit(
399420
X, cls, sample_weight, self, X.shape[0]) + 1
400421
return self
401422

423+
def _fit_perpendicular(self, X, y, sample_weight):
424+
"Implements the perpendicular strategy."
425+
raise NotImplementedError()
426+
402427
def predict(self, X):
403428
"""
404429
Runs the predictions.

mlinsights/mlmodel/kmeans_l1.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sklearn.utils.extmath import stable_cumsum
1919
try:
2020
from sklearn.cluster._kmeans import _check_sample_weight
21-
except ImportError:
21+
except ImportError: # pragma: no cover
2222
from sklearn.cluster._kmeans import (
2323
_check_normalize_sample_weight as _check_sample_weight)
2424
from ._kmeans_022 import (
@@ -144,7 +144,7 @@ def _init_centroids(norm, X, k, init, random_state=None,
144144
X = X[init_indices]
145145
n_samples = X.shape[0]
146146
elif n_samples < k:
147-
raise ValueError(
147+
raise ValueError( # pragma: no cover
148148
"n_samples=%d should be larger than k=%d" % (n_samples, k))
149149

150150
if isinstance(init, str) and init == 'k-means++':
@@ -160,21 +160,22 @@ def _init_centroids(norm, X, k, init, random_state=None,
160160
centers = init(norm, X, k, random_state=random_state)
161161
centers = numpy.asarray(centers, dtype=X.dtype)
162162
else:
163-
raise ValueError("the init parameter for the k-means should "
164-
"be 'k-means++' or 'random' or an ndarray, "
165-
"'%s' (type '%s') was passed." % (init, type(init)))
163+
raise ValueError( # pragma: no cover
164+
"init parameter for the k-means should "
165+
"be 'k-means++' or 'random' or an ndarray, "
166+
"'%s' (type '%s') was passed." % (init, type(init)))
166167

167168
if issparse(centers):
168169
centers = centers.toarray()
169170

170171
def _validate_center_shape(X, k, centers):
171172
"""Check if centers is compatible with X and n_clusters"""
172173
if centers.shape[0] != k:
173-
raise ValueError(
174+
raise ValueError( # pragma: no cover
174175
f"The shape of the initial centers {centers.shape} does not "
175176
f"match the number of clusters {k}.")
176177
if centers.shape[1] != X.shape[1]:
177-
raise ValueError(
178+
raise ValueError( # pragma: no cover
178179
f"The shape of the initial centers {centers.shape} does not "
179180
f"match the number of features of the data {X.shape[1]}.")
180181

@@ -598,7 +599,7 @@ def _fit_l1(self, X, y=None, sample_weight=None):
598599
X, init)
599600

600601
if n_init != 1:
601-
warnings.warn(
602+
warnings.warn( # pragma: no cover
602603
'Explicit initial center position passed: '
603604
'performing only one init in k-means instead of n_init=%d'
604605
% n_init, RuntimeWarning, stacklevel=2)

mlinsights/mlmodel/sklearn_testing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _get_test_instance():
9393
try:
9494
from pyquickhelper.pycode import ExtTestCase # pylint: disable=C0415
9595
cls = ExtTestCase
96-
except ImportError:
96+
except ImportError: # pragma: no cover
9797

9898
class _ExtTestCase(TestCase):
9999
"simple test classe with a more methods"
@@ -185,7 +185,7 @@ def _assert_dict_equal(a, b, ext):
185185
if key not in b:
186186
rows.append("** Removed key '{0}' in a".format(key))
187187
if len(rows) > 0:
188-
raise AssertionError(
188+
raise AssertionError( # pragma: no cover
189189
"Dictionaries are different\n{0}".format('\n'.join(rows)))
190190

191191

@@ -290,7 +290,7 @@ def adjust(obj1, obj2):
290290
if hasattr(obj2, k):
291291
v1 = getattr(obj1, k)
292292
if callable(v1):
293-
raise RuntimeError(
293+
raise RuntimeError( # pragma: no cover
294294
"Cannot migrate trained parameters for {}.".format(obj1))
295295
elif isinstance(v1, BaseEstimator):
296296
v1 = getattr(obj1, k)
@@ -302,7 +302,7 @@ def adjust(obj1, obj2):
302302
v1 = getattr(obj1, k)
303303
setattr(obj2, k, clone_with_fitted_parameters(v1))
304304
else:
305-
raise RuntimeError(
305+
raise RuntimeError( # pragma: no cover
306306
"Cloned object is missing '{0}' in {1}.".format(k, obj2))
307307

308308
if isinstance(est, BaseEstimator):

0 commit comments

Comments
 (0)