Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.13
0.0.14
8 changes: 5 additions & 3 deletions driftbench/data_generation/data_generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from abc import ABCMeta, abstractmethod
from driftbench.data_generation.solvers import JaxCurveGenerationSolver


class DataGenerator(metaclass=ABCMeta):
"""
Represents a generator for high-dimensional data.
"""

@abstractmethod
def run(self, X):
"""
Expand All @@ -24,15 +26,15 @@ class CurveGenerator(DataGenerator):
Based on a polynomial and an initial guess, the generator computes coefficients,
which meet the constraints provided by the latent information.
"""
def __init__(self, p, w0, max_fit_attempts=100, random_seed=42):

def __init__(self, p, w0, max_fit_attempts=100):
"""
Args:
p (func): The polynomial to fit.
w0 (list-like): The initial guess.
max_fit_attemps (int): The maxmium number of attempts to refit a curve, if optimization didn't succeed.
random_seed (int): The random seed for the random number generator.
"""
self.solver = JaxCurveGenerationSolver(p, w0, max_fit_attempts, random_seed)
self.solver = JaxCurveGenerationSolver(p, w0, max_fit_attempts)

def run(self, X, callback=None):
return self.solver.solve(X, callback=callback)
27 changes: 12 additions & 15 deletions driftbench/data_generation/solvers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABCMeta, abstractmethod
import numpy as np
import jax
import jax.numpy as jnp
from jax import (
Expand Down Expand Up @@ -34,35 +33,33 @@ def solve(self, X):

class JaxCurveGenerationSolver(Solver):
"""
Fits latent information according to a given polynomial.
Fits latent information according to a given function.
"""

def __init__(self, p, w0, max_fit_attemps, random_seed):
def __init__(self, f, w0, max_fit_attemps):
"""
Args:
p (Callable): The polynomial.
f (Callable): The function.
w0 (list-like): The initial guess for the solution.
max_fit_attemps (int): The maxmium number of attempts to refit a curve, if optimization didn't succeed.
random_seed (int): The random seed for the random number generator.
"""
self.p = p
self.dp_dx = grad(p, argnums=1)
self.dp_dx2 = grad(self.dp_dx, argnums=1)
self.f = jit(vmap(partial(f), in_axes=(None, 0)))
df_dx = grad(f, argnums=1)
df_dx2 = grad(df_dx, argnums=1)
self.df_dx = jit(vmap(partial(df_dx), in_axes=(None, 0)))
self.df_dx2 = jit(vmap(partial(df_dx2), in_axes=(None, 0)))
self.w0 = jnp.array(w0)
self.max_fit_attempts = max_fit_attemps
self.rng = np.random.RandomState(random_seed)

def solve(self, X, callback=None):
coefficients = []
p = jit(vmap(partial(self.p), in_axes=(None, 0)))
dp_dx = jit(vmap(partial(self.dp_dx), in_axes=(None, 0)))
dp_dx2 = jit(vmap(partial(self.dp_dx2), in_axes=(None, 0)))
solution = self.w0
for i, latent in enumerate(X):
result = _minimize(
p,
dp_dx,
dp_dx2,
self.f,
self.df_dx,
self.df_dx2,
solution,
latent.y0,
latent.x0,
Expand All @@ -72,7 +69,7 @@ def solve(self, X, callback=None):
latent.x2,
)
if not result.success:
result = self._refit(p, dp_dx, dp_dx2, latent)
result = self._refit(self.f, self.df_dx, self.df_dx2, latent)
solution = result.x
if callback:
jax.debug.callback(callback, i, solution)
Expand Down
16 changes: 9 additions & 7 deletions driftbench/data_generation/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def plot_curve_with_latent_information(
coefficients, p, latent_information, title=None, ax=None, y_lim=None
coefficients, f, latent_information, title=None, ax=None, y_lim=None
):
"""
Plots the reconstructed wave with the given coefficients and a polynomial with the ground truth
Expand All @@ -27,7 +27,7 @@ def plot_curve_with_latent_information(
if not ax:
fig, ax = plt.subplots()

ax.plot(x, p(coefficients, x))
ax.plot(x, f(coefficients, x))

# Plot the given x-values
for xx in latent_information.x0:
Expand All @@ -37,9 +37,9 @@ def plot_curve_with_latent_information(
for slope, x_slope in zip(latent_information.y1, latent_information.x1):
xxs = [x for x in range(int(x_slope - 1), int(x_slope + 3.0))]
dx_vals = np.array(
[(slope * x) - (slope * x_slope - p(coefficients, x_slope)) for x in xxs]
[(slope * x) - (slope * x_slope - f(coefficients, x_slope)) for x in xxs]
)
ax.scatter(x_slope, p(coefficients, x_slope), alpha=0.4, color="green")
ax.scatter(x_slope, f(coefficients, x_slope), alpha=0.4, color="green")
ax.plot(xxs, dx_vals, c="green")

# Plot curvature
Expand All @@ -58,20 +58,22 @@ def plot_curve_with_latent_information(
ax.set_title(title)


def plot_curves(curves, xs, title=None, cmap="coolwarm", ylim=None):
def plot_curves(xs, curves, title=None, cmap="coolwarm", ax=None, ylim=None):
"""
Plots curves with a given cmap, where the color mapping is applied over the temporal axis.

Args:
xs(list[float]): The x-values for the curve, must be of length m.
curves(np.ndarray): The curves array, of shape (N, m), where N curves consist of m
timesteps.
xs(list[float]): The x-values for the curve, must be of length m.
title (str): The title of the plot.
cmap (str): The colormap for the color mapping over the temporal axis.
ax (matplotlib.axes).: Extern axes if this function is used for external created figure.
ylim(tuple[float, float]): The y-limit for the plot.

"""
fig, ax = plt.subplots()
if ax is None:
fig, ax = plt.subplots()
cmap_obj = plt.get_cmap(name=cmap)
cycler = plt.cycler("color", cmap_obj(np.linspace(0, 1, curves.shape[0])))
ax.set_prop_cycle(cycler)
Expand Down
16 changes: 8 additions & 8 deletions tests/data_generation/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@

class TestJaxCurveGenerationSolver(unittest.TestCase):
def setUp(self):
self.p = lambda w, x: w[0] * x ** 3 + w[1] * x ** 2 + w[2] * x + w[3]
x0 = np.array([0., 2., 4.])
y0 = np.array([0., 8., 64.])
x1 = np.array([1., 3.])
y1 = np.array([3., 27.])
x2 = np.array([2.])
y2 = np.array([12.])
self.p = lambda w, x: w[0] * x**3 + w[1] * x**2 + w[2] * x + w[3]
x0 = np.array([0.0, 2.0, 4.0])
y0 = np.array([0.0, 8.0, 64.0])
x1 = np.array([1.0, 3.0])
y1 = np.array([3.0, 27.0])
x2 = np.array([2.0])
y2 = np.array([12.0])
self.latent_information = LatentInformation(y0, x0, y1, x1, y2, x2)

def test_solve(self):
w0 = jnp.zeros(4)
solver = JaxCurveGenerationSolver(self.p, w0, max_fit_attemps=1, random_seed=10)
solver = JaxCurveGenerationSolver(self.p, w0, max_fit_attemps=10)
coefficients = solver.solve([self.latent_information])
expected = np.array([[1.0, 0.0, 0.0, 0.0]])
self.assertIs(type(coefficients), jaxlib.ArrayImpl)
Expand Down