Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.11
0.0.12
2 changes: 1 addition & 1 deletion docs/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ coefficients, latent_information, curves = sample_curves(dataset["example"], mea
```
By specifying a value for `measurement_scale` some gaussian noise with the specified scale is applied
on each value for every curve. By default, $5\%$ of the mean of the curves is used. If you want to
omit the scale, set it to `0.0` explictly.
omit the scale, set it to `0.0` explicitly.
18 changes: 16 additions & 2 deletions driftbench/benchmarks/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,29 @@


class Dataset:
"""
Represents a container class for a dataset specification for benchmarking purposes.
"""

def __init__(self, name, spec, f=None, w0=None, n_variations=5):
"""
Args:
name (str): The name of the dataset specification.
spec (dict): The yaml-specification of the dataset.
f (Callable): The function to fit the curves.
w0 (np.ndarray): The inital value for the internal parameters.
n_variations (int): The number of variations each dataset is sampled.
Each dataset is sampled as many times as `n_variations` is set, each time with a
different random seed.
"""
self.spec = spec
self.name = name
self.n_variations = n_variations
self.w0 = w0
self.f = f

drift_bounds = self.spec['drifts'].get_individual_drift_bounds()
self.Y = transform_drift_segments_into_binary(drift_bounds, self.spec['N'])
drift_bounds = self.spec["drifts"].get_individual_drift_bounds()
self.Y = transform_drift_segments_into_binary(drift_bounds, self.spec["N"])

def _generate(self, random_state):
_, _, curves = sample_curves(
Expand Down
50 changes: 34 additions & 16 deletions driftbench/data_generation/drifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from abc import ABCMeta, abstractmethod
from itertools import groupby, combinations


class Drift(metaclass=ABCMeta):
"""
Represents a drift for 1d or 2d input.
"""

def __init__(self, start, end, feature=None, dimension=0) -> None:
"""
Args:
Expand Down Expand Up @@ -41,6 +43,7 @@ class DriftSequence:
"""
Represents a sequence of drifts, which will be applied on a latent information object.
"""

def __init__(self, drifts):
"""
Args:
Expand All @@ -52,25 +55,30 @@ def __init__(self, drifts):
def apply(self, X):
"""
Applies the transformation by the given drifts on the latent information input.

Args:
X (list[LatentInformation]): The list of latent information the drifts are applied on.

Returns (list): A list of drifted latent information according to the drift sequence.
Returns:
(list[LatentInformation]): A list of drifted latent information according to the drift sequence.
"""
drifted = copy.deepcopy(X)
for drift in self.drifts:
feature = np.array([getattr(x, drift.feature) for x in drifted])
feature[:, drift.dimension] = drift.transform(feature[:, drift.dimension]).flatten()
feature[:, drift.dimension] = drift.transform(
feature[:, drift.dimension]
).flatten()
for i, x in enumerate(drifted):
setattr(x, drift.feature, feature[i])
return drifted

def get_aggregated_drift_bounds(self):
"""
Returns the aggregated drift bounds, i.e. the maximum range where drifts are applied.

Returns:
A tuple of (int, int), where the first value denotes the start index and the second value the
end index of the aggregated drift bounds.
(tuple[int, int]): A tuple of (int, int), where the first value denotes the start
index and the second value the end index of the aggregated drift bounds.
"""
start = self.drifts[0].start
end = self.drifts[-1].end
Expand All @@ -79,9 +87,10 @@ def get_aggregated_drift_bounds(self):
def get_individual_drift_bounds(self):
"""
Returns the drift bounds for each individual drift in the drift sequence.

Returns:
A list of tuples of (int, int), where the first value denotes the start of the drift,
and the second value the end of the drift.
(list[tuple[int, int]]): A list of tuples of (int, int), where the first value denotes
the start of the drift, and the second value the end of the drift.
"""
return [(drift.start, drift.end) for drift in self.drifts]

Expand All @@ -90,15 +99,17 @@ def get_drift_intensities(self):
Returns the intensities for each range in the drift sequence. Each drift has a base intensity of 1,
and when multiple drifts overlap, the intensity becomes the number of the drifts present in the given
range.

Returns:
A dictionary with tuples as keys and ints as values.
(dict[tuple[int, int], int]): A dictionary with tuples as keys and ints as values.
The keys indicate the range of the drift intensity, and the values indicate the intensity.
"""
intensities = {}
drift_intensities_array = np.zeros((len(self.drifts),
np.max([drift.end for drift in self.drifts]) + 1))
drift_intensities_array = np.zeros(
(len(self.drifts), np.max([drift.end for drift in self.drifts]) + 1)
)
for i, drift in enumerate(self.drifts):
drift_intensities_array[i, drift.start:drift.end + 1] = 1
drift_intensities_array[i, drift.start : drift.end + 1] = 1
stacked_drift_intensities = np.sum(drift_intensities_array, axis=0)

for intensity in range(1, np.max(stacked_drift_intensities).astype(int) + 1):
Expand All @@ -111,24 +122,31 @@ def get_drift_intensities(self):

def _validate_drifts(self, drifts):
# Group drifts by their feature and their dimension they apply on.
drifts_sorted = sorted(drifts, key=lambda drift: (drift.feature, drift.dimension))
drifts_grouped = groupby(drifts_sorted, key=lambda drift: (drift.feature, drift.dimension))
drifts_sorted = sorted(
drifts, key=lambda drift: (drift.feature, drift.dimension)
)
drifts_grouped = groupby(
drifts_sorted, key=lambda drift: (drift.feature, drift.dimension)
)
# Check within these groups if an overlap exists.
for (feature, dimension), curr_drifts in drifts_grouped:
curr_drifts = list(curr_drifts)
for i, j in combinations(range(len(curr_drifts)), 2):
drift1 = curr_drifts[i]
drift2 = curr_drifts[j]
if drift1.start <= drift2.end and drift2.start <= drift1.end:
raise ValueError(f"Drifts are not allowed to overlap. "
f"Overlapping drift at feature {feature} in dimension {dimension}")
raise ValueError(
f"Drifts are not allowed to overlap. "
f"Overlapping drift at feature {feature} in dimension {dimension}"
)


class LinearDrift(Drift):
"""
Represents a linear drift for a 1d or 2d-input, i.e. a drift
where the input data is drifted in a linear fashion.
"""

def __init__(self, start, end, m, feature=None, dimension=0):
"""
Args:
Expand All @@ -147,8 +165,8 @@ def transform(self, X):
drifted = drifted.reshape(-1, 1)
# Use 0 based x indices for computing the slope at a given position
xs = np.arange(self.end - self.start + 1).reshape(-1, 1)
drifted[self.start:self.end + 1, :] += self.m * xs
drifted[self.start : self.end + 1, :] += self.m * xs
# Maintain data according to new data after drift happened.
after_drift_idx = drifted.shape[0] - self.end
drifted[-after_drift_idx + 1:, :] += self.m * xs[-1]
drifted[-after_drift_idx + 1 :, :] += self.m * xs[-1]
return drifted
22 changes: 22 additions & 0 deletions driftbench/data_generation/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@ def sample_curves(
measurement_scale=None,
callback=None,
):
"""
Samples synthetic curves given a dataset specification.

Args:
dataset_specification (dict): A dataset specification which contains
all information to syntethisize curves in yaml-format.
Each dataset is encoded with a name and needs a latent information provided.
The function `f` to fit and as well as initial guess `w0`can be provided as well.
f (Callable): The function to fit the curves. Use this parameter if no function is specified
in `dataset_specification`.
w0 (np.ndarray): The inital guess for the optimization problem used to synthesize curves.
Use this parameter if no initial guess is specified in `dataset_specification`.
random_state (int): The random state for reproducablity.
measurement_scale (float): The scale for the noise applied on the evaluated curves. If not
set, 5% percent of the mean of the curves is used. Set to 0.0 if you want to omit
this noise.

Returns:
(np.ndarray): The coefficients for each sampled curve.
(list[LatentInformation]): The latent information for each sampled curve.
(np.ndarray): The evaluated sampled curves.
"""
dimensions = dataset_specification["dimensions"]
drifts = dataset_specification.get("drifts")
x_scale = dataset_specification.get("x_scale", 0.02)
Expand Down
52 changes: 45 additions & 7 deletions driftbench/data_generation/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@

jax.config.update("jax_enable_x64", True)


class Solver(metaclass=ABCMeta):
"""
Represents a backend for solving an optimization problem.
"""

@abstractmethod
def solve(self, X):
"""
Solves an optimization problem defined by the solver.

Args:
X (list-like): Input to optimize according to solver instance.

Returns:

(np.ndarray|jnp.ndarray): The parameters obtained by solving the optimzation problem.
"""
pass

Expand All @@ -33,10 +36,11 @@ class JaxCurveGenerationSolver(Solver):
"""
Fits latent information according to a given polynomial.
"""

def __init__(self, p, w0, max_fit_attemps, random_seed):
"""
Args:
p (func): The polynomial.
p (Callable): The polynomial.
w0 (list-like): The initial guess for the solution.
max_fit_attemps (int): The maxmium number of attempts to refit a curve, if optimization didn't succeed.
random_seed (int): The random seed for the random number generator.
Expand All @@ -55,12 +59,23 @@ def solve(self, X, callback=None):
dp_dx2 = jit(vmap(partial(self.dp_dx2), in_axes=(None, 0)))
solution = self.w0
for i, latent in enumerate(X):
result = _minimize(p, dp_dx, dp_dx2, solution, latent.y0, latent.x0, latent.y1, latent.x1, latent.y2, latent.x2)
result = _minimize(
p,
dp_dx,
dp_dx2,
solution,
latent.y0,
latent.x0,
latent.y1,
latent.x1,
latent.y2,
latent.x2,
)
if not result.success:
result = self._refit(p, dp_dx, dp_dx2, latent)
solution = result.x
if callback:
jax.debug.callback(callback,i, solution)
jax.debug.callback(callback, i, solution)
coefficients.append(solution)
return jnp.array(coefficients)

Expand All @@ -74,7 +89,18 @@ def _refit(self, p, dp_dx, dp_dx2, latent):
# for the same problem as starting point until convergence.
while not success and current_fit_attempts < self.max_fit_attempts:
current_fit_attempts += 1
result = _minimize(p, dp_dx, dp_dx2, solution, latent.y0, latent.x0, latent.y1, latent.x1, latent.y2, latent.x2)
result = _minimize(
p,
dp_dx,
dp_dx2,
solution,
latent.y0,
latent.x0,
latent.y1,
latent.x1,
latent.y2,
latent.x2,
)
solution = result.x
success = result.success
return result
Expand All @@ -86,7 +112,17 @@ def _minimize(p, dp_dx, dp_dx2, w, y0, x0, y1, x1, y2, x2):
_solve,
w,
method="BFGS",
args=(p, dp_dx, dp_dx2, jnp.array(y0), jnp.array(x0), jnp.array(y1), jnp.array(x1), jnp.array(y2), jnp.array(x2))
args=(
p,
dp_dx,
dp_dx2,
jnp.array(y0),
jnp.array(x0),
jnp.array(y1),
jnp.array(x1),
jnp.array(y2),
jnp.array(x2),
),
)


Expand All @@ -100,4 +136,6 @@ def _solve(w, p, dp_dx, dp_dx2, y0, x0, y1, x1, y2, x2):

@jit
def _loss(y0, y1, y2, px, dp_px, dp_px2):
return ((px - y0) ** 2).sum() + ((dp_px - y1) ** 2).sum() + ((dp_px2 - y2) ** 2).sum()
return (
((px - y0) ** 2).sum() + ((dp_px - y1) ** 2).sum() + ((dp_px2 - y2) ** 2).sum()
)
27 changes: 21 additions & 6 deletions driftbench/data_generation/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import matplotlib.pyplot as plt


def plot_curve_with_latent_information(coefficients, p, latent_information, title=None, ax=None, y_lim=None):
def plot_curve_with_latent_information(
coefficients, p, latent_information, title=None, ax=None, y_lim=None
):
"""
Plots the reconstructed wave with the given coefficients and a polynomial with the ground truth
defined by the latent information.
Expand All @@ -13,7 +15,7 @@ def plot_curve_with_latent_information(coefficients, p, latent_information, titl
the ground truth for the polynomial and it's coefficients
title (str): The title for the plot.
ax (matplotlib.axes).: Extern axes if this function is used for external created figure.
y_lim (tuple(int, int): The y-lim for the plot.
y_lim (tuple[float, float]): The y-lim for the plot.

Returns:

Expand All @@ -29,20 +31,21 @@ def plot_curve_with_latent_information(coefficients, p, latent_information, titl

# Plot the given x-values
for xx in latent_information.x0:
ax.axvline(xx, linestyle='dashed', color='black')
ax.axvline(xx, linestyle="dashed", color="black")

# Plot slope according to first derivative
for slope, x_slope in zip(latent_information.y1, latent_information.x1):
xxs = [x for x in range(int(x_slope - 1), int(x_slope + 3.))]
xxs = [x for x in range(int(x_slope - 1), int(x_slope + 3.0))]
dx_vals = np.array(
[(slope * x) - (slope * x_slope - p(coefficients, x_slope)) for x in xxs])
[(slope * x) - (slope * x_slope - p(coefficients, x_slope)) for x in xxs]
)
ax.scatter(x_slope, p(coefficients, x_slope), alpha=0.4, color="green")
ax.plot(xxs, dx_vals, c="green")

# Plot curvature
for x_curvature, curvature in zip(latent_information.x2, latent_information.y2):
label = "convex" if curvature > 0.0 else "concave"
ax.axvline(x_curvature, linestyle='dashed', color='purple', label=label)
ax.axvline(x_curvature, linestyle="dashed", color="purple", label=label)

# Mark the corresponding y-values
for yy, xx in zip(latent_information.y0, latent_information.x0):
Expand All @@ -56,6 +59,18 @@ def plot_curve_with_latent_information(coefficients, p, latent_information, titl


def plot_curves(curves, xs, title=None, cmap="coolwarm", ylim=None):
"""
Plots curves with a given cmap, where the color mapping is applied over the temporal axis.

Args:
curves(np.ndarray): The curves array, of shape (N, m), where N curves consist of m
timesteps.
xs(list[float]): The x-values for the curve, must be of length m.
title (str): The title of the plot.
cmap (str): The colormap for the color mapping over the temporal axis.
ylim(tuple[float, float]): The y-limit for the plot.

"""
fig, ax = plt.subplots()
cmap_obj = plt.get_cmap(name=cmap)
cycler = plt.cycler("color", cmap_obj(np.linspace(0, 1, curves.shape[0])))
Expand Down