Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create blocked Jacobi method for eigen decomposition #1510

Merged
merged 18 commits into from
Jan 13, 2025

Conversation

christianjgreen
Copy link
Contributor

@christianjgreen christianjgreen commented Jun 12, 2024

#1027

The current implementation of eigh using a 10x10 symmetric matrix takes about 450ms for a 20x20 matrix and 154s for a 100x100 matrix, while the new implementation takes 0.5ms and 11ms respectively.

This is a defn version of the method used by XLA: https://github.com/openxla/xla/blob/main/xla/service/eigh_expander.cc
There is still a todo list and code cleanup/drying to do, but I wanted to pitch this before getting to far into the process. While this method has a static submatrix size with no recursion, this approach can be built on to recreate the recursive blocked-eigh used by JAX. This approach had less complexity and seemed like a nice way to make eigh performant without having to exactly copy the JAX method.

The gist of the method is to break the matrix into four submatrices and apply the jacobi rotations across all rows and cols each iteration and then joining the results.

Draft commit to introduce the idea.
Todo:

  • Inherit parent tensor type
  • Pass in eps as argument
  • Handle complex numbers
  • Reject malformed matrices
  • [?] Refactor to be less ugly and cleanup syntax

Current issues:

  • The values returned are not always normalized the same wave the current implementation does, so some tests fail

Please let me know if this is of any use! <3

Draft commit to introduce the idea.
Todo:
* Handle complex numbers
* Reject malformed matrices
@christianjgreen christianjgreen marked this pull request as ready for review June 12, 2024 02:58
@josevalim josevalim requested a review from polvalente June 12, 2024 06:02
Comment on lines 59 to 68
defn eigh(matrix) do
matrix
|> Nx.revectorize([collapsed_axes: :auto],
target_shape: {Nx.axis_size(matrix, -2), Nx.axis_size(matrix, -1)}
)
|> decompose()
|> then(fn {w, v} ->
revectorize_result({w, v}, matrix)
end)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should be the only defn in this module and the others would be defnp. Or something close to that.

end

# Initialze tensors to hold eigenvectors
v_tl = Nx.eye(mid, type: :f32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why force f32 here? Is this another case where the algorithm just fails on f64?
Perhaps this should be masked underneath the implementation if it's the case.

#
# The inner loop performs "sweep" rounds of n - 1, which is enough permutations to allow
# all sub matrices to share the needed values.
{_, _, tl, _tr, _bl, br, v_tl, v_tr, v_bl, v_br, _} =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use a pattern for organizing the while state that we do quite a lot:

{{tl, br, v_tl, v_tr, v_bl, v_br}, _} where you leave the outputs in a first-position tuple, and the other state in a second position, so pattern matching on the statement is easier, as well as understanding what's output and what's not

Copy link
Contributor

@polvalente polvalente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good. I'm also leaving a few stylistic suggestions for readability

nx/lib/eigh_block.ex Outdated Show resolved Hide resolved
nx/lib/eigh_block.ex Outdated Show resolved Hide resolved
t = Nx.sqrt(1 + Nx.pow(tau, 2))
t = Nx.select(Nx.greater_equal(tau, 0), 1 / (tau + t), 1 / (tau - t))

pred = Nx.less_equal(Nx.abs(b), 0.1 * 1.0e-4 * Nx.min(Nx.abs(a), Nx.abs(c)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pred = Nx.less_equal(Nx.abs(b), 0.1 * 1.0e-4 * Nx.min(Nx.abs(a), Nx.abs(c)))
pred = Nx.abs(b) <= 1.0e-5 * Nx.min(Nx.abs(a), Nx.abs(c))

nx/lib/eigh_block.ex Outdated Show resolved Hide resolved
nx/lib/eigh_block.ex Outdated Show resolved Hide resolved
nx/lib/eigh_block.ex Outdated Show resolved Hide resolved
@polvalente
Copy link
Contributor

Merging this now! For reference, these are the benchmarks on my desktop, where Default Eigh is the previous implementation, and BlockEigh is the current one. For Host default eigh is the custom call that uses Eigen

Cuda

##### With input 5x5 #####
Name                          ips        average  deviation         median         99th %
BlockEigh (cuda)           3.60 K      277.62 μs     ±6.19%      275.34 μs      351.13 μs
Default Eigh (cuda)        3.20 K      312.15 μs     ±4.13%      308.74 μs      356.48 μs

Comparison: 
BlockEigh (cuda)           3.60 K
Default Eigh (cuda)        3.20 K - 1.12x slower +34.53 μs

##### With input 50x50 #####
Name                          ips        average  deviation         median         99th %
BlockEigh (cuda)           530.17        1.89 ms     ±4.22%        1.85 ms        2.16 ms
Default Eigh (cuda)        309.24        3.23 ms     ±3.10%        3.23 ms        3.50 ms

Comparison: 
BlockEigh (cuda)           530.17
Default Eigh (cuda)        309.24 - 1.71x slower +1.35 ms

##### With input 500x500 #####
Name                          ips        average  deviation         median         99th %
BlockEigh (cuda)            44.91    0.00037 min     ±0.60%    0.00037 min    0.00038 min
Default Eigh (cuda)        0.0762       0.22 min     ±0.00%       0.22 min       0.22 min

Comparison: 
BlockEigh (cuda)            44.91
Default Eigh (cuda)        0.0762 - 589.54x slower +0.22 min

Host

##### With input 5x5 #####
Name                          ips        average  deviation         median         99th %
Default Eigh (host)      117.29 K        8.53 μs   ±137.57%        7.50 μs       25.03 μs
BlockEigh (host)          71.34 K       14.02 μs    ±66.68%       12.74 μs       33.63 μs

Comparison: 
Default Eigh (host)      117.29 K
BlockEigh (host)          71.34 K - 1.64x slower +5.49 μs

##### With input 50x50 #####
Name                          ips        average  deviation         median         99th %
Default Eigh (host)       17.52 K       57.08 μs     ±9.75%       56.15 μs       78.27 μs
BlockEigh (host)           4.59 K      217.99 μs     ±1.92%      217.22 μs      228.60 μs

Comparison: 
Default Eigh (host)       17.52 K
BlockEigh (host)           4.59 K - 3.82x slower +160.91 μs

@polvalente polvalente merged commit 8dc7b29 into elixir-nx:main Jan 13, 2025
8 checks passed
@polvalente polvalente changed the title [Draft] Create blocked Jacobi method for eigen decomposition Create blocked Jacobi method for eigen decomposition Jan 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants