Skip to content

Commit bd246f7

Browse files
committed
refactor: cleanup implementation and make test more strict
1 parent eade22d commit bd246f7

File tree

2 files changed

+105
-105
lines changed

2 files changed

+105
-105
lines changed

nx/lib/nx/lin_alg/block_eigh.ex

Lines changed: 101 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -12,65 +12,8 @@ defmodule Nx.LinAlg.BlockEigh do
1212

1313
import Nx.Defn
1414

15-
defn calc_rot(tl, tr, br) do
16-
complex? = tl |> Nx.type() |> Nx.Type.complex?()
17-
br = Nx.take_diagonal(br) |> Nx.real()
18-
tr = Nx.take_diagonal(tr)
19-
tl = Nx.take_diagonal(tl) |> Nx.real()
20-
21-
{tr, w} =
22-
if complex? do
23-
abs_tr = Nx.abs(tr)
24-
{abs_tr, Nx.select(abs_tr == 0, 1, Nx.conjugate(tr) / abs_tr)}
25-
else
26-
{tr, 1}
27-
end
28-
29-
z_tr = Nx.equal(tr, 0)
30-
s_tr = Nx.select(z_tr, 1, tr)
31-
tau = Nx.select(z_tr, 0, (br - tl) / (2 * s_tr))
32-
33-
t = Nx.sqrt(1 + tau ** 2)
34-
35-
t = 1 / (tau + Nx.select(tau >= 0, t, -t))
36-
37-
pred = Nx.abs(tr) <= 1.0e-5 * Nx.min(Nx.abs(br), Nx.abs(tl))
38-
t = Nx.select(pred, Nx.tensor(0, type: tl.type), t)
39-
40-
c = 1.0 / Nx.sqrt(1.0 + t ** 2)
41-
s = if complex?, do: Nx.complex(t * c, 0) * w, else: t * c
42-
43-
rt1 = tl - t * tr
44-
rt2 = br + t * tr
45-
{rt1, rt2, c, s}
46-
end
47-
48-
defn sq_norm(tl, tr, bl, br) do
49-
Nx.sum(Nx.abs(tl) ** 2 + Nx.abs(tr) ** 2 + Nx.abs(bl) ** 2 + Nx.abs(br) ** 2)
50-
end
51-
52-
defn off_norm(tl, tr, bl, br) do
53-
{n, _} = Nx.shape(tl)
54-
diag = Nx.broadcast(0, {n})
55-
o_tl = Nx.put_diagonal(tl, diag)
56-
o_br = Nx.put_diagonal(br, diag)
57-
58-
sq_norm(o_tl, tr, bl, o_br)
59-
end
60-
61-
@doc """
62-
Calculates the Frobenius norm and the norm of the off-diagonals from
63-
the submatrices. Used to calculate convergeance.
64-
"""
65-
defn norms(tl, tr, bl, br) do
66-
frob = sq_norm(tl, tr, bl, br)
67-
off = off_norm(tl, tr, bl, br)
68-
69-
{frob, off}
70-
end
71-
7215
defn eigh(matrix, opts \\ []) do
73-
opts = keyword!(opts, eps: 1.0e-6, max_iter: 15)
16+
opts = keyword!(opts, eps: 1.0e-6, max_iter: 100)
7417

7518
matrix
7619
|> Nx.revectorize([collapsed_axes: :auto],
@@ -80,17 +23,6 @@ defmodule Nx.LinAlg.BlockEigh do
8023
|> revectorize_result(matrix)
8124
end
8225

83-
deftransformp revectorize_result({eigenvals, eigenvecs}, a) do
84-
shape = Nx.shape(a)
85-
86-
{
87-
Nx.revectorize(eigenvals, a.vectorized_axes,
88-
target_shape: Tuple.delete_at(shape, tuple_size(shape) - 1)
89-
),
90-
Nx.revectorize(eigenvecs, a.vectorized_axes, target_shape: shape)
91-
}
92-
end
93-
9426
defnp decompose(matrix, opts) do
9527
{n, _} = Nx.shape(matrix)
9628

@@ -105,31 +37,30 @@ defmodule Nx.LinAlg.BlockEigh do
10537
eps = opts[:eps]
10638
max_iter = opts[:max_iter]
10739

108-
out_type = Nx.Type.to_floating(Nx.type(matrix))
109-
matrix = Nx.as_type(matrix, out_type)
40+
type = Nx.Type.to_floating(Nx.type(matrix))
41+
matrix = Nx.as_type(matrix, type)
11042
{n, _} = Nx.shape(matrix)
11143
i_n = n - 1
112-
# TO-DO: use a deftransform to calculate this without slicing
113-
{mid, _} = Nx.shape(matrix[[0..i_n//2, 0..i_n//2]])
44+
mid = calculate_mid(i_n)
11445
i_mid = mid - 1
11546

116-
{tl, tr, bl, br} =
117-
{matrix[[0..i_mid, 0..i_mid]], matrix[[0..i_mid, mid..i_n]], matrix[[mid..i_n, 0..i_mid]],
118-
matrix[[mid..i_n, mid..i_n]]}
47+
tl = matrix[[0..i_mid, 0..i_mid]]
48+
tr = matrix[[0..i_mid, mid..i_n]]
49+
bl = matrix[[mid..i_n, 0..i_mid]]
50+
br = matrix[[mid..i_n, mid..i_n]]
11951

12052
# Pad if not even
121-
{tl, tr, bl, br} =
53+
{tr, bl, br} =
12254
if Nx.remainder(n, 2) == 1 do
12355
tr = Nx.pad(tr, 0, [{0, 0, 0}, {0, 1, 0}])
12456
bl = Nx.pad(bl, 0, [{0, 1, 0}, {0, 0, 0}])
12557
br = Nx.pad(br, 0, [{0, 1, 0}, {0, 1, 0}])
126-
{tl, tr, bl, br}
58+
{tr, bl, br}
12759
else
128-
{tl, tr, bl, br}
60+
{tr, bl, br}
12961
end
13062

13163
# Initialze tensors to hold eigenvectors
132-
type = tl |> Nx.type() |> Nx.Type.to_floating()
13364
v_tl = v_br = Nx.eye(mid, type: type)
13465
v_tr = v_bl = Nx.broadcast(Nx.tensor(0, type: type), {mid, mid})
13566

@@ -145,7 +76,7 @@ defmodule Nx.LinAlg.BlockEigh do
14576
# all sub matrices to share the needed values.
14677
{{tl, br, v_tl, v_tr, v_bl, v_br}, _} =
14778
while {{tl, br, v_tl, v_tr, v_bl, v_br}, {frob_norm, off_norm, tr, bl, i = 0}},
148-
off_norm > Nx.pow(eps, 2) * frob_norm and i < max_iter do
79+
off_norm > eps ** 2 * frob_norm and i < max_iter do
14980
{tl, tr, bl, br, v_tl, v_tr, v_bl, v_br} =
15081
perform_sweeps(tl, tr, bl, br, v_tl, v_tr, v_bl, v_br, mid, i_n)
15182

@@ -180,57 +111,126 @@ defmodule Nx.LinAlg.BlockEigh do
180111
{w, v}
181112
end
182113

114+
deftransformp calculate_mid(i_n) do
115+
Range.size(0..i_n//2)
116+
end
117+
118+
defnp calc_rot(tl, tr, br) do
119+
complex? = tl |> Nx.type() |> Nx.Type.complex?()
120+
br = Nx.take_diagonal(br) |> Nx.real()
121+
tr = Nx.take_diagonal(tr)
122+
tl = Nx.take_diagonal(tl) |> Nx.real()
123+
124+
{tr, w} =
125+
if complex? do
126+
abs_tr = Nx.abs(tr)
127+
{abs_tr, Nx.select(abs_tr == 0, 1, Nx.conjugate(tr) / abs_tr)}
128+
else
129+
{tr, 1}
130+
end
131+
132+
z_tr = Nx.equal(tr, 0)
133+
s_tr = Nx.select(z_tr, 1, tr)
134+
tau = Nx.select(z_tr, 0, (br - tl) / (2 * s_tr))
135+
136+
t = Nx.sqrt(1 + tau ** 2)
137+
138+
t = 1 / (tau + Nx.select(tau >= 0, t, -t))
139+
140+
pred = Nx.abs(tr) <= 1.0e-5 * Nx.min(Nx.abs(br), Nx.abs(tl))
141+
t = Nx.select(pred, Nx.tensor(0, type: tl.type), t)
142+
143+
c = 1.0 / Nx.sqrt(1.0 + t ** 2)
144+
s = if complex?, do: Nx.complex(t * c, 0) * w, else: t * c
145+
146+
rt1 = tl - t * tr
147+
rt2 = br + t * tr
148+
{rt1, rt2, c, s}
149+
end
150+
151+
defnp sq_norm(tl, tr, bl, br) do
152+
Nx.sum(Nx.abs(tl) ** 2 + Nx.abs(tr) ** 2 + Nx.abs(bl) ** 2 + Nx.abs(br) ** 2)
153+
end
154+
155+
defnp off_norm(tl, tr, bl, br) do
156+
{n, _} = Nx.shape(tl)
157+
diag = Nx.broadcast(0, {n})
158+
o_tl = Nx.put_diagonal(tl, diag)
159+
o_br = Nx.put_diagonal(br, diag)
160+
161+
sq_norm(o_tl, tr, bl, o_br)
162+
end
163+
164+
# Calculates the Frobenius norm and the norm of the off-diagonals from
165+
# the submatrices. Used to calculate convergeance.
166+
defnp norms(tl, tr, bl, br) do
167+
frob = sq_norm(tl, tr, bl, br)
168+
off = off_norm(tl, tr, bl, br)
169+
170+
{frob, off}
171+
end
172+
173+
deftransformp revectorize_result({eigenvals, eigenvecs}, a) do
174+
shape = Nx.shape(a)
175+
176+
{
177+
Nx.revectorize(eigenvals, a.vectorized_axes,
178+
target_shape: Tuple.delete_at(shape, tuple_size(shape) - 1)
179+
),
180+
Nx.revectorize(eigenvecs, a.vectorized_axes, target_shape: shape)
181+
}
182+
end
183+
183184
defnp perform_sweeps(tl, tr, bl, br, v_tl, v_tr, v_bl, v_br, mid, i_n) do
184185
while {tl, tr, bl, br, v_tl, v_tr, v_bl, v_br}, _n <- 0..i_n do
185186
{rt1, rt2, c, s} = calc_rot(tl, tr, br)
186187
# build row and column vectors for parrelelized rotations
187-
c_v = Nx.reshape(c, {mid, 1})
188-
s_v = Nx.reshape(s, {mid, 1})
189-
c_h = Nx.reshape(c, {1, mid})
190-
s_h = Nx.reshape(s, {1, mid})
188+
c_v = Nx.new_axis(c, 1)
189+
s_v = Nx.new_axis(s, 1)
190+
c_h = Nx.new_axis(c, 0)
191+
s_h = Nx.new_axis(s, 0)
191192

192-
s_conj =
193+
s_v_conj =
193194
if Nx.type(s) |> Nx.Type.complex?() do
194195
Nx.conjugate(s_v)
195196
else
196197
s_v
197198
end
198199

200+
s_h_conj = Nx.transpose(s_v_conj)
201+
202+
# Each rotation group below is performed based on the same
203+
# tl, bl, tr, br values, so we must do single-expr
204+
# assignments (i.e. {tl, tr, bl, br} = ...)
205+
199206
# Rotate rows
200207
{tl, tr, bl, br} = {
201-
tl * c_v - bl * s_conj,
202-
tr * c_v - br * s_conj,
208+
tl * c_v - bl * s_v_conj,
209+
tr * c_v - br * s_v_conj,
203210
tl * s_v + bl * c_v,
204211
tr * s_v + br * c_v
205212
}
206213

207-
s_conj =
208-
if Nx.type(s) |> Nx.Type.complex?() do
209-
Nx.conjugate(s_h)
210-
else
211-
s_h
212-
end
213-
214214
# Rotate cols
215215
{tl, tr, bl, br} = {
216216
tl * c_h - tr * s_h,
217-
tl * s_conj + tr * c_h,
217+
tl * s_h_conj + tr * c_h,
218218
bl * c_h - br * s_h,
219-
bl * s_conj + br * c_h
219+
bl * s_h_conj + br * c_h
220220
}
221221

222222
# Store results and permute values across sub matrices
223+
zero_diag = Nx.broadcast(0, {mid})
223224
tl = Nx.put_diagonal(tl, rt1)
224-
tr = Nx.put_diagonal(tr, Nx.broadcast(0, {mid}))
225-
bl = Nx.put_diagonal(bl, Nx.broadcast(0, {mid}))
225+
tr = Nx.put_diagonal(tr, zero_diag)
226+
bl = Nx.put_diagonal(bl, zero_diag)
226227
br = Nx.put_diagonal(br, rt2)
227228

228229
{tl, tr} = permute_cols_in_row(tl, tr)
229230
{bl, br} = permute_cols_in_row(bl, br)
230231
{tl, bl} = permute_rows_in_col(tl, bl)
231232
{tr, br} = permute_rows_in_col(tr, br)
232233

233-
s_v_conj = if Nx.type(s_v) |> Nx.Type.complex?(), do: Nx.conjugate(s_v), else: s_v
234234
# Rotate to calc vectors
235235
{v_tl, v_tr, v_bl, v_br} = {
236236
v_tl * c_v - v_bl * s_v_conj,
@@ -282,7 +282,7 @@ defmodule Nx.LinAlg.BlockEigh do
282282
{top_out, bottom_out}
283283
end
284284

285-
defn permute_cols_in_row(left, right) do
285+
defnp permute_cols_in_row(left, right) do
286286
{k, _} = Nx.shape(left)
287287

288288
{left_out, right_out} =

nx/test/nx/lin_alg_test.exs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ defmodule Nx.LinAlgTest do
647647
rand = :rand.uniform() * magnitude * 0.1 + magnitude
648648
rand * sign
649649
end)
650-
|> Nx.tensor(type: :f64)
650+
|> Nx.tensor(type: type)
651651

652652
evals_test_diag =
653653
evals_test
@@ -664,10 +664,10 @@ defmodule Nx.LinAlgTest do
664664
|> Nx.dot([2], [0], q, [1], [0])
665665

666666
# Eigenvalues and eigenvectors
667-
assert {evals, evecs} = Nx.LinAlg.eigh(a, max_iter: 100_000, eps: 1.0e-8)
667+
assert {evals, evecs} = Nx.LinAlg.eigh(a, eps: 1.0e-8)
668668

669669
assert_all_close(evals_test, evals[0], atol: 1.0e-1)
670-
# assert_all_close(evals_test, evals[1], atol: 1.0e-1)
670+
assert_all_close(evals_test, evals[1], atol: 1.0e-1)
671671

672672
evals =
673673
evals
@@ -679,7 +679,7 @@ defmodule Nx.LinAlgTest do
679679
evecs_evals = Nx.dot(evecs, [2], [0], evals, [1], [0])
680680
a_evecs = Nx.dot(evecs_evals, [2], [0], Nx.LinAlg.adjoint(evecs), [1], [0])
681681

682-
assert_all_close(a, a_evecs, atol: 1.0e-1)
682+
assert_all_close(a, a_evecs, atol: 1.0e-8)
683683
key
684684
end
685685
end

0 commit comments

Comments
 (0)