|
| 1 | +import jax |
| 2 | +import jax.numpy as np |
| 3 | + |
| 4 | +from optimism.JaxConfig import if_then_else |
| 5 | +from optimism.QuadratureRule import create_padded_quadrature_rule_1D |
| 6 | + |
| 7 | +@jax.custom_jvp |
| 8 | +def sqrtm(A): |
| 9 | + sqrtA,_ = sqrtm_dbp(A) |
| 10 | + return sqrtA |
| 11 | + |
| 12 | + |
| 13 | +@sqrtm.defjvp |
| 14 | +def jvp_sqrtm(primals, tangents): |
| 15 | + A, = primals |
| 16 | + H, = tangents |
| 17 | + sqrtA = sqrtm(A) |
| 18 | + dim = A.shape[0] |
| 19 | + # TODO(brandon): Use a stable algorithm for solving a Sylvester equation. |
| 20 | + # See https://en.wikipedia.org/wiki/Bartels%E2%80%93Stewart_algorithm |
| 21 | + # The following will only reliably work for small matrices. |
| 22 | + I = np.identity(dim) |
| 23 | + M = np.kron(sqrtA.T, I) + np.kron(I, sqrtA) |
| 24 | + Hvec = H.T.ravel() |
| 25 | + return sqrtA, (np.linalg.solve(M, Hvec)).reshape((dim,dim)).T |
| 26 | + |
| 27 | + |
| 28 | +def sqrtm_dbp(A): |
| 29 | + """ Matrix square root by product form of Denman-Beavers iteration. |
| 30 | + |
| 31 | + Translated from the Matrix Function Toolbox |
| 32 | + http://www.ma.man.ac.uk/~higham/mftoolbox |
| 33 | + Nicholas J. Higham, Functions of Matrices: Theory and Computation, |
| 34 | + SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7, |
| 35 | + """ |
| 36 | + dim = A.shape[0] |
| 37 | + tol = 0.5 * np.sqrt(dim) * np.finfo(np.dtype("float64")).eps |
| 38 | + maxIters = 32 |
| 39 | + scaleTol = 0.01 |
| 40 | + |
| 41 | + def scaling(M): |
| 42 | + d = np.abs(np.linalg.det(M))**(1.0/(2.0*dim)) |
| 43 | + g = 1.0 / d |
| 44 | + return g |
| 45 | + |
| 46 | + def cond_f(loopData): |
| 47 | + _,_,error,k,_ = loopData |
| 48 | + p = np.array([k < maxIters, error > tol], dtype=bool) |
| 49 | + return np.all(p) |
| 50 | + |
| 51 | + def body_f(loopData): |
| 52 | + X, M, error, k, diff = loopData |
| 53 | + g = np.where(diff >= scaleTol, |
| 54 | + scaling(M), |
| 55 | + 1.0) |
| 56 | + |
| 57 | + X *= g |
| 58 | + M *= g * g |
| 59 | + |
| 60 | + Y = X |
| 61 | + N = np.linalg.inv(M) |
| 62 | + I = np.identity(dim) |
| 63 | + X = 0.5 * X @ (I + N) |
| 64 | + M = 0.5 * (I + 0.5 * (M + N)) |
| 65 | + error = np.linalg.norm(M - I, 'fro') |
| 66 | + diff = np.linalg.norm(X - Y, 'fro') / np.linalg.norm(X, 'fro') |
| 67 | + k += 1 |
| 68 | + return (X, M, error, k, diff) |
| 69 | + |
| 70 | + X0 = A |
| 71 | + M0 = A |
| 72 | + error0 = np.finfo(np.dtype("float64")).max |
| 73 | + k0 = 0 |
| 74 | + diff0 = 2.0*scaleTol # want to force scaling on first iteration |
| 75 | + loopData0 = (X0, M0, error0, k0, diff0) |
| 76 | + |
| 77 | + X,_,_,k,_ = jax.lax.while_loop(cond_f, body_f, loopData0) |
| 78 | + |
| 79 | + return X,k |
| 80 | + |
| 81 | + |
| 82 | +@jax.custom_jvp |
| 83 | +def logm_iss(A): |
| 84 | + X,k,m = _logm_iss(A) |
| 85 | + return (1 << k) * log_pade_pf(X - np.identity(A.shape[0]), m) |
| 86 | + |
| 87 | + |
| 88 | +@logm_iss.defjvp |
| 89 | +def logm_jvp(primals, tangents): |
| 90 | + A, = primals |
| 91 | + H, = tangents |
| 92 | + logA = logm_iss(A) |
| 93 | + DexpLogA = jax.jacfwd(jax.scipy.linalg.expm)(logA) |
| 94 | + dim = A.shape[0] |
| 95 | + JVP = np.linalg.solve(DexpLogA.reshape(dim*dim,-1), H.ravel()) |
| 96 | + return logA, JVP.reshape(dim,dim) |
| 97 | + |
| 98 | + |
| 99 | +def _logm_iss(A): |
| 100 | + """Logarithmic map by inverse scaling and squaring and Padé approximants |
| 101 | + |
| 102 | + Translated from the Matrix Function Toolbox |
| 103 | + http://www.ma.man.ac.uk/~higham/mftoolbox |
| 104 | + Nicholas J. Higham, Functions of Matrices: Theory and Computation, |
| 105 | + SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7, |
| 106 | + """ |
| 107 | + dim = A.shape[0] |
| 108 | + c15 = log_pade_coefficients[15] |
| 109 | + |
| 110 | + def cond_f(loopData): |
| 111 | + _,_,k,_,_,converged = loopData |
| 112 | + conditions = np.array([~converged, k < 16], dtype = bool) |
| 113 | + return conditions.all() |
| 114 | + |
| 115 | + def compute_pade_degree(diff, j, itk): |
| 116 | + j += 1 |
| 117 | + # Manually force the return type of searchsorted to be 64-bit int, because it |
| 118 | + # returns 32-bit ints, ignoring the global `jax_enable_x64` flag. This looks |
| 119 | + # like a bug. I filed an issue (#11375) with Jax to correct this. |
| 120 | + # If they fix it, the conversions on p and q can be removed. |
| 121 | + p = np.searchsorted(log_pade_coefficients[2:16], diff, side='right').astype(np.int64) |
| 122 | + p += 2 |
| 123 | + q = np.searchsorted(log_pade_coefficients[2:16], diff/2.0, side='right').astype(np.int64) |
| 124 | + q += 2 |
| 125 | + m,j,converged = if_then_else((2 * (p - q) // 3 < itk) | (j == 2), |
| 126 | + (p+1,j,True), (0,j,False)) |
| 127 | + return m,j,converged |
| 128 | + |
| 129 | + def body_f(loopData): |
| 130 | + X,j,k,m,itk,converged = loopData |
| 131 | + diff = np.linalg.norm(X - np.identity(dim), ord=1) |
| 132 | + m,j,converged = if_then_else(diff < c15, |
| 133 | + compute_pade_degree(diff, j, itk), |
| 134 | + (m, j, converged)) |
| 135 | + X,itk = sqrtm_dbp(X) |
| 136 | + k += 1 |
| 137 | + return X,j,k,m,itk,converged |
| 138 | + |
| 139 | + X = A |
| 140 | + j = 0 |
| 141 | + k = 0 |
| 142 | + m = 0 |
| 143 | + itk = 5 |
| 144 | + converged = False |
| 145 | + X,j,k,m,itk,converged = jax.lax.while_loop(cond_f, body_f, (X,j,k,m,itk,converged)) |
| 146 | + return X,k,m |
| 147 | + |
| 148 | + |
| 149 | +def log_pade_pf(A, n): |
| 150 | + """Logarithmic map by Padé approximant and partial fractions |
| 151 | + """ |
| 152 | + I = np.identity(A.shape[0]) |
| 153 | + X = np.zeros_like(A) |
| 154 | + quadPrec = 2*n - 1 |
| 155 | + xs,ws = create_padded_quadrature_rule_1D(quadPrec) |
| 156 | + |
| 157 | + def get_log_inc(A, x, w): |
| 158 | + B = I + x*A |
| 159 | + dXT = w*np.linalg.solve(B.T, A.T) |
| 160 | + return dXT |
| 161 | + |
| 162 | + dXsTransposed = jax.vmap(get_log_inc, (None, 0, 0))(A, xs, ws) |
| 163 | + X = np.sum(dXsTransposed, axis=0).T |
| 164 | + |
| 165 | + return X |
| 166 | + |
| 167 | + |
| 168 | +log_pade_coefficients = np.array([ |
| 169 | + 1.100343044625278e-05, 1.818617533662554e-03, 1.620628479501567e-02, 5.387353263138127e-02, |
| 170 | + 1.135280226762866e-01, 1.866286061354130e-01, 2.642960831111435e-01, 3.402172331985299e-01, |
| 171 | + 4.108235000556820e-01, 4.745521256007768e-01, 5.310667521178455e-01, 5.806887133441684e-01, |
| 172 | + 6.240414344012918e-01, 6.618482563071411e-01, 6.948266172489354e-01, 7.236382701437292e-01, |
| 173 | + 7.488702930926310e-01, 7.710320825151814e-01, 7.905600074925671e-01, 8.078252198050853e-01, |
| 174 | + 8.231422814010787e-01, 8.367774696147783e-01, 8.489562661576765e-01, 8.598698723737197e-01, |
| 175 | + 8.696807597657327e-01, 8.785273397512191e-01, 8.865278635527148e-01, 8.937836659824918e-01, |
| 176 | + 9.003818585631236e-01, 9.063975647545747e-01, 9.118957765024351e-01, 9.169328985287867e-01, |
| 177 | + 9.215580354375991e-01, 9.258140669835052e-01, 9.297385486977516e-01, 9.333644683151422e-01, |
| 178 | + 9.367208829050256e-01, 9.398334570841484e-01, 9.427249190039424e-01, 9.454154478075423e-01, |
| 179 | + 9.479230038146050e-01, 9.502636107090112e-01, 9.524515973891873e-01, 9.544998058228285e-01, |
| 180 | + 9.564197701703862e-01, 9.582218715590143e-01, 9.599154721638511e-01, 9.615090316568806e-01, |
| 181 | + 9.630102085912245e-01, 9.644259488813590e-01, 9.657625632018019e-01, 9.670257948457799e-01, |
| 182 | + 9.682208793510226e-01, 9.693525970039069e-01, 9.704253191689650e-01, 9.714430492527785e-01, |
| 183 | + 9.724094589950460e-01, 9.733279206814576e-01, 9.742015357899175e-01, 9.750331605111618e-01, |
| 184 | + 9.758254285248543e-01, 9.765807713611383e-01, 9.773014366339591e-01, 9.779895043950849e-01 ]) |
0 commit comments