Skip to content

Commit e920781

Browse files
authored
Merge pull request sandialabs#76 from btalamini/feature/tensor3x3ops
Some new and some improved tensor operations
2 parents de3f241 + 650ea04 commit e920781

14 files changed

+733
-676
lines changed

examples/adjoint_with_ivs/parameterized_j2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def compute_elastic_logarithmic_strain(dispGrad, state):
220220
Je = np.linalg.det(FeT) # = J since this model is isochoric plasticity
221221
traceEe = np.log(Je)
222222
CeIso = Je**(-2./3.)*FeT@FeT.T
223-
EeDev = TensorMath.mtk_log_sqrt(CeIso)
223+
EeDev = TensorMath.log_sqrt_symm(CeIso)
224224
return EeDev + traceEe/3.0*np.identity(3)
225225

226226

examples/adjoint_with_ivs/parameterized_linear_elastic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,5 @@ def log_strain(dispGrad):
6767
J = np.linalg.det(F)
6868
traceStrain = np.log(J)
6969
CIso = J**(-2.0/3.0)*F.T@F
70-
devStrain = TensorMath.mtk_log_sqrt(CIso)
70+
devStrain = TensorMath.log_sqrt_symm(CIso)
7171
return devStrain + traceStrain/3.0*np.identity(3)

optimism/J2PlasticPhaseField.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def compute_logarithmic_elastic_strain(dispGrad, state):
8484
Fp = state[PLASTIC_STRAIN].reshape((3,3))
8585
FeT = solve(Fp.T, F.T)
8686
Ce = FeT@FeT.T
87-
return TensorMath.mtk_log_sqrt(Ce)
87+
return TensorMath.log_sqrt_symm(Ce)
8888

8989

9090
def compute_state_increment(elasticTrialStrain, stateOld, props):

