-
Notifications
You must be signed in to change notification settings - Fork 7
several quality-of-life improvements #293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
cfdb982
76fbb1f
e3d7ce0
1de04c4
4306a42
b496e8e
d2118e9
968083d
94777bf
ea5550b
c3955f5
8d7a136
4d648d1
0f36ea6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ Modules in PBjam | |
| peakbagging | ||
| plotting | ||
| samplers | ||
|
|
||
| query | ||
|
|
||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| import pbjam.distributions as dist | ||
| from dynesty import utils as dyfunc | ||
| jax.config.update('jax_enable_x64', True) | ||
| from functools import partial | ||
|
|
||
| class commonFuncs(jar.generalModelFuncs): | ||
| """ | ||
|
|
@@ -622,17 +623,17 @@ def setPriors(self,): | |
| self.DR.logpdf[i], | ||
| self.DR.cdf[i]) | ||
|
|
||
| AddKeys = [k for k in self.variables if k in self.addPriors.keys()] | ||
|
|
||
| self.priors.update({key : self.addPriors[key] for key in AddKeys}) | ||
|
|
||
| # Core rotation prior | ||
| self.priors['nurot_c'] = dist.uniform(loc=-2., scale=2.) | ||
|
|
||
| self.priors['nurot_e'] = dist.uniform(loc=-2., scale=2.) | ||
|
|
||
| # The inclination prior is a sine truncated between 0, and pi/2. | ||
| self.priors['inc'] = dist.truncsine() | ||
| self.priors['inc'] = dist.truncsine() | ||
|
|
||
| # override priors | ||
| AddKeys = [k for k in self.variables if k in self.addPriors.keys()] | ||
| self.priors.update({key : self.addPriors[key] for key in AddKeys}) | ||
|
|
||
|
|
||
| def model(self, thetaU,): | ||
| """ | ||
|
|
@@ -1085,19 +1086,18 @@ def setPriors(self,): | |
| self.DR.logpdf[i], | ||
| self.DR.cdf[i]) | ||
|
|
||
| AddKeys = [k for k in self.variables if k in self.addPriors.keys()] | ||
|
|
||
| self.priors.update({key : self.addPriors[key] for key in AddKeys}) | ||
|
|
||
| self.priors['q'] = dist.uniform(loc=0.01, scale=0.6) | ||
|
|
||
| # Core rotation prior | ||
| self.priors['nurot_c'] = dist.uniform(loc=-2., scale=3.) | ||
|
|
||
| self.priors['nurot_e'] = dist.uniform(loc=-2., scale=2.) | ||
|
|
||
| # The inclination prior is a sine truncated between 0, and pi/2. | ||
| self.priors['inc'] = dist.truncsine() | ||
| self.priors['inc'] = dist.truncsine() | ||
|
|
||
| # override priors | ||
| AddKeys = [k for k in self.variables if k in self.addPriors.keys()] | ||
| self.priors.update({key : self.addPriors[key] for key in AddKeys}) | ||
|
|
||
|
|
||
| def unpackParams(self, theta): | ||
| """ Cast the parameters in a dictionary | ||
|
|
@@ -1256,6 +1256,7 @@ def nearest(self, nu, nu_target): | |
|
|
||
| return nu_target[jnp.argmin(jnp.abs(nu[:, None] - nu_target[None, :]), axis=1)] | ||
|
|
||
| @partial(jax.jit, static_argnums=(0,)) | ||
| def Theta_p(self, nu, Dnu, nu_p): | ||
| """ | ||
| Compute the p-mode phase function Theta_p. | ||
|
|
@@ -1279,6 +1280,7 @@ def Theta_p(self, nu, Dnu, nu_p): | |
| (nu - self.nearest(nu, nu_p)) / Dnu + jnp.round((self.nearest(nu, nu_p) - nu_p[0]) / Dnu) | ||
| ) | ||
|
|
||
| @partial(jax.jit, static_argnums=(0,)) | ||
| def Theta_g(self, nu, DPi1, nu_g): | ||
| """ | ||
| Compute the g-mode phase function Theta_g. | ||
|
|
@@ -1303,6 +1305,7 @@ def Theta_g(self, nu, DPi1, nu_g): | |
| (1 / self.nearest(nu, nu_g) - 1 / nu) / DPi1 | ||
| ) | ||
|
|
||
| @partial(jax.jit, static_argnums=(0,)) | ||
| def zeta(self, nu, q, DPi1, Dnu, nu_p, nu_g): | ||
| """ | ||
| Compute the local mixing fraction zeta. | ||
|
|
@@ -1334,6 +1337,7 @@ def zeta(self, nu, q, DPi1, Dnu, nu_p, nu_g): | |
|
|
||
| return 1 / (1 + DPi1 / Dnu * nu**2 / q * jnp.sin(Theta_g)**2 / jnp.cos(Theta_p)**2) | ||
|
|
||
| @partial(jax.jit, static_argnums=(0,)) | ||
| def zeta_p(self, nu, q, DPi1, Dnu, nu_p): | ||
| """ | ||
| Compute the mixing fraction zeta using only the p-mode phase function. Agrees with zeta only at the | ||
|
|
@@ -1361,6 +1365,7 @@ def zeta_p(self, nu, q, DPi1, Dnu, nu_p): | |
|
|
||
| return 1 / (1 + DPi1 / Dnu * nu**2 / (q * jnp.cos(Theta)**2 + jnp.sin(Theta)**2/q)) | ||
|
|
||
| @partial(jax.jit, static_argnums=(0,)) | ||
| def zeta_g(self, nu, q, DPi1, Dnu, nu_g): | ||
|
|
||
| """ | ||
|
|
@@ -1390,6 +1395,7 @@ def zeta_g(self, nu, q, DPi1, Dnu, nu_g): | |
|
|
||
| return 1 / (1 + DPi1 / Dnu * nu**2 * (q * jnp.cos(Theta)**2 + jnp.sin(Theta)**2/q)) | ||
|
|
||
| @partial(jax.jit, static_argnums=(0,)) | ||
| def F(self, nu, nu_p, nu_g, Dnu, DPi1, q): | ||
| """ | ||
| Compute the characteristic function F such that F(nu) = 0 yields eigenvalues. | ||
|
|
@@ -1417,6 +1423,7 @@ def F(self, nu, nu_p, nu_g, Dnu, DPi1, q): | |
|
|
||
| return jnp.tan(self.Theta_p(nu, Dnu, nu_p)) * jnp.tan(self.Theta_g(nu, DPi1, nu_g)) - q | ||
|
|
||
| @partial(jax.jit, static_argnums=(0,)) | ||
| def Fp(self, nu, nu_p, nu_g, Dnu, DPi1, qp=0): | ||
| """ | ||
| Compute the first derivative dF/dnu of the characteristic function F. | ||
|
|
@@ -1446,6 +1453,7 @@ def Fp(self, nu, nu_p, nu_g, Dnu, DPi1, qp=0): | |
| + jnp.tan(self.Theta_p(nu, Dnu, nu_p)) / jnp.cos(self.Theta_g(nu, DPi1, nu_g))**2 * jnp.pi / DPi1 / nu**2 | ||
| - qp) | ||
|
|
||
| @partial(jax.jit, static_argnums=(0,)) | ||
| def Fpp(self, nu, nu_p, nu_g, Dnu, DPi1, qpp=0): | ||
| """ | ||
| Compute the second derivative d^2F / dnu^2of the characteristic function F. | ||
|
|
@@ -1477,6 +1485,7 @@ def Fpp(self, nu, nu_p, nu_g, Dnu, DPi1, qpp=0): | |
| + 2 / jnp.cos(self.Theta_p(nu, Dnu, nu_p))**2 * jnp.pi / Dnu / jnp.cos(self.Theta_g(nu, DPi1, nu_g))**2 * jnp.pi / DPi1 / nu**2 | ||
| - qpp) | ||
|
|
||
| @partial(jax.jit, static_argnums=(0,5)) | ||
| def halley_iteration(self, x, y, yp, ypp, lmbda=1.): | ||
| """ | ||
| Perform Halley's method (2nd order Householder) iteration, with damping | ||
|
|
@@ -1501,6 +1510,7 @@ def halley_iteration(self, x, y, yp, ypp, lmbda=1.): | |
| """ | ||
| return x - lmbda * 2 * y * yp / (2 * yp * yp - y * ypp) | ||
|
|
||
| @partial(jax.jit, static_argnums=(0,6)) | ||
| def couple(self, nu_p, nu_g, q_p, q_g, DPi1, lmbda=.5): | ||
| """ | ||
| Solve the characteristic equation using Halley's method to couple | ||
|
|
@@ -1528,21 +1538,20 @@ def couple(self, nu_p, nu_g, q_p, q_g, DPi1, lmbda=.5): | |
| num : array-like | ||
| Array of mixed mode frequencies. | ||
| """ | ||
|
|
||
| num_p = jnp.copy(nu_p) | ||
|
|
||
| num_g = jnp.copy(nu_g) | ||
|
|
||
| for _ in range(self.rootiter): | ||
| num_p = self.halley_iteration(num_p, | ||
| def _body(i, x0): | ||
| num_p, num_g = x0 | ||
| a = self.halley_iteration(num_p, | ||
| self.F(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1, q_p), | ||
| self.Fp(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1), | ||
| self.Fpp(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1), lmbda=lmbda) | ||
| num_g = self.halley_iteration(num_g, | ||
| b = self.halley_iteration(num_g, | ||
| self.F(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1, q_g), | ||
| self.Fp(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1), | ||
| self.Fpp(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1), lmbda=lmbda) | ||
| return a, b | ||
|
|
||
| num_p, num_g = jax.lax.fori_loop(0, self.rootiter, _body, (nu_p, nu_g)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this improve the compile-time? We of course also need to verify that results don't change.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this significantly improves compile times because of no loop unrolling, and allows parameterisation by number of iterations without recompilation. this also is (as far as I can tell) the preferred pattern for for loops. |
||
| return jnp.append(num_p, num_g) | ||
|
|
||
| def parseSamples(self, smp, Nmax=5000): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for changing the order of things here you'll need to check that it doesn't influence the sampling.
I remember having problems with that at some point, where I thought using dictionaries should have solved this, but apparently it didn't.
Just so we aren't sampling nurot_e when we think it's something else like eps_g.