diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index fbb773e8024f0..17f2704e0829b 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1113,7 +1113,7 @@ Model validation neighbors.RadiusNeighborsRegressor neighbors.NearestCentroid neighbors.NearestNeighbors - neighbors.NeighborhoodComponentAnalysis + neighbors.NeighborhoodComponentsAnalysis .. autosummary:: :toctree: generated/ diff --git a/sklearn/neighbors/__init__.py b/sklearn/neighbors/__init__.py index 8e211ef9ec448..367928fad5b5a 100644 --- a/sklearn/neighbors/__init__.py +++ b/sklearn/neighbors/__init__.py @@ -14,7 +14,7 @@ from .kde import KernelDensity from .approximate import LSHForest from .lof import LocalOutlierFactor -from .nca import NeighborhoodComponentAnalysis +from .nca import NeighborhoodComponentsAnalysis __all__ = ['BallTree', 'DistanceMetric', @@ -30,4 +30,4 @@ 'KernelDensity', 'LSHForest', 'LocalOutlierFactor', - 'NeighborhoodComponentAnalysis'] + 'NeighborhoodComponentsAnalysis'] diff --git a/sklearn/neighbors/nca.py b/sklearn/neighbors/nca.py index a3d35e4dd0a2d..4179faa9f6315 100644 --- a/sklearn/neighbors/nca.py +++ b/sklearn/neighbors/nca.py @@ -12,6 +12,7 @@ import time from scipy.misc import logsumexp from scipy.optimize import minimize +from sklearn.preprocessing import OneHotEncoder from ..base import BaseEstimator, TransformerMixin from ..preprocessing import LabelEncoder @@ -22,12 +23,12 @@ from ..externals.six import integer_types -class NeighborhoodComponentAnalysis(BaseEstimator, TransformerMixin): - """Neighborhood Component Analysis +class NeighborhoodComponentsAnalysis(BaseEstimator, TransformerMixin): + """Neighborhood Components Analysis Parameters ---------- - n_features_out: int, optional (default=None) + n_features_out : int, optional (default=None) Preferred dimensionality of the embedding. init : string or numpy array, optional (default='pca') @@ -87,10 +88,10 @@ class NeighborhoodComponentAnalysis(BaseEstimator, TransformerMixin): Attributes ---------- transformation_ : array, shape (n_features_out, n_features) - The linear transformation learned during fitting. + The linear transformation learned during fitting. n_iter_ : int - Counts the number of iterations performed by the optimizer. + Counts the number of iterations performed by the optimizer. opt_result_ : scipy.optimize.OptimizeResult (optional) A dictionary of information representing the optimization result. @@ -98,16 +99,16 @@ class NeighborhoodComponentAnalysis(BaseEstimator, TransformerMixin): Examples -------- - >>> from sklearn.neighbors.nca import NeighborhoodComponentAnalysis + >>> from sklearn.neighbors.nca import NeighborhoodComponentsAnalysis >>> from sklearn.neighbors import KNeighborsClassifier >>> from sklearn.datasets import load_iris >>> from sklearn.model_selection import train_test_split >>> X, y = load_iris(return_X_y=True) >>> X_train, X_test, y_train, y_test = train_test_split(X, y, ... stratify=y, test_size=0.7, random_state=42) - >>> nca = NeighborhoodComponentAnalysis(None,random_state=42) + >>> nca = NeighborhoodComponentsAnalysis(random_state=42) >>> nca.fit(X_train, y_train) # doctest: +ELLIPSIS - NeighborhoodComponentAnalysis(...) + NeighborhoodComponentsAnalysis(...) >>> knn = KNeighborsClassifier(n_neighbors=3) >>> knn.fit(X_train, y_train) # doctest: +ELLIPSIS KNeighborsClassifier(...) @@ -121,15 +122,9 @@ class NeighborhoodComponentAnalysis(BaseEstimator, TransformerMixin): Notes ----- Neighborhood Component Analysis (NCA) is a machine learning algorithm for - metric learning. It learns a linear transformation of the space in a - supervised fashion to improve the classification accuracy of a - stochastic nearest neighbors rule in this new space. - - .. warning:: - - As NCA is optimizing a non-convex objective function, it will - likely end up in a local optimum. Several runs with independent random - init might be necessary to get a good convergence. + metric learning. It learns a linear transformation in a supervised fashion + to improve the classification accuracy of a stochastic nearest neighbors + rule in the transformed space. References ---------- @@ -137,9 +132,13 @@ class NeighborhoodComponentAnalysis(BaseEstimator, TransformerMixin): "Neighbourhood Components Analysis". Advances in Neural Information Processing Systems. 17, 513-520, 2005. http://www.cs.nyu.edu/~roweis/papers/ncanips.pdf + + .. [2] Wikipedia entry on Neighborhood Components Analysis + https://en.wikipedia.org/wiki/Neighbourhood_components_analysis + """ - def __init__(self, n_features_out=None, init='identity', max_iter=50, + def __init__(self, n_features_out=None, init='pca', max_iter=50, tol=1e-5, callback=None, store_opt_result=False, verbose=0, random_state=None): @@ -167,7 +166,7 @@ def fit(self, X, y): Returns ------- self : object - returns a trained NeighborhoodComponentAnalysis model. + returns a trained NeighborhoodComponentsAnalysis model. """ # Verify inputs X and y and NCA parameters, and transform a copy if @@ -182,7 +181,8 @@ def fit(self, X, y): # Compute arrays that stay fixed during optimization: # mask for fast lookup of same-class samples - masks = _make_masks(y_valid) + masks = OneHotEncoder(sparse=False, + dtype=bool).fit_transform(y_valid[:, np.newaxis]) # pairwise differences diffs = X_valid[:, np.newaxis] - X_valid[np.newaxis] @@ -193,7 +193,7 @@ def fit(self, X, y): disp = self.verbose - 2 if self.verbose > 1 else -1 optimizer_params = {'method': 'L-BFGS-B', 'fun': self._loss_grad_lbfgs, - 'args': (X_valid, y_valid, diffs, masks), + 'args': (X_valid, y_valid, diffs, masks, -1.0), 'jac': True, 'x0': transformation, 'tol': self.tol, @@ -401,7 +401,7 @@ def _callback(self, transformation): self.n_iter_ += 1 def _loss_grad_lbfgs(self, transformation, X, y, diffs, - masks): + masks, sign=1.0): """Compute the loss and the loss gradient w.r.t. ``transformation``. Parameters @@ -430,31 +430,58 @@ def _loss_grad_lbfgs(self, transformation, X, y, diffs, The new (flattened) gradient of the loss. """ + if self.n_iter_ == 0: + self.n_iter_ += 1 + if self.verbose: + header_fields = ['Iteration', 'Objective Value', 'Time(s)'] + header_fmt = '{:>10} {:>20} {:>10}' + header = header_fmt.format(*header_fields) + cls_name = self.__class__.__name__ + print('[{}]'.format(cls_name)) + print('[{}] {}\n[{}] {}'.format(cls_name, header, + cls_name, '-' * len(header))) + + t_funcall = time.time() + transformation = transformation.reshape(-1, X.shape[1]) loss = 0 gradient = np.zeros(transformation.shape) X_embedded = transformation.dot(X.T).T - # for every sample, compute its contribution to loss and gradient + # for every sample x_i, compute its contribution to loss and gradient for i in range(X.shape[0]): + # compute squared distances to x_i in embedded space diff_embedded = X_embedded[i] - X_embedded - sum_of_squares = np.einsum('ij,ij->i', diff_embedded, - diff_embedded) - sum_of_squares[i] = np.inf - soft = np.exp(-sum_of_squares - logsumexp(-sum_of_squares)) - ci = masks[:, y[i]] - p_i_j = soft[ci] - not_ci = np.logical_not(ci) - diff_ci = diffs[i, ci, :] # n_samples * n_features - diff_not_ci = diffs[i, not_ci, :] + dist_embedded = np.einsum('ij,ij->i', diff_embedded, + diff_embedded) + dist_embedded[i] = np.inf + + # compute exponentiated distances (use the log-sum-exp trick to + # avoid numerical instabilities + exp_dist_embedded = np.exp(-dist_embedded - + logsumexp(-dist_embedded)) + ci = masks[:, y[i]] # samples that are in the same class as x_i + p_i_j = exp_dist_embedded[ci] + diff_ci = diffs[i, ci, :] + diff_not_ci = diffs[i, ~ci, :] sum_ci = diff_ci.T.dot( (p_i_j[:, np.newaxis] * diff_embedded[ci, :])) - sum_not_ci = diff_not_ci.T.dot((soft[not_ci][:, np.newaxis] * - diff_embedded[not_ci, :])) - p_i = np.sum(p_i_j) + sum_not_ci = diff_not_ci.T.dot((exp_dist_embedded[~ci][:, + np.newaxis] * + diff_embedded[~ci, :])) + p_i = np.sum(p_i_j) # probability of x_i to be correctly + # classified gradient += 2 * (p_i * (sum_ci.T + sum_not_ci.T) - sum_ci.T) loss += p_i - return - loss, - gradient.ravel() + + if self.verbose: + t_funcall = time.time() - t_funcall + values_fmt = '[{}] {:>10} {:>20.6e} {:>10.2f}' + print(values_fmt.format(self.__class__.__name__, self.n_iter_, + loss, t_funcall)) + sys.stdout.flush() + + return sign * loss, sign * gradient.ravel() ########################## @@ -502,23 +529,3 @@ def _check_scalar(x, name, target_type, min_val=None, max_val=None): if max_val is not None and x > max_val: raise ValueError('`{}`= {}, must be <= {}.'.format(name, x, max_val)) - - -def _make_masks(y): - """Create one-hot encoding of vector ``y``. - - Parameters - ---------- - y : array, shape (n_samples,) - Data samples labels. - - Returns - ------- - masks: array, shape (n_samples, n_classes) - One-hot encoding of ``y``. - """ - - n = y.shape[0] - masks = np.zeros((n, y.max() + 1)) - masks[np.arange(n), y] = [1] - return masks.astype(bool) diff --git a/sklearn/neighbors/tests/test_nca.py b/sklearn/neighbors/tests/test_nca.py index fccc046e51892..7f3edaa9beee2 100644 --- a/sklearn/neighbors/tests/test_nca.py +++ b/sklearn/neighbors/tests/test_nca.py @@ -1,9 +1,10 @@ import numpy as np +from numpy.testing import assert_array_equal +from sklearn.preprocessing import OneHotEncoder from sklearn.utils import check_random_state from sklearn.utils.testing import assert_raises, assert_equal from sklearn.datasets import load_iris, make_classification -from sklearn.model_selection import train_test_split -from sklearn.neighbors.nca import NeighborhoodComponentAnalysis, _make_masks +from sklearn.neighbors.nca import NeighborhoodComponentsAnalysis from sklearn.metrics import pairwise_distances @@ -16,6 +17,24 @@ EPS = np.finfo(float).eps +def test_simple_example(): + """Test on a simple example. + + Puts four points in the input space where the opposite labels points are + next to each other. After transform the same labels points should be next + to each other. + + """ + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + nca = NeighborhoodComponentsAnalysis(n_features_out=2, init='identity', + random_state=42) + nca.fit(X, y) + Xansformed = nca.transform(X) + np.testing.assert_equal(pairwise_distances(Xansformed).argsort()[:, 1], + np.array([2, 3, 0, 1])) + + def test_finite_differences(): r"""Test gradient of loss function @@ -47,19 +66,16 @@ def test_finite_differences(): """ # Initialize `transformation`, `X` and `y` and `NCA` - random_state = check_random_state(0) - n_features = 10 - num_dims = 2 - n_samples = 100 - n_labels = 3 - y = random_state.randint(0, n_labels, (n_samples)) - point = random_state.randn(num_dims, n_features) - X = random_state.randn(n_samples, n_features) - nca = NeighborhoodComponentAnalysis(None, init=point) + X = iris_data + y = iris_target + point = rng.randn(rng.randint(1, X.shape[1] + 1), X.shape[1]) + nca = NeighborhoodComponentsAnalysis(init=point) X, y, init = nca._validate_params(X, y) - masks = _make_masks(y) + masks = OneHotEncoder(sparse=False, + dtype=bool).fit_transform(y[:, np.newaxis]) diffs = X[:, np.newaxis] - X[np.newaxis] + nca.n_iter_ = 0 point = nca._initialize(X, init) # compute the gradient at `point` @@ -67,7 +83,7 @@ def test_finite_differences(): masks) # create a random direction of norm 1 - random_direction = random_state.randn(*point.shape) + random_direction = rng.randn(*point.shape) random_direction /= np.linalg.norm(random_direction) # computes projected gradient @@ -88,36 +104,17 @@ def test_finite_differences(): np.testing.assert_almost_equal(relative_error, 0.) -def test_simple_example(): - """Test on a simple example. - - Puts four points in the input space where the opposite labels points are - next to each other. After transform the same labels points should be next - to each other. - - """ - X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) - y = np.array([1, 0, 1, 0]) - nca = NeighborhoodComponentAnalysis(n_features_out=2, init='identity', - random_state=42) - nca.fit(X, y) - X_transformed = nca.transform(X) - np.testing.assert_equal(pairwise_distances(X_transformed).argsort()[:, 1], - np.array([2, 3, 0, 1])) - - def test_params_validation(): # Test that invalid parameters raise value error X = np.arange(12).reshape(4, 3) y = [1, 1, 2, 2] - NCA = NeighborhoodComponentAnalysis + NCA = NeighborhoodComponentsAnalysis # TypeError assert_raises(TypeError, NCA(max_iter='21').fit, X, y) assert_raises(TypeError, NCA(verbose='true').fit, X, y) assert_raises(TypeError, NCA(tol=1).fit, X, y) - assert_raises(TypeError, NCA(n_features_out='invalid').fit, - X, y) + assert_raises(TypeError, NCA(n_features_out='invalid').fit, X, y) # ValueError assert_raises(ValueError, NCA(init=1).fit, X, y) @@ -135,7 +132,7 @@ def test_transformation_dimensions(): # Fail if transformation input dimension does not match inputs dimensions transformation = np.array([[1, 2], [3, 4]]) assert_raises(ValueError, - NeighborhoodComponentAnalysis(None, init=transformation).fit, + NeighborhoodComponentsAnalysis(init=transformation).fit, X, y) # Fail if transformation output dimension is larger than @@ -143,12 +140,12 @@ def test_transformation_dimensions(): transformation = np.array([[1, 2], [3, 4], [5, 6]]) # len(transformation) > len(transformation[0]) assert_raises(ValueError, - NeighborhoodComponentAnalysis(None, init=transformation).fit, + NeighborhoodComponentsAnalysis(init=transformation).fit, X, y) # Pass otherwise transformation = np.arange(9).reshape(3, 3) - NeighborhoodComponentAnalysis(None, init=transformation).fit(X, y) + NeighborhoodComponentsAnalysis(init=transformation).fit(X, y) def test_n_features_out(): @@ -158,96 +155,133 @@ def test_n_features_out(): transformation = np.array([[1, 2, 3], [4, 5, 6]]) # n_features_out = X.shape[1] != transformation.shape[0] - nca = NeighborhoodComponentAnalysis(n_features_out=3, init=transformation) + nca = NeighborhoodComponentsAnalysis(n_features_out=3, init=transformation) assert_raises(ValueError, nca.fit, X, y) # n_features_out > X.shape[1] - nca = NeighborhoodComponentAnalysis(n_features_out=5, init=transformation) + nca = NeighborhoodComponentsAnalysis(n_features_out=5, init=transformation) assert_raises(ValueError, nca.fit, X, y) # n_features_out < X.shape[1] - nca = NeighborhoodComponentAnalysis(n_features_out=2, init='identity') + nca = NeighborhoodComponentsAnalysis(n_features_out=2, init='identity') nca.fit(X, y) def test_init_transformation(): X, y = make_classification(n_samples=30, n_features=5, n_redundant=0, random_state=0) - X_train, X_test, y_train, y_test = train_test_split(X, y) # Start learning from scratch - nca = NeighborhoodComponentAnalysis(None, init='identity') - nca.fit(X_train, y_train) + nca = NeighborhoodComponentsAnalysis(init='identity') + nca.fit(X, y) # Initialize with random - nca_random = NeighborhoodComponentAnalysis(None, init='random') - nca_random.fit(X_train, y_train) + nca_random = NeighborhoodComponentsAnalysis(init='random') + nca_random.fit(X, y) # Initialize with PCA - nca_pca = NeighborhoodComponentAnalysis(None, init='pca') - nca_pca.fit(X_train, y_train) + nca_pca = NeighborhoodComponentsAnalysis(init='pca') + nca_pca.fit(X, y) init = np.random.rand(X.shape[1], X.shape[1]) - nca = NeighborhoodComponentAnalysis(None, init=init) - nca.fit(X_train, y_train) + nca = NeighborhoodComponentsAnalysis(init=init) + nca.fit(X, y) # init.shape[1] must match X.shape[1] init = np.random.rand(X.shape[1], X.shape[1] + 1) - nca = NeighborhoodComponentAnalysis(None, init=init) - assert_raises(ValueError, nca.fit, X_train, y_train) + nca = NeighborhoodComponentsAnalysis(init=init) + assert_raises(ValueError, nca.fit, X, y) # init.shape[0] must be <= init.shape[1] init = np.random.rand(X.shape[1] + 1, X.shape[1]) - nca = NeighborhoodComponentAnalysis(None, init=init) - assert_raises(ValueError, nca.fit, X_train, y_train) + nca = NeighborhoodComponentsAnalysis(init=init) + assert_raises(ValueError, nca.fit, X, y) # init.shape[0] must match n_features_out init = np.random.rand(X.shape[1], X.shape[1]) - nca = NeighborhoodComponentAnalysis(n_features_out=X.shape[1] - 2, - init=init) - assert_raises(ValueError, nca.fit, X_train, y_train) + nca = NeighborhoodComponentsAnalysis(n_features_out=X.shape[1] - 2, + init=init) + assert_raises(ValueError, nca.fit, X, y) def test_verbose(): - nca = NeighborhoodComponentAnalysis(None, verbose=1) + nca = NeighborhoodComponentsAnalysis(verbose=1) nca.fit(iris_data, iris_target) + # TODO: rather assert that some message is printed -def test_callable(): +def test_singleton_class(): X = iris_data y = iris_target - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) - nca = NeighborhoodComponentAnalysis(None, callback='my_cb') - assert_raises(ValueError, nca.fit, X_train, y_train) + # one singleton class + singleton_class = 1 + ind_singleton, = np.where(y == singleton_class) + y[ind_singleton] = 2 + y[ind_singleton[0]] = singleton_class - max_iter = 10 + nca = NeighborhoodComponentsAnalysis(max_iter=30) + nca.fit(X, y) + + # One non-singleton class + ind_1, = np.where(y == 1) + ind_2, = np.where(y == 2) + y[ind_1] = 0 + y[ind_1[0]] = 1 + y[ind_2] = 0 + y[ind_2[0]] = 2 + + nca = NeighborhoodComponentsAnalysis(max_iter=30) + nca.fit(X, y) + + # Only singleton classes + ind_0, = np.where(y == 0) + ind_1, = np.where(y == 1) + ind_2, = np.where(y == 2) + X = X[[ind_0[0], ind_1[0], ind_2[0]]] + y = y[[ind_0[0], ind_1[0], ind_2[0]]] + + nca = NeighborhoodComponentsAnalysis(init='identity', max_iter=30) + nca.fit(X, y) + assert_array_equal(X, nca.transform(X)) - def my_cb(transformation, n_iter): - rem_iter = max_iter - n_iter - print('{} iterations remaining...'.format(rem_iter)) - nca = NeighborhoodComponentAnalysis(None, max_iter=max_iter, - callback=my_cb, verbose=1) - nca.fit(X_train, y_train) +def test_one_class(): + X = iris_data[iris_target == 0] + y = iris_target[iris_target == 0] + + nca = NeighborhoodComponentsAnalysis(max_iter=30, + n_features_out=X.shape[1], + init='identity') + nca.fit(X, y) + assert_array_equal(X, nca.transform(X)) -def test_terminate_early(): +def test_callable(): X = iris_data y = iris_target - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) - nca = NeighborhoodComponentAnalysis(None, max_iter=5) - nca.fit(X_train, y_train) + nca = NeighborhoodComponentsAnalysis(callback='my_cb') + assert_raises(ValueError, nca.fit, X, y) + + max_iter = 10 + + def my_cb(transformation, n_iter): + rem_iter = max_iter - n_iter + print('{} iterations remaining...'.format(rem_iter)) + + nca = NeighborhoodComponentsAnalysis(max_iter=max_iter, + callback=my_cb, verbose=1) + nca.fit(X, y) + # TODO: rather assert that message is printed def test_store_opt_result(): X = iris_data y = iris_target - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) - nca = NeighborhoodComponentAnalysis(None, max_iter=5, - store_opt_result=True) - nca.fit(X_train, y_train) + nca = NeighborhoodComponentsAnalysis(max_iter=5, + store_opt_result=True) + nca.fit(X, y) transformation = nca.opt_result_.x assert_equal(transformation.size, X.shape[1]**2)