optimism/LinAlg.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import jax
2+
import jax.numpy as np
3+
4+
from optimism.JaxConfig import if_then_else
5+
from optimism.QuadratureRule import create_padded_quadrature_rule_1D
6+
7+
@jax.custom_jvp
8+
def sqrtm(A):
9+
sqrtA,_ = sqrtm_dbp(A)
10+
return sqrtA
11+
12+
13+
@sqrtm.defjvp
14+
def jvp_sqrtm(primals, tangents):
15+
A, = primals
16+
H, = tangents
17+
sqrtA = sqrtm(A)
18+
dim = A.shape[0]
19+
# TODO(brandon): Use a stable algorithm for solving a Sylvester equation.
20+
# See https://en.wikipedia.org/wiki/Bartels%E2%80%93Stewart_algorithm
21+
# The following will only reliably work for small matrices.
22+
I = np.identity(dim)
23+
M = np.kron(sqrtA.T, I) + np.kron(I, sqrtA)
24+
Hvec = H.T.ravel()
25+
return sqrtA, (np.linalg.solve(M, Hvec)).reshape((dim,dim)).T
26+
27+
28+
def sqrtm_dbp(A):
29+
""" Matrix square root by product form of Denman-Beavers iteration.
30+
31+
Translated from the Matrix Function Toolbox
32+
http://www.ma.man.ac.uk/~higham/mftoolbox
33+
Nicholas J. Higham, Functions of Matrices: Theory and Computation,
34+
SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7,
35+
"""
36+
dim = A.shape[0]
37+
tol = 0.5 * np.sqrt(dim) * np.finfo(np.dtype("float64")).eps
38+
maxIters = 32
39+
scaleTol = 0.01
40+
41+
def scaling(M):
42+
d = np.abs(np.linalg.det(M))**(1.0/(2.0*dim))
43+
g = 1.0 / d
44+
return g
45+
46+
def cond_f(loopData):
47+
_,_,error,k,_ = loopData
48+
p = np.array([k < maxIters, error > tol], dtype=bool)
49+
return np.all(p)
50+
51+
def body_f(loopData):
52+
X, M, error, k, diff = loopData
53+
g = np.where(diff >= scaleTol,
54+
scaling(M),
55+
1.0)
56+
57+
X *= g
58+
M *= g * g
59+
60+
Y = X
61+
N = np.linalg.inv(M)
62+
I = np.identity(dim)
63+
X = 0.5 * X @ (I + N)
64+
M = 0.5 * (I + 0.5 * (M + N))
65+
error = np.linalg.norm(M - I, 'fro')
66+
diff = np.linalg.norm(X - Y, 'fro') / np.linalg.norm(X, 'fro')
67+
k += 1
68+
return (X, M, error, k, diff)
69+
70+
X0 = A
71+
M0 = A
72+
error0 = np.finfo(np.dtype("float64")).max
73+
k0 = 0
74+
diff0 = 2.0*scaleTol # want to force scaling on first iteration
75+
loopData0 = (X0, M0, error0, k0, diff0)
76+
77+
X,_,_,k,_ = jax.lax.while_loop(cond_f, body_f, loopData0)
78+
79+
return X,k
80+
81+
82+
@jax.custom_jvp
83+
def logm_iss(A):
84+
X,k,m = _logm_iss(A)
85+
return (1 << k) * log_pade_pf(X - np.identity(A.shape[0]), m)
86+
87+
88+
@logm_iss.defjvp
89+
def logm_jvp(primals, tangents):
90+
A, = primals
91+
H, = tangents
92+
logA = logm_iss(A)
93+
DexpLogA = jax.jacfwd(jax.scipy.linalg.expm)(logA)
94+
dim = A.shape[0]
95+
JVP = np.linalg.solve(DexpLogA.reshape(dim*dim,-1), H.ravel())
96+
return logA, JVP.reshape(dim,dim)
97+
98+
99+
def _logm_iss(A):
100+
"""Logarithmic map by inverse scaling and squaring and Padé approximants
101+
102+
Translated from the Matrix Function Toolbox
103+
http://www.ma.man.ac.uk/~higham/mftoolbox
104+
Nicholas J. Higham, Functions of Matrices: Theory and Computation,
105+
SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7,
106+
"""
107+
dim = A.shape[0]
108+
c15 = log_pade_coefficients[15]
109+
110+
def cond_f(loopData):
111+
_,_,k,_,_,converged = loopData
112+
conditions = np.array([~converged, k < 16], dtype = bool)
113+
return conditions.all()
114+
115+
def compute_pade_degree(diff, j, itk):
116+
j += 1
117+
# Manually force the return type of searchsorted to be 64-bit int, because it
118+
# returns 32-bit ints, ignoring the global `jax_enable_x64` flag. This looks
119+
# like a bug. I filed an issue (#11375) with Jax to correct this.
120+
# If they fix it, the conversions on p and q can be removed.
121+
p = np.searchsorted(log_pade_coefficients[2:16], diff, side='right').astype(np.int64)
122+
p += 2
123+
q = np.searchsorted(log_pade_coefficients[2:16], diff/2.0, side='right').astype(np.int64)
124+
q += 2
125+
m,j,converged = if_then_else((2 * (p - q) // 3 < itk) | (j == 2),
126+
(p+1,j,True), (0,j,False))
127+
return m,j,converged
128+
129+
def body_f(loopData):
130+
X,j,k,m,itk,converged = loopData
131+
diff = np.linalg.norm(X - np.identity(dim), ord=1)
132+
m,j,converged = if_then_else(diff < c15,
133+
compute_pade_degree(diff, j, itk),
134+
(m, j, converged))
135+
X,itk = sqrtm_dbp(X)
136+
k += 1
137+
return X,j,k,m,itk,converged
138+
139+
X = A
140+
j = 0
141+
k = 0
142+
m = 0
143+
itk = 5
144+
converged = False
145+
X,j,k,m,itk,converged = jax.lax.while_loop(cond_f, body_f, (X,j,k,m,itk,converged))
146+
return X,k,m
147+
148+
149+
def log_pade_pf(A, n):
150+
"""Logarithmic map by Padé approximant and partial fractions
151+
"""
152+
I = np.identity(A.shape[0])
153+
X = np.zeros_like(A)
154+
quadPrec = 2*n - 1
155+
xs,ws = create_padded_quadrature_rule_1D(quadPrec)
156+
157+
def get_log_inc(A, x, w):
158+
B = I + x*A
159+
dXT = w*np.linalg.solve(B.T, A.T)
160+
return dXT
161+
162+
dXsTransposed = jax.vmap(get_log_inc, (None, 0, 0))(A, xs, ws)
163+
X = np.sum(dXsTransposed, axis=0).T
164+
165+
return X
166+
167+
168+
log_pade_coefficients = np.array([
169+
1.100343044625278e-05, 1.818617533662554e-03, 1.620628479501567e-02, 5.387353263138127e-02,
170+
1.135280226762866e-01, 1.866286061354130e-01, 2.642960831111435e-01, 3.402172331985299e-01,
171+
4.108235000556820e-01, 4.745521256007768e-01, 5.310667521178455e-01, 5.806887133441684e-01,
172+
6.240414344012918e-01, 6.618482563071411e-01, 6.948266172489354e-01, 7.236382701437292e-01,
173+
7.488702930926310e-01, 7.710320825151814e-01, 7.905600074925671e-01, 8.078252198050853e-01,
174+
8.231422814010787e-01, 8.367774696147783e-01, 8.489562661576765e-01, 8.598698723737197e-01,
175+
8.696807597657327e-01, 8.785273397512191e-01, 8.865278635527148e-01, 8.937836659824918e-01,
176+
9.003818585631236e-01, 9.063975647545747e-01, 9.118957765024351e-01, 9.169328985287867e-01,
177+
9.215580354375991e-01, 9.258140669835052e-01, 9.297385486977516e-01, 9.333644683151422e-01,
178+
9.367208829050256e-01, 9.398334570841484e-01, 9.427249190039424e-01, 9.454154478075423e-01,
179+
9.479230038146050e-01, 9.502636107090112e-01, 9.524515973891873e-01, 9.544998058228285e-01,
180+
9.564197701703862e-01, 9.582218715590143e-01, 9.599154721638511e-01, 9.615090316568806e-01,
181+
9.630102085912245e-01, 9.644259488813590e-01, 9.657625632018019e-01, 9.670257948457799e-01,
182+
9.682208793510226e-01, 9.693525970039069e-01, 9.704253191689650e-01, 9.714430492527785e-01,
183+
9.724094589950460e-01, 9.733279206814576e-01, 9.742015357899175e-01, 9.750331605111618e-01,
184+
9.758254285248543e-01, 9.765807713611383e-01, 9.773014366339591e-01, 9.779895043950849e-01 ])

0 commit comments

Comments
 (0)