Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.14
0.1.0
52 changes: 31 additions & 21 deletions driftbench/data_generation/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,39 @@ def __init__(self, f, w0, max_fit_attemps):
self.w0 = jnp.array(w0)
self.max_fit_attempts = max_fit_attemps

def _latents_to_array(self, latents):
l_x0_mat = jnp.array([latent.x0 for latent in latents])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like would be a bit more code efficient (while harder to understand though)

return (
    jnp.array([attr(latent, "dim") for latent in latents]) for dim in ["x0", "x1",....]
)

l_x1_mat = jnp.array([latent.x1 for latent in latents])
l_x2_mat = jnp.array([latent.x2 for latent in latents])
l_y0_mat = jnp.array([latent.y0 for latent in latents])
l_y1_mat = jnp.array([latent.y1 for latent in latents])
l_y2_mat = jnp.array([latent.y2 for latent in latents])
return l_x0_mat, l_x1_mat, l_x2_mat, l_y0_mat, l_y1_mat, l_y2_mat

def solve(self, X, callback=None):
coefficients = []
solution = self.w0
for i, latent in enumerate(X):
result = _minimize(
self.f,
self.df_dx,
self.df_dx2,
solution,
latent.y0,
latent.x0,
latent.y1,
latent.x1,
latent.y2,
latent.x2,
)
if not result.success:
result = self._refit(self.f, self.df_dx, self.df_dx2, latent)
solution = result.x
if callback:
jax.debug.callback(callback, i, solution)
coefficients.append(solution)
return jnp.array(coefficients)
l_x0_mat, l_x1_mat, l_x2_mat, l_y0_mat, l_y1_mat, l_y2_mat = (
self._latents_to_array(X)
)
min_vmapped = jit(
vmap(
partial(_minimize), in_axes=(None, None, None, None, 0, 0, 0, 0, 0, 0)
),
static_argnums=(0, 1, 2),
)
result = min_vmapped(
self.f,
self.df_dx,
self.df_dx2,
solution,
l_y0_mat,
l_x0_mat,
l_y1_mat,
l_x1_mat,
l_y2_mat,
l_x2_mat,
)
return result.x

def _refit(self, p, dp_dx, dp_dx2, latent):
# Restart with initial guess in order to be independent of previous solutions.
Expand Down