|
| 1 | +import time |
| 2 | + |
| 3 | +import sparse |
| 4 | + |
| 5 | +import networkx as nx |
| 6 | +from networkx.algorithms.link_analysis.pagerank_alg import _pagerank_scipy |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import scipy.sparse as sp |
| 10 | + |
| 11 | + |
| 12 | +def pagerank(G, alpha=0.85, max_iter=100, tol=1e-6) -> dict: |
| 13 | + N = len(G) |
| 14 | + if N == 0: |
| 15 | + return {} |
| 16 | + |
| 17 | + alpha = sparse.asarray(alpha) |
| 18 | + nodelist = list(G) |
| 19 | + A = nx.to_scipy_sparse_array(G, dtype=float, format="csc") |
| 20 | + A = sparse.asarray(A) |
| 21 | + S = sparse.sum(A, axis=1) |
| 22 | + S = sparse.where(sparse.asarray(0.0) != S, sparse.asarray(1.0) / S, S) |
| 23 | + |
| 24 | + # TODO: spdiags https://github.com/willow-ahrens/Finch.jl/issues/499 |
| 25 | + Q = sparse.asarray(sp.csc_array(sp.spdiags(S.todense(), 0, *A.shape))) |
| 26 | + A = Q @ A |
| 27 | + |
| 28 | + # initial vector |
| 29 | + x = sparse.full((1, N), fill_value=1.0 / N) |
| 30 | + |
| 31 | + # personalization vector |
| 32 | + p = sparse.full((1, N), fill_value=1.0 / N) |
| 33 | + |
| 34 | + # Dangling nodes |
| 35 | + dangling_weights = p |
| 36 | + |
| 37 | + # power iteration: make up to max_iter iterations |
| 38 | + for _ in range(max_iter): |
| 39 | + xlast = x |
| 40 | + x_dangling = sparse.where(S[None, :] == sparse.asarray(0.0), x, sparse.asarray(0.0)) |
| 41 | + x = ( |
| 42 | + alpha * (x @ A + sparse.asarray(sparse.sum(x_dangling)) * dangling_weights) |
| 43 | + + (sparse.asarray(1) - alpha) * p |
| 44 | + ) |
| 45 | + # check convergence, l1 norm |
| 46 | + err = sparse.sum(sparse.abs(x - xlast)) |
| 47 | + if err < N * tol: |
| 48 | + return dict(zip(nodelist, map(float, x[0, :]), strict=False)) |
| 49 | + |
| 50 | + raise nx.PowerIterationFailedConvergence(max_iter) |
| 51 | + |
| 52 | + |
| 53 | +if __name__ == "__main__": |
| 54 | + G = nx.DiGraph(nx.path_graph(4)) |
| 55 | + ITERS = 3 |
| 56 | + |
| 57 | + # compile |
| 58 | + pagerank(G) |
| 59 | + print("compiled") |
| 60 | + |
| 61 | + # finch |
| 62 | + start = time.time() |
| 63 | + for i in range(ITERS): |
| 64 | + print(f"finch iter: {i}") |
| 65 | + pr = pagerank(G) |
| 66 | + elapsed = time.time() - start |
| 67 | + print(f"Finch took {elapsed / ITERS} s.") |
| 68 | + |
| 69 | + # scipy |
| 70 | + start = time.time() |
| 71 | + for i in range(ITERS): |
| 72 | + print(f"scipy iter: {i}") |
| 73 | + scipy_pr = _pagerank_scipy(G) |
| 74 | + elapsed = time.time() - start |
| 75 | + print(f"SciPy took {elapsed / ITERS} s.") |
| 76 | + |
| 77 | + np.testing.assert_almost_equal(list(pr.values()), list(scipy_pr.values())) |
| 78 | + print(f"finch: {pr}") |
| 79 | + print(f"scipy: {scipy_pr}") |
0 commit comments