Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Modules in PBjam
peakbagging
plotting
samplers

query



Expand Down
42 changes: 34 additions & 8 deletions pbjam/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,16 @@ class star(plotting):
Dictionary of additional keyword arguments for either modeID or peakbag.
"""

def __init__(self, name, f, s, obs, outpath=None, **kwargs):
def __init__(self, name, f, s, obs, outpath=None, mask=None, **kwargs):

# sanitize inputs

if mask is None:
mask = f > 0
f = f[mask]
s = s[mask]

self.__dict__.update((k, v) for k, v in locals().items() if k not in ['self'])
self.__dict__.update((k, v) for k, v in locals().items() if k not in ['self', 'mask'])

self.__dict__.update(kwargs)

Expand All @@ -326,12 +333,11 @@ def __init__(self, name, f, s, obs, outpath=None, **kwargs):
assert isinstance(val, Iterable), 'Entries in obs must be of the form (value, error)'
assert len(val) == 2, 'Entries in obs must be of the form (value, error)'

def runModeID(self, modeID_kwargs={}):
def makeModeID(self, **modeID_kwargs):
""" Run the mode identification process using the provided or default keyword arguments.

This method creates a `modeID` instance and executes it with the arguments provided in
`modeID_kwargs` or from the current object's attributes. If `priorpath` is not specified,
it fetches the path to the prior file.
This method creates a `modeID` instance ONLY. This function is for advanced usage only;
it is automatically called when running star.runModeID().

Parameters
----------
Expand All @@ -346,7 +352,6 @@ def runModeID(self, modeID_kwargs={}):
"""

_modeID_kwargs = copy.deepcopy(self.__dict__)

_modeID_kwargs.update(modeID_kwargs)

if not 'priorpath' in _modeID_kwargs:
Expand All @@ -355,8 +360,29 @@ def runModeID(self, modeID_kwargs={}):
_modeID_kwargs['priorpath'] = self.priorpath

self.modeID = modeID(**_modeID_kwargs)
self._modeID_kwargs = _modeID_kwargs

def runModeID(self, modeID_kwargs={}):
""" Run the mode identification process using the provided or default keyword arguments.

This method creates a `modeID` instance and executes it with the arguments provided in
`modeID_kwargs` or from the current object's attributes. If `priorpath` is not specified,
it fetches the path to the prior file.

self.modeID(**_modeID_kwargs)
Parameters
----------
modeID_kwargs : dict, optional
Dictionary of additional keyword arguments to update or override the current object's attributes
when initializing the `modeID` instance. Default is an empty dictionary.

Raises
------
KeyError
If required parameters for mode identification are missing.
"""
if not hasattr(self, "modeID"):
self.makeModeID(**modeID_kwargs)
self.modeID(**self._modeID_kwargs)

def runPeakbag(self, peakbag_kwargs={}):
""" Run the peakbagging process using the provided or default keyword arguments.
Expand Down
13 changes: 13 additions & 0 deletions pbjam/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def __init__(self, a=1, b=1, loc=0, scale=1):

self._set_stdatt()

def __repr__(self):
return f'Distribution: Beta(a={self.a}, b={self.b}, x0={self.loc}, x1={self.loc + self.scale})'

def rv(self):
""" Draw random variable from distribution

Expand Down Expand Up @@ -411,6 +414,8 @@ def __init__(self, loc=0, scale=1):

self._set_stdatt()

def __repr__(self):
return f'Distribution: Uniform(x1={self.a}, x2={self.b})'

def rv(self):
""" Draw random variable from distribution
Expand Down Expand Up @@ -544,6 +549,8 @@ def __init__(self, loc=0, scale=1):

self._set_stdatt()

def __repr__(self):
return f'Distribution: Normal(μ={self.loc}, σ={self.scale})'

def rv(self):
""" Draw random variable from distribution
Expand Down Expand Up @@ -656,6 +663,9 @@ def __init__(self,):
"""

self._set_stdatt()

def __repr__(self):
return f'Distribution: TruncatedSine'

def rv(self):
""" Draw random variable from distribution
Expand Down Expand Up @@ -773,6 +783,9 @@ def __init__(self, low, high):

self._set_stdatt()

def __repr__(self):
return f'Distribution: RandInt(min={self.low}, max={self.high})'

def rv(self):
""" Draw random variable from distribution

