Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
.DS_Store
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.14
0.1.0
76 changes: 76 additions & 0 deletions docs/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,79 @@ coefficients, latent_information, curves = sample_curves(dataset["example"], mea
By specifying a value for `measurement_scale` some gaussian noise with the specified scale is applied
on each value for every curve. By default, $5\%$ of the mean of the curves is used. If you want to
omit the scale, set it to `0.0` explicitly.

## Vectorized Curve Generation

The curve generation process can be significantly accelerated by using vectorized computation.
By default, `driftbench` uses JAX's `vmap` (vectorized map) to parallelize the optimization
across all latent information instances simultaneously, resulting in much faster curve generation
compared to sequential computation.

### Performance Comparison

| Mode | Description | Use Case |
|------|-------------|----------|
| **Vectorized** (default) | Optimizes all curves in parallel using `vmap` and `jit` | Large datasets, production use |
| **Sequential** | Optimizes curves one-by-one, using previous solution as starting point | Debugging, progress tracking, improved stability |

The vectorized mode is typically **orders of magnitude faster** for large datasets because:

1. JAX compiles the optimization function only once for all instances
2. Operations are batched and executed in parallel on the hardware (CPU/GPU)
3. Memory access patterns are optimized for vectorized operations

However, sequential mode can provide **more stable results** because it uses the solution
from the previous curve as the starting point for the next optimization. This warm-starting
approach can lead to smoother transitions across the execution dimension, especially when
curves are expected to have similar coefficients.

### Using Vectorized Mode

Vectorized computation is enabled by default when using `sample_curves`:

```python
from driftbench.data_generation.sample import sample_curves

# Vectorized mode is used by default - fast computation
coefficients, latent_information, curves = sample_curves(dataset["example"])

# Explicitly enable vectorized mode
coefficients, latent_information, curves = sample_curves(dataset["example"], vectorize=True)
```

### Using Sequential Mode

To use sequential mode with better stability, set `vectorize=False` in `sample_curves`:

```python
from driftbench.data_generation.sample import sample_curves

# Sequential mode - slower but more stable results
coefficients, latent_information, curves = sample_curves(dataset["example"], vectorize=False)

# With a callback to track progress
def progress_callback(i, solution):
print(f"Curve {i}: coefficients = {solution}")

coefficients, latent_information, curves = sample_curves(
dataset["example"],
vectorize=False,
callback=progress_callback
)
```

!!! note
The `callback` parameter is only supported in sequential mode (`vectorize=False`).
When using vectorized mode, the callback is ignored since all curves are computed
simultaneously.

### When to Use Each Mode

- **Vectorized mode** (default): Use this for production workloads and when generating
large numbers of curves where speed is the priority.

- **Sequential mode**: Use this when you need to:
- Achieve more stable optimization results with smooth coefficient transitions
- Debug the optimization process
- Monitor progress with a callback function
- Investigate individual curve fitting issues
7 changes: 5 additions & 2 deletions driftbench/data_generation/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ class CurveGenerator(DataGenerator):
which meet the constraints provided by the latent information.
"""

def __init__(self, p, w0, max_fit_attempts=100):
def __init__(self, p, w0, max_fit_attempts=100, vectorize=True):
"""
Args:
p (func): The polynomial to fit.
w0 (list-like): The initial guess.
max_fit_attemps (int): The maxmium number of attempts to refit a curve, if optimization didn't succeed.
vectorize (bool): Whether to vectorize the optimization over multiple latent information instances.
"""
self.solver = JaxCurveGenerationSolver(p, w0, max_fit_attempts)
self.solver = JaxCurveGenerationSolver(
p, w0, max_fit_attempts, vectorize=vectorize
)

def run(self, X, callback=None):
return self.solver.solve(X, callback=callback)
10 changes: 6 additions & 4 deletions driftbench/data_generation/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@ def sample_curves(
w0=None,
random_state=2024,
measurement_scale=None,
vectorize=True,
callback=None,
):
"""
Samples synthetic curves given a dataset specification.

Args:
dataset_specification (dict): A dataset specification which contains
all information to syntethisize curves in yaml-format.
all information to synthesize curves in yaml-format.
Each dataset is encoded with a name and needs a latent information provided.
The function `f` to fit and as well as initial guess `w0`can be provided as well.
f (Callable): The function to fit the curves. Use this parameter if no function is specified
in `dataset_specification`.
w0 (np.ndarray): The inital guess for the optimization problem used to synthesize curves.
w0 (np.ndarray): The initial guess for the optimization problem used to synthesize curves.
Use this parameter if no initial guess is specified in `dataset_specification`.
random_state (int): The random state for reproducablity.
random_state (int): The random state for reproducibility.
vectorize (bool): Whether to vectorize the optimization over multiple latent information instances.
measurement_scale (float): The scale for the noise applied on the evaluated curves. If not
set, 5% percent of the mean of the curves is used. Set to 0.0 if you want to omit
this noise.
Expand All @@ -46,7 +48,7 @@ def sample_curves(
)
if drifts is not None:
latent_information = drifts.apply(latent_information)
data_generator = CurveGenerator(func, w_init)
data_generator = CurveGenerator(func, w_init, vectorize=vectorize)
w = data_generator.run(latent_information, callback=callback)
x_range = np.concatenate(
(
Expand Down
84 changes: 65 additions & 19 deletions driftbench/data_generation/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ class JaxCurveGenerationSolver(Solver):
Fits latent information according to a given function.
"""

def __init__(self, f, w0, max_fit_attemps):
def __init__(self, f, w0, max_fit_attemps, vectorize=True):
"""
Args:
f (Callable): The function.
w0 (list-like): The initial guess for the solution.
max_fit_attemps (int): The maxmium number of attempts to refit a curve, if optimization didn't succeed.
random_seed (int): The random seed for the random number generator.
vectorize (bool): Whether to vectorize the optimization over multiple latent information instances.
"""
self.f = jit(vmap(partial(f), in_axes=(None, 0)))
df_dx = grad(f, argnums=1)
Expand All @@ -51,12 +51,33 @@ def __init__(self, f, w0, max_fit_attemps):
self.df_dx2 = jit(vmap(partial(df_dx2), in_axes=(None, 0)))
self.w0 = jnp.array(w0)
self.max_fit_attempts = max_fit_attemps
self.vectorize = vectorize
self.min_func = _minimize

# Compile function beforehand if vectorization is enabled for one compilation only.
if self.vectorize:
self.min_func = jit(
vmap(
partial(_minimize),
in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, 0),
),
static_argnums=(0, 1, 2),
)

def solve(self, X, callback=None):
def _latents_to_array(self, latents):
l_x0_mat = jnp.array([latent.x0 for latent in latents])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like would be a bit more code efficient (while harder to understand though)

return (
    jnp.array([attr(latent, "dim") for latent in latents]) for dim in ["x0", "x1",....]
)

l_x1_mat = jnp.array([latent.x1 for latent in latents])
l_x2_mat = jnp.array([latent.x2 for latent in latents])
l_y0_mat = jnp.array([latent.y0 for latent in latents])
l_y1_mat = jnp.array([latent.y1 for latent in latents])
l_y2_mat = jnp.array([latent.y2 for latent in latents])
return l_x0_mat, l_x1_mat, l_x2_mat, l_y0_mat, l_y1_mat, l_y2_mat

def _solve_sequentially(self, X, callback=None):
coefficients = []
solution = self.w0
for i, latent in enumerate(X):
result = _minimize(
result = self.min_func(
self.f,
self.df_dx,
self.df_dx2,
Expand All @@ -76,7 +97,32 @@ def solve(self, X, callback=None):
coefficients.append(solution)
return jnp.array(coefficients)

def _refit(self, p, dp_dx, dp_dx2, latent):
def _solve_vectorized(self, X):
solution = jnp.tile(self.w0, (len(X), 1))
l_x0_mat, l_x1_mat, l_x2_mat, l_y0_mat, l_y1_mat, l_y2_mat = (
self._latents_to_array(X)
)
result = self.min_func(
self.f,
self.df_dx,
self.df_dx2,
solution,
l_y0_mat,
l_x0_mat,
l_y1_mat,
l_x1_mat,
l_y2_mat,
l_x2_mat,
)
return result.x

def solve(self, X, callback=None):
if self.vectorize:
return self._solve_vectorized(X)
else:
return self._solve_sequentially(X, callback)

def _refit(self, f, df_dx, df_dx2, latent):
# Restart with initial guess in order to be independent of previous solutions.
solution = self.w0
current_fit_attempts = 0
Expand All @@ -86,10 +132,10 @@ def _refit(self, p, dp_dx, dp_dx2, latent):
# for the same problem as starting point until convergence.
while not success and current_fit_attempts < self.max_fit_attempts:
current_fit_attempts += 1
result = _minimize(
p,
dp_dx,
dp_dx2,
result = self.min_func(
f,
df_dx,
df_dx2,
solution,
latent.y0,
latent.x0,
Expand All @@ -104,15 +150,15 @@ def _refit(self, p, dp_dx, dp_dx2, latent):


@partial(jit, static_argnums=(0, 1, 2))
def _minimize(p, dp_dx, dp_dx2, w, y0, x0, y1, x1, y2, x2):
def _minimize(f, df_dx, df_dx2, w, y0, x0, y1, x1, y2, x2):
return minimize(
_solve,
w,
method="BFGS",
args=(
p,
dp_dx,
dp_dx2,
f,
df_dx,
df_dx2,
jnp.array(y0),
jnp.array(x0),
jnp.array(y1),
Expand All @@ -124,15 +170,15 @@ def _minimize(p, dp_dx, dp_dx2, w, y0, x0, y1, x1, y2, x2):


@partial(jit, static_argnums=(1, 2, 3))
def _solve(w, p, dp_dx, dp_dx2, y0, x0, y1, x1, y2, x2):
px = p(w, x0)
dp_px = dp_dx(w, x1)
dp_px2 = dp_dx2(w, x2)
def _solve(w, f, df_dx, df_dx2, y0, x0, y1, x1, y2, x2):
px = f(w, x0)
dp_px = df_dx(w, x1)
dp_px2 = df_dx2(w, x2)
return _loss(y0, y1, y2, px, dp_px, dp_px2)


@jit
def _loss(y0, y1, y2, px, dp_px, dp_px2):
def _loss(y0, y1, y2, fx, df_fx, df_fx2):
return (
((px - y0) ** 2).sum() + ((dp_px - y1) ** 2).sum() + ((dp_px2 - y2) ** 2).sum()
((fx - y0) ** 2).sum() + ((df_fx - y1) ** 2).sum() + ((df_fx2 - y2) ** 2).sum()
)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ plugins:
canonical_version: latest

markdown_extensions:
- admonition
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
Expand Down
26 changes: 17 additions & 9 deletions tests/data_generation/test_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,26 @@
class TestCurveGenerator(unittest.TestCase):

def setUp(self):
self.p = lambda w, x: w[0] * x ** 3 + w[1] * x ** 2 + w[2] * x + w[3]
x0 = np.array([0., 2., 4.])
y0 = np.array([0., 8., 64.])
x1 = np.array([1., 3.])
y1 = np.array([3., 27.])
x2 = np.array([2.])
y2 = np.array([12.])
self.p = lambda w, x: w[0] * x**3 + w[1] * x**2 + w[2] * x + w[3]
x0 = np.array([0.0, 2.0, 4.0])
y0 = np.array([0.0, 8.0, 64.0])
x1 = np.array([1.0, 3.0])
y1 = np.array([3.0, 27.0])
x2 = np.array([2.0])
y2 = np.array([12.0])
self.latent_information = LatentInformation(y0, x0, y1, x1, y2, x2)

def test_curve_generation(self):
def test_curve_generation_vectorized(self):
w0 = np.zeros(4)
curve_generator = CurveGenerator(self.p, w0)
curve_generator = CurveGenerator(self.p, w0, vectorize=True)
solution = curve_generator.run([self.latent_information])
expected = np.array([[1.0, 0.0, 0.0, 0.0]])
self.assertTupleEqual(solution.shape, (1, 4))
self.assertTrue(np.allclose(expected, solution))

def test_curve_generation_sequentially(self):
w0 = np.zeros(4)
curve_generator = CurveGenerator(self.p, w0, vectorize=False)
solution = curve_generator.run([self.latent_information])
expected = np.array([[1.0, 0.0, 0.0, 0.0]])
self.assertTupleEqual(solution.shape, (1, 4))
Expand Down
17 changes: 15 additions & 2 deletions tests/data_generation/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,22 @@ def setUp(self):
y2 = np.array([12.0])
self.latent_information = LatentInformation(y0, x0, y1, x1, y2, x2)

def test_solve(self):
def test_solve_vectorized(self):
w0 = jnp.zeros(4)
solver = JaxCurveGenerationSolver(self.p, w0, max_fit_attemps=10)
solver = JaxCurveGenerationSolver(
self.p, w0, max_fit_attemps=10, vectorize=True
)
coefficients = solver.solve([self.latent_information])
expected = np.array([[1.0, 0.0, 0.0, 0.0]])
self.assertIs(type(coefficients), jaxlib.ArrayImpl)
self.assertTupleEqual(coefficients.shape, (1, 4))
self.assertTrue(np.allclose(expected, coefficients))

def test_solve_sequentially(self):
w0 = jnp.zeros(4)
solver = JaxCurveGenerationSolver(
self.p, w0, max_fit_attemps=10, vectorize=False
)
coefficients = solver.solve([self.latent_information])
expected = np.array([[1.0, 0.0, 0.0, 0.0]])
self.assertIs(type(coefficients), jaxlib.ArrayImpl)
Expand Down