@@ -34,9 +34,12 @@ def __init__(self, dim=None, k=7, metric='weighted'):
3434 '''
3535 if metric not in ('weighted' , 'orthonormalized' , 'plain' ):
3636 raise ValueError ('Invalid metric: %r' % metric )
37- self .dim = dim
38- self .metric = metric
39- self .k = k
37+
38+ self .params = {
39+ 'dim' : dim ,
40+ 'metric' : metric ,
41+ 'k' : k ,
42+ }
4043
4144 def transformer (self ):
4245 return self ._transformer
@@ -48,12 +51,12 @@ def _process_inputs(self, X, Y):
4851 unique_classes , Y = np .unique (Y , return_inverse = True )
4952 num_classes = len (unique_classes )
5053
51- if self .dim is None :
52- self .dim = d
53- elif not 0 < self .dim <= d :
54+ if self .params [ ' dim' ] is None :
55+ self .params [ ' dim' ] = d
56+ elif not 0 < self .params [ ' dim' ] <= d :
5457 raise ValueError ('Invalid embedding dimension, must be in [1,%d]' % d )
5558
56- if not 0 < self .k < d :
59+ if not 0 < self .params [ 'k' ] < d :
5760 raise ValueError ('Invalid k, must be in [0,%d]' % (d - 1 ))
5861
5962 return X , Y , num_classes , n , d
@@ -74,7 +77,7 @@ def fit(self, X, Y):
7477 # classwise affinity matrix
7578 dist = pairwise_distances (Xc , metric = 'l2' , squared = True )
7679 # distances to k-th nearest neighbor
77- k = min (self .k , nc - 1 )
80+ k = min (self .params [ 'k' ] , nc - 1 )
7881 sigma = np .sqrt (np .partition (dist , k , axis = 0 )[:,k ])
7982
8083 local_scale = np .outer (sigma , sigma )
@@ -94,21 +97,22 @@ def fit(self, X, Y):
9497 tSw += tSw .T
9598 tSw /= 2
9699
97- if self .dim == d :
100+ if self .params [ ' dim' ] == d :
98101 vals , vecs = scipy .linalg .eigh (tSb , tSw )
99102 else :
100- vals , vecs = scipy .sparse .linalg .eigsh (tSb , k = self .dim , M = tSw , which = 'LA' )
103+ vals , vecs = scipy .sparse .linalg .eigsh (tSb , k = self .params [ ' dim' ] , M = tSw , which = 'LA' )
101104
102- order = np .argsort (- vals )[:self .dim ]
105+ order = np .argsort (- vals )[:self .params [ ' dim' ] ]
103106 vals = vals [order ]
104107 vecs = vecs [:,order ]
105108
106- if self .metric == 'weighted' :
109+ if self .params [ ' metric' ] == 'weighted' :
107110 vecs *= np .sqrt (vals )
108- elif self .metric == 'orthonormalized' :
111+ elif self .params [ ' metric' ] == 'orthonormalized' :
109112 vecs , _ = np .linalg .qr (vecs )
110113
111114 self ._transformer = vecs .T
115+ return self
112116
113117
114118def _sum_outer (x ):
0 commit comments