Expand Down
49 changes: 29 additions & 20 deletions pbjam/l1models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pbjam.distributions as dist
from dynesty import utils as dyfunc
jax.config.update('jax_enable_x64', True)
from functools import partial

class commonFuncs(jar.generalModelFuncs):
"""
Expand Down Expand Up @@ -622,17 +623,17 @@ def setPriors(self,):
self.DR.logpdf[i],
self.DR.cdf[i])

AddKeys = [k for k in self.variables if k in self.addPriors.keys()]

self.priors.update({key : self.addPriors[key] for key in AddKeys})

# Core rotation prior
self.priors['nurot_c'] = dist.uniform(loc=-2., scale=2.)

self.priors['nurot_e'] = dist.uniform(loc=-2., scale=2.)

# The inclination prior is a sine truncated between 0, and pi/2.
self.priors['inc'] = dist.truncsine()
self.priors['inc'] = dist.truncsine()

# override priors
AddKeys = [k for k in self.variables if k in self.addPriors.keys()]
self.priors.update({key : self.addPriors[key] for key in AddKeys})


def model(self, thetaU,):
"""
Expand Down Expand Up @@ -1085,19 +1086,18 @@ def setPriors(self,):
self.DR.logpdf[i],
self.DR.cdf[i])

AddKeys = [k for k in self.variables if k in self.addPriors.keys()]

self.priors.update({key : self.addPriors[key] for key in AddKeys})

self.priors['q'] = dist.uniform(loc=0.01, scale=0.6)

# Core rotation prior
self.priors['nurot_c'] = dist.uniform(loc=-2., scale=3.)

self.priors['nurot_e'] = dist.uniform(loc=-2., scale=2.)

# The inclination prior is a sine truncated between 0, and pi/2.
self.priors['inc'] = dist.truncsine()
self.priors['inc'] = dist.truncsine()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for changing the order of things here you'll need to check that it doesn't influence the sampling.

I remember having problems with that at some point, where I thought using dictionaries should have solved this, but apparently it didn't.

Just so we aren't sampling nurot_e when we think it's something else like eps_g.


# override priors
AddKeys = [k for k in self.variables if k in self.addPriors.keys()]
self.priors.update({key : self.addPriors[key] for key in AddKeys})


def unpackParams(self, theta):
""" Cast the parameters in a dictionary
Expand Down Expand Up @@ -1256,6 +1256,7 @@ def nearest(self, nu, nu_target):

return nu_target[jnp.argmin(jnp.abs(nu[:, None] - nu_target[None, :]), axis=1)]

@partial(jax.jit, static_argnums=(0,))
def Theta_p(self, nu, Dnu, nu_p):
"""
Compute the p-mode phase function Theta_p.
Expand All @@ -1279,6 +1280,7 @@ def Theta_p(self, nu, Dnu, nu_p):
(nu - self.nearest(nu, nu_p)) / Dnu + jnp.round((self.nearest(nu, nu_p) - nu_p[0]) / Dnu)
)

@partial(jax.jit, static_argnums=(0,))
def Theta_g(self, nu, DPi1, nu_g):
"""
Compute the g-mode phase function Theta_g.
Expand All @@ -1303,6 +1305,7 @@ def Theta_g(self, nu, DPi1, nu_g):
(1 / self.nearest(nu, nu_g) - 1 / nu) / DPi1
)

@partial(jax.jit, static_argnums=(0,))
def zeta(self, nu, q, DPi1, Dnu, nu_p, nu_g):
"""
Compute the local mixing fraction zeta.
Expand Down Expand Up @@ -1334,6 +1337,7 @@ def zeta(self, nu, q, DPi1, Dnu, nu_p, nu_g):

return 1 / (1 + DPi1 / Dnu * nu**2 / q * jnp.sin(Theta_g)**2 / jnp.cos(Theta_p)**2)

@partial(jax.jit, static_argnums=(0,))
def zeta_p(self, nu, q, DPi1, Dnu, nu_p):
"""
Compute the mixing fraction zeta using only the p-mode phase function. Agrees with zeta only at the
Expand Down Expand Up @@ -1361,6 +1365,7 @@ def zeta_p(self, nu, q, DPi1, Dnu, nu_p):

return 1 / (1 + DPi1 / Dnu * nu**2 / (q * jnp.cos(Theta)**2 + jnp.sin(Theta)**2/q))

@partial(jax.jit, static_argnums=(0,))
def zeta_g(self, nu, q, DPi1, Dnu, nu_g):

"""
Expand Down Expand Up @@ -1390,6 +1395,7 @@ def zeta_g(self, nu, q, DPi1, Dnu, nu_g):

return 1 / (1 + DPi1 / Dnu * nu**2 * (q * jnp.cos(Theta)**2 + jnp.sin(Theta)**2/q))

@partial(jax.jit, static_argnums=(0,))
def F(self, nu, nu_p, nu_g, Dnu, DPi1, q):
"""
Compute the characteristic function F such that F(nu) = 0 yields eigenvalues.
Expand Down Expand Up @@ -1417,6 +1423,7 @@ def F(self, nu, nu_p, nu_g, Dnu, DPi1, q):

return jnp.tan(self.Theta_p(nu, Dnu, nu_p)) * jnp.tan(self.Theta_g(nu, DPi1, nu_g)) - q

@partial(jax.jit, static_argnums=(0,))
def Fp(self, nu, nu_p, nu_g, Dnu, DPi1, qp=0):
"""
Compute the first derivative dF/dnu of the characteristic function F.
Expand Down Expand Up @@ -1446,6 +1453,7 @@ def Fp(self, nu, nu_p, nu_g, Dnu, DPi1, qp=0):
+ jnp.tan(self.Theta_p(nu, Dnu, nu_p)) / jnp.cos(self.Theta_g(nu, DPi1, nu_g))**2 * jnp.pi / DPi1 / nu**2
- qp)

@partial(jax.jit, static_argnums=(0,))
def Fpp(self, nu, nu_p, nu_g, Dnu, DPi1, qpp=0):
"""
Compute the second derivative d^2F / dnu^2of the characteristic function F.
Expand Down Expand Up @@ -1477,6 +1485,7 @@ def Fpp(self, nu, nu_p, nu_g, Dnu, DPi1, qpp=0):
+ 2 / jnp.cos(self.Theta_p(nu, Dnu, nu_p))**2 * jnp.pi / Dnu / jnp.cos(self.Theta_g(nu, DPi1, nu_g))**2 * jnp.pi / DPi1 / nu**2
- qpp)

@partial(jax.jit, static_argnums=(0,5))
def halley_iteration(self, x, y, yp, ypp, lmbda=1.):
"""
Perform Halley's method (2nd order Householder) iteration, with damping
Expand All @@ -1501,6 +1510,7 @@ def halley_iteration(self, x, y, yp, ypp, lmbda=1.):
"""
return x - lmbda * 2 * y * yp / (2 * yp * yp - y * ypp)

@partial(jax.jit, static_argnums=(0,6))
def couple(self, nu_p, nu_g, q_p, q_g, DPi1, lmbda=.5):
"""
Solve the characteristic equation using Halley's method to couple
Expand Down Expand Up @@ -1528,21 +1538,20 @@ def couple(self, nu_p, nu_g, q_p, q_g, DPi1, lmbda=.5):
num : array-like
Array of mixed mode frequencies.
"""

num_p = jnp.copy(nu_p)

num_g = jnp.copy(nu_g)

for _ in range(self.rootiter):
num_p = self.halley_iteration(num_p,
def _body(i, x0):
num_p, num_g = x0
a = self.halley_iteration(num_p,
self.F(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1, q_p),
self.Fp(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1),
self.Fpp(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1), lmbda=lmbda)
num_g = self.halley_iteration(num_g,
b = self.halley_iteration(num_g,
self.F(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1, q_g),
self.Fp(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1),
self.Fpp(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1), lmbda=lmbda)
return a, b

num_p, num_g = jax.lax.fori_loop(0, self.rootiter, _body, (nu_p, nu_g))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this improve the compile-time?

We of course also need to verify that results don't change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this significantly improves compile times because of no loop unrolling, and allows parameterisation by number of iterations without recompilation. this also is (as far as I can tell) the preferred pattern for for loops.

return jnp.append(num_p, num_g)

def parseSamples(self, smp, Nmax=5000):
Expand Down
Loading