@@ -12,65 +12,8 @@ defmodule Nx.LinAlg.BlockEigh do
12
12
13
13
import Nx.Defn
14
14
15
- defn calc_rot ( tl , tr , br ) do
16
- complex? = tl |> Nx . type ( ) |> Nx.Type . complex? ( )
17
- br = Nx . take_diagonal ( br ) |> Nx . real ( )
18
- tr = Nx . take_diagonal ( tr )
19
- tl = Nx . take_diagonal ( tl ) |> Nx . real ( )
20
-
21
- { tr , w } =
22
- if complex? do
23
- abs_tr = Nx . abs ( tr )
24
- { abs_tr , Nx . select ( abs_tr == 0 , 1 , Nx . conjugate ( tr ) / abs_tr ) }
25
- else
26
- { tr , 1 }
27
- end
28
-
29
- z_tr = Nx . equal ( tr , 0 )
30
- s_tr = Nx . select ( z_tr , 1 , tr )
31
- tau = Nx . select ( z_tr , 0 , ( br - tl ) / ( 2 * s_tr ) )
32
-
33
- t = Nx . sqrt ( 1 + tau ** 2 )
34
-
35
- t = 1 / ( tau + Nx . select ( tau >= 0 , t , - t ) )
36
-
37
- pred = Nx . abs ( tr ) <= 1.0e-5 * Nx . min ( Nx . abs ( br ) , Nx . abs ( tl ) )
38
- t = Nx . select ( pred , Nx . tensor ( 0 , type: tl . type ) , t )
39
-
40
- c = 1.0 / Nx . sqrt ( 1.0 + t ** 2 )
41
- s = if complex? , do: Nx . complex ( t * c , 0 ) * w , else: t * c
42
-
43
- rt1 = tl - t * tr
44
- rt2 = br + t * tr
45
- { rt1 , rt2 , c , s }
46
- end
47
-
48
- defn sq_norm ( tl , tr , bl , br ) do
49
- Nx . sum ( Nx . abs ( tl ) ** 2 + Nx . abs ( tr ) ** 2 + Nx . abs ( bl ) ** 2 + Nx . abs ( br ) ** 2 )
50
- end
51
-
52
- defn off_norm ( tl , tr , bl , br ) do
53
- { n , _ } = Nx . shape ( tl )
54
- diag = Nx . broadcast ( 0 , { n } )
55
- o_tl = Nx . put_diagonal ( tl , diag )
56
- o_br = Nx . put_diagonal ( br , diag )
57
-
58
- sq_norm ( o_tl , tr , bl , o_br )
59
- end
60
-
61
- @ doc """
62
- Calculates the Frobenius norm and the norm of the off-diagonals from
63
- the submatrices. Used to calculate convergeance.
64
- """
65
- defn norms ( tl , tr , bl , br ) do
66
- frob = sq_norm ( tl , tr , bl , br )
67
- off = off_norm ( tl , tr , bl , br )
68
-
69
- { frob , off }
70
- end
71
-
72
15
defn eigh ( matrix , opts \\ [ ] ) do
73
- opts = keyword! ( opts , eps: 1.0e-6 , max_iter: 15 )
16
+ opts = keyword! ( opts , eps: 1.0e-6 , max_iter: 100 )
74
17
75
18
matrix
76
19
|> Nx . revectorize ( [ collapsed_axes: :auto ] ,
@@ -80,17 +23,6 @@ defmodule Nx.LinAlg.BlockEigh do
80
23
|> revectorize_result ( matrix )
81
24
end
82
25
83
- deftransformp revectorize_result ( { eigenvals , eigenvecs } , a ) do
84
- shape = Nx . shape ( a )
85
-
86
- {
87
- Nx . revectorize ( eigenvals , a . vectorized_axes ,
88
- target_shape: Tuple . delete_at ( shape , tuple_size ( shape ) - 1 )
89
- ) ,
90
- Nx . revectorize ( eigenvecs , a . vectorized_axes , target_shape: shape )
91
- }
92
- end
93
-
94
26
defnp decompose ( matrix , opts ) do
95
27
{ n , _ } = Nx . shape ( matrix )
96
28
@@ -105,31 +37,30 @@ defmodule Nx.LinAlg.BlockEigh do
105
37
eps = opts [ :eps ]
106
38
max_iter = opts [ :max_iter ]
107
39
108
- out_type = Nx.Type . to_floating ( Nx . type ( matrix ) )
109
- matrix = Nx . as_type ( matrix , out_type )
40
+ type = Nx.Type . to_floating ( Nx . type ( matrix ) )
41
+ matrix = Nx . as_type ( matrix , type )
110
42
{ n , _ } = Nx . shape ( matrix )
111
43
i_n = n - 1
112
- # TO-DO: use a deftransform to calculate this without slicing
113
- { mid , _ } = Nx . shape ( matrix [ [ 0 .. i_n // 2 , 0 .. i_n // 2 ] ] )
44
+ mid = calculate_mid ( i_n )
114
45
i_mid = mid - 1
115
46
116
- { tl , tr , bl , br } =
117
- { matrix [ [ 0 .. i_mid , 0 .. i_mid ] ] , matrix [ [ 0 .. i_mid , mid .. i_n ] ] , matrix [ [ mid .. i_n , 0 .. i_mid ] ] ,
118
- matrix [ [ mid .. i_n , mid .. i_n ] ] }
47
+ tl = matrix [ [ 0 .. i_mid , 0 .. i_mid ] ]
48
+ tr = matrix [ [ 0 .. i_mid , mid .. i_n ] ]
49
+ bl = matrix [ [ mid .. i_n , 0 .. i_mid ] ]
50
+ br = matrix [ [ mid .. i_n , mid .. i_n ] ]
119
51
120
52
# Pad if not even
121
- { tl , tr , bl , br } =
53
+ { tr , bl , br } =
122
54
if Nx . remainder ( n , 2 ) == 1 do
123
55
tr = Nx . pad ( tr , 0 , [ { 0 , 0 , 0 } , { 0 , 1 , 0 } ] )
124
56
bl = Nx . pad ( bl , 0 , [ { 0 , 1 , 0 } , { 0 , 0 , 0 } ] )
125
57
br = Nx . pad ( br , 0 , [ { 0 , 1 , 0 } , { 0 , 1 , 0 } ] )
126
- { tl , tr , bl , br }
58
+ { tr , bl , br }
127
59
else
128
- { tl , tr , bl , br }
60
+ { tr , bl , br }
129
61
end
130
62
131
63
# Initialze tensors to hold eigenvectors
132
- type = tl |> Nx . type ( ) |> Nx.Type . to_floating ( )
133
64
v_tl = v_br = Nx . eye ( mid , type: type )
134
65
v_tr = v_bl = Nx . broadcast ( Nx . tensor ( 0 , type: type ) , { mid , mid } )
135
66
@@ -145,7 +76,7 @@ defmodule Nx.LinAlg.BlockEigh do
145
76
# all sub matrices to share the needed values.
146
77
{ { tl , br , v_tl , v_tr , v_bl , v_br } , _ } =
147
78
while { { tl , br , v_tl , v_tr , v_bl , v_br } , { frob_norm , off_norm , tr , bl , i = 0 } } ,
148
- off_norm > Nx . pow ( eps , 2 ) * frob_norm and i < max_iter do
79
+ off_norm > eps ** 2 * frob_norm and i < max_iter do
149
80
{ tl , tr , bl , br , v_tl , v_tr , v_bl , v_br } =
150
81
perform_sweeps ( tl , tr , bl , br , v_tl , v_tr , v_bl , v_br , mid , i_n )
151
82
@@ -180,57 +111,126 @@ defmodule Nx.LinAlg.BlockEigh do
180
111
{ w , v }
181
112
end
182
113
114
+ deftransformp calculate_mid ( i_n ) do
115
+ Range . size ( 0 .. i_n // 2 )
116
+ end
117
+
118
+ defnp calc_rot ( tl , tr , br ) do
119
+ complex? = tl |> Nx . type ( ) |> Nx.Type . complex? ( )
120
+ br = Nx . take_diagonal ( br ) |> Nx . real ( )
121
+ tr = Nx . take_diagonal ( tr )
122
+ tl = Nx . take_diagonal ( tl ) |> Nx . real ( )
123
+
124
+ { tr , w } =
125
+ if complex? do
126
+ abs_tr = Nx . abs ( tr )
127
+ { abs_tr , Nx . select ( abs_tr == 0 , 1 , Nx . conjugate ( tr ) / abs_tr ) }
128
+ else
129
+ { tr , 1 }
130
+ end
131
+
132
+ z_tr = Nx . equal ( tr , 0 )
133
+ s_tr = Nx . select ( z_tr , 1 , tr )
134
+ tau = Nx . select ( z_tr , 0 , ( br - tl ) / ( 2 * s_tr ) )
135
+
136
+ t = Nx . sqrt ( 1 + tau ** 2 )
137
+
138
+ t = 1 / ( tau + Nx . select ( tau >= 0 , t , - t ) )
139
+
140
+ pred = Nx . abs ( tr ) <= 1.0e-5 * Nx . min ( Nx . abs ( br ) , Nx . abs ( tl ) )
141
+ t = Nx . select ( pred , Nx . tensor ( 0 , type: tl . type ) , t )
142
+
143
+ c = 1.0 / Nx . sqrt ( 1.0 + t ** 2 )
144
+ s = if complex? , do: Nx . complex ( t * c , 0 ) * w , else: t * c
145
+
146
+ rt1 = tl - t * tr
147
+ rt2 = br + t * tr
148
+ { rt1 , rt2 , c , s }
149
+ end
150
+
151
+ defnp sq_norm ( tl , tr , bl , br ) do
152
+ Nx . sum ( Nx . abs ( tl ) ** 2 + Nx . abs ( tr ) ** 2 + Nx . abs ( bl ) ** 2 + Nx . abs ( br ) ** 2 )
153
+ end
154
+
155
+ defnp off_norm ( tl , tr , bl , br ) do
156
+ { n , _ } = Nx . shape ( tl )
157
+ diag = Nx . broadcast ( 0 , { n } )
158
+ o_tl = Nx . put_diagonal ( tl , diag )
159
+ o_br = Nx . put_diagonal ( br , diag )
160
+
161
+ sq_norm ( o_tl , tr , bl , o_br )
162
+ end
163
+
164
+ # Calculates the Frobenius norm and the norm of the off-diagonals from
165
+ # the submatrices. Used to calculate convergeance.
166
+ defnp norms ( tl , tr , bl , br ) do
167
+ frob = sq_norm ( tl , tr , bl , br )
168
+ off = off_norm ( tl , tr , bl , br )
169
+
170
+ { frob , off }
171
+ end
172
+
173
+ deftransformp revectorize_result ( { eigenvals , eigenvecs } , a ) do
174
+ shape = Nx . shape ( a )
175
+
176
+ {
177
+ Nx . revectorize ( eigenvals , a . vectorized_axes ,
178
+ target_shape: Tuple . delete_at ( shape , tuple_size ( shape ) - 1 )
179
+ ) ,
180
+ Nx . revectorize ( eigenvecs , a . vectorized_axes , target_shape: shape )
181
+ }
182
+ end
183
+
183
184
defnp perform_sweeps ( tl , tr , bl , br , v_tl , v_tr , v_bl , v_br , mid , i_n ) do
184
185
while { tl , tr , bl , br , v_tl , v_tr , v_bl , v_br } , _n <- 0 .. i_n do
185
186
{ rt1 , rt2 , c , s } = calc_rot ( tl , tr , br )
186
187
# build row and column vectors for parrelelized rotations
187
- c_v = Nx . reshape ( c , { mid , 1 } )
188
- s_v = Nx . reshape ( s , { mid , 1 } )
189
- c_h = Nx . reshape ( c , { 1 , mid } )
190
- s_h = Nx . reshape ( s , { 1 , mid } )
188
+ c_v = Nx . new_axis ( c , 1 )
189
+ s_v = Nx . new_axis ( s , 1 )
190
+ c_h = Nx . new_axis ( c , 0 )
191
+ s_h = Nx . new_axis ( s , 0 )
191
192
192
- s_conj =
193
+ s_v_conj =
193
194
if Nx . type ( s ) |> Nx.Type . complex? ( ) do
194
195
Nx . conjugate ( s_v )
195
196
else
196
197
s_v
197
198
end
198
199
200
+ s_h_conj = Nx . transpose ( s_v_conj )
201
+
202
+ # Each rotation group below is performed based on the same
203
+ # tl, bl, tr, br values, so we must do single-expr
204
+ # assignments (i.e. {tl, tr, bl, br} = ...)
205
+
199
206
# Rotate rows
200
207
{ tl , tr , bl , br } = {
201
- tl * c_v - bl * s_conj ,
202
- tr * c_v - br * s_conj ,
208
+ tl * c_v - bl * s_v_conj ,
209
+ tr * c_v - br * s_v_conj ,
203
210
tl * s_v + bl * c_v ,
204
211
tr * s_v + br * c_v
205
212
}
206
213
207
- s_conj =
208
- if Nx . type ( s ) |> Nx.Type . complex? ( ) do
209
- Nx . conjugate ( s_h )
210
- else
211
- s_h
212
- end
213
-
214
214
# Rotate cols
215
215
{ tl , tr , bl , br } = {
216
216
tl * c_h - tr * s_h ,
217
- tl * s_conj + tr * c_h ,
217
+ tl * s_h_conj + tr * c_h ,
218
218
bl * c_h - br * s_h ,
219
- bl * s_conj + br * c_h
219
+ bl * s_h_conj + br * c_h
220
220
}
221
221
222
222
# Store results and permute values across sub matrices
223
+ zero_diag = Nx . broadcast ( 0 , { mid } )
223
224
tl = Nx . put_diagonal ( tl , rt1 )
224
- tr = Nx . put_diagonal ( tr , Nx . broadcast ( 0 , { mid } ) )
225
- bl = Nx . put_diagonal ( bl , Nx . broadcast ( 0 , { mid } ) )
225
+ tr = Nx . put_diagonal ( tr , zero_diag )
226
+ bl = Nx . put_diagonal ( bl , zero_diag )
226
227
br = Nx . put_diagonal ( br , rt2 )
227
228
228
229
{ tl , tr } = permute_cols_in_row ( tl , tr )
229
230
{ bl , br } = permute_cols_in_row ( bl , br )
230
231
{ tl , bl } = permute_rows_in_col ( tl , bl )
231
232
{ tr , br } = permute_rows_in_col ( tr , br )
232
233
233
- s_v_conj = if Nx . type ( s_v ) |> Nx.Type . complex? ( ) , do: Nx . conjugate ( s_v ) , else: s_v
234
234
# Rotate to calc vectors
235
235
{ v_tl , v_tr , v_bl , v_br } = {
236
236
v_tl * c_v - v_bl * s_v_conj ,
@@ -282,7 +282,7 @@ defmodule Nx.LinAlg.BlockEigh do
282
282
{ top_out , bottom_out }
283
283
end
284
284
285
- defn permute_cols_in_row ( left , right ) do
285
+ defnp permute_cols_in_row ( left , right ) do
286
286
{ k , _ } = Nx . shape ( left )
287
287
288
288
{ left_out , right_out } =
0 commit comments