diff --git a/.gitignore b/.gitignore index 82f9275..e90d309 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,5 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ +.DS_Store diff --git a/VERSION b/VERSION index 9789c4c..6e8bf73 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.14 +0.1.0 diff --git a/docs/data.md b/docs/data.md index c880ed5..be412f9 100644 --- a/docs/data.md +++ b/docs/data.md @@ -67,3 +67,79 @@ coefficients, latent_information, curves = sample_curves(dataset["example"], mea By specifying a value for `measurement_scale` some gaussian noise with the specified scale is applied on each value for every curve. By default, $5\%$ of the mean of the curves is used. If you want to omit the scale, set it to `0.0` explicitly. + +## Vectorized Curve Generation + +The curve generation process can be significantly accelerated by using vectorized computation. +By default, `driftbench` uses JAX's `vmap` (vectorized map) to parallelize the optimization +across all latent information instances simultaneously, resulting in much faster curve generation +compared to sequential computation. + +### Performance Comparison + +| Mode | Description | Use Case | +|------|-------------|----------| +| **Vectorized** (default) | Optimizes all curves in parallel using `vmap` and `jit` | Large datasets, production use | +| **Sequential** | Optimizes curves one-by-one, using previous solution as starting point | Debugging, progress tracking, improved stability | + +The vectorized mode is typically **orders of magnitude faster** for large datasets because: + +1. JAX compiles the optimization function only once for all instances +2. Operations are batched and executed in parallel on the hardware (CPU/GPU) +3. Memory access patterns are optimized for vectorized operations + +However, sequential mode can provide **more stable results** because it uses the solution +from the previous curve as the starting point for the next optimization. This warm-starting +approach can lead to smoother transitions across the execution dimension, especially when +curves are expected to have similar coefficients. + +### Using Vectorized Mode + +Vectorized computation is enabled by default when using `sample_curves`: + +```python +from driftbench.data_generation.sample import sample_curves + +# Vectorized mode is used by default - fast computation +coefficients, latent_information, curves = sample_curves(dataset["example"]) + +# Explicitly enable vectorized mode +coefficients, latent_information, curves = sample_curves(dataset["example"], vectorize=True) +``` + +### Using Sequential Mode + +To use sequential mode with better stability, set `vectorize=False` in `sample_curves`: + +```python +from driftbench.data_generation.sample import sample_curves + +# Sequential mode - slower but more stable results +coefficients, latent_information, curves = sample_curves(dataset["example"], vectorize=False) + +# With a callback to track progress +def progress_callback(i, solution): + print(f"Curve {i}: coefficients = {solution}") + +coefficients, latent_information, curves = sample_curves( + dataset["example"], + vectorize=False, + callback=progress_callback +) +``` + +!!! note + The `callback` parameter is only supported in sequential mode (`vectorize=False`). + When using vectorized mode, the callback is ignored since all curves are computed + simultaneously. + +### When to Use Each Mode + +- **Vectorized mode** (default): Use this for production workloads and when generating + large numbers of curves where speed is the priority. + +- **Sequential mode**: Use this when you need to: + - Achieve more stable optimization results with smooth coefficient transitions + - Debug the optimization process + - Monitor progress with a callback function + - Investigate individual curve fitting issues diff --git a/driftbench/data_generation/data_generator.py b/driftbench/data_generation/data_generator.py index aeec0d9..e1b6400 100644 --- a/driftbench/data_generation/data_generator.py +++ b/driftbench/data_generation/data_generator.py @@ -27,14 +27,17 @@ class CurveGenerator(DataGenerator): which meet the constraints provided by the latent information. """ - def __init__(self, p, w0, max_fit_attempts=100): + def __init__(self, p, w0, max_fit_attempts=100, vectorize=True): """ Args: p (func): The polynomial to fit. w0 (list-like): The initial guess. max_fit_attemps (int): The maxmium number of attempts to refit a curve, if optimization didn't succeed. + vectorize (bool): Whether to vectorize the optimization over multiple latent information instances. """ - self.solver = JaxCurveGenerationSolver(p, w0, max_fit_attempts) + self.solver = JaxCurveGenerationSolver( + p, w0, max_fit_attempts, vectorize=vectorize + ) def run(self, X, callback=None): return self.solver.solve(X, callback=callback) diff --git a/driftbench/data_generation/sample.py b/driftbench/data_generation/sample.py index 8a242b4..329c649 100644 --- a/driftbench/data_generation/sample.py +++ b/driftbench/data_generation/sample.py @@ -10,6 +10,7 @@ def sample_curves( w0=None, random_state=2024, measurement_scale=None, + vectorize=True, callback=None, ): """ @@ -17,14 +18,15 @@ def sample_curves( Args: dataset_specification (dict): A dataset specification which contains - all information to syntethisize curves in yaml-format. + all information to synthesize curves in yaml-format. Each dataset is encoded with a name and needs a latent information provided. The function `f` to fit and as well as initial guess `w0`can be provided as well. f (Callable): The function to fit the curves. Use this parameter if no function is specified in `dataset_specification`. - w0 (np.ndarray): The inital guess for the optimization problem used to synthesize curves. + w0 (np.ndarray): The initial guess for the optimization problem used to synthesize curves. Use this parameter if no initial guess is specified in `dataset_specification`. - random_state (int): The random state for reproducablity. + random_state (int): The random state for reproducibility. + vectorize (bool): Whether to vectorize the optimization over multiple latent information instances. measurement_scale (float): The scale for the noise applied on the evaluated curves. If not set, 5% percent of the mean of the curves is used. Set to 0.0 if you want to omit this noise. @@ -46,7 +48,7 @@ def sample_curves( ) if drifts is not None: latent_information = drifts.apply(latent_information) - data_generator = CurveGenerator(func, w_init) + data_generator = CurveGenerator(func, w_init, vectorize=vectorize) w = data_generator.run(latent_information, callback=callback) x_range = np.concatenate( ( diff --git a/driftbench/data_generation/solvers.py b/driftbench/data_generation/solvers.py index 1bd265f..45bdb44 100644 --- a/driftbench/data_generation/solvers.py +++ b/driftbench/data_generation/solvers.py @@ -36,13 +36,13 @@ class JaxCurveGenerationSolver(Solver): Fits latent information according to a given function. """ - def __init__(self, f, w0, max_fit_attemps): + def __init__(self, f, w0, max_fit_attemps, vectorize=True): """ Args: f (Callable): The function. w0 (list-like): The initial guess for the solution. max_fit_attemps (int): The maxmium number of attempts to refit a curve, if optimization didn't succeed. - random_seed (int): The random seed for the random number generator. + vectorize (bool): Whether to vectorize the optimization over multiple latent information instances. """ self.f = jit(vmap(partial(f), in_axes=(None, 0))) df_dx = grad(f, argnums=1) @@ -51,12 +51,33 @@ def __init__(self, f, w0, max_fit_attemps): self.df_dx2 = jit(vmap(partial(df_dx2), in_axes=(None, 0))) self.w0 = jnp.array(w0) self.max_fit_attempts = max_fit_attemps + self.vectorize = vectorize + self.min_func = _minimize + + # Compile function beforehand if vectorization is enabled for one compilation only. + if self.vectorize: + self.min_func = jit( + vmap( + partial(_minimize), + in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, 0), + ), + static_argnums=(0, 1, 2), + ) - def solve(self, X, callback=None): + def _latents_to_array(self, latents): + l_x0_mat = jnp.array([latent.x0 for latent in latents]) + l_x1_mat = jnp.array([latent.x1 for latent in latents]) + l_x2_mat = jnp.array([latent.x2 for latent in latents]) + l_y0_mat = jnp.array([latent.y0 for latent in latents]) + l_y1_mat = jnp.array([latent.y1 for latent in latents]) + l_y2_mat = jnp.array([latent.y2 for latent in latents]) + return l_x0_mat, l_x1_mat, l_x2_mat, l_y0_mat, l_y1_mat, l_y2_mat + + def _solve_sequentially(self, X, callback=None): coefficients = [] solution = self.w0 for i, latent in enumerate(X): - result = _minimize( + result = self.min_func( self.f, self.df_dx, self.df_dx2, @@ -76,7 +97,32 @@ def solve(self, X, callback=None): coefficients.append(solution) return jnp.array(coefficients) - def _refit(self, p, dp_dx, dp_dx2, latent): + def _solve_vectorized(self, X): + solution = jnp.tile(self.w0, (len(X), 1)) + l_x0_mat, l_x1_mat, l_x2_mat, l_y0_mat, l_y1_mat, l_y2_mat = ( + self._latents_to_array(X) + ) + result = self.min_func( + self.f, + self.df_dx, + self.df_dx2, + solution, + l_y0_mat, + l_x0_mat, + l_y1_mat, + l_x1_mat, + l_y2_mat, + l_x2_mat, + ) + return result.x + + def solve(self, X, callback=None): + if self.vectorize: + return self._solve_vectorized(X) + else: + return self._solve_sequentially(X, callback) + + def _refit(self, f, df_dx, df_dx2, latent): # Restart with initial guess in order to be independent of previous solutions. solution = self.w0 current_fit_attempts = 0 @@ -86,10 +132,10 @@ def _refit(self, p, dp_dx, dp_dx2, latent): # for the same problem as starting point until convergence. while not success and current_fit_attempts < self.max_fit_attempts: current_fit_attempts += 1 - result = _minimize( - p, - dp_dx, - dp_dx2, + result = self.min_func( + f, + df_dx, + df_dx2, solution, latent.y0, latent.x0, @@ -104,15 +150,15 @@ def _refit(self, p, dp_dx, dp_dx2, latent): @partial(jit, static_argnums=(0, 1, 2)) -def _minimize(p, dp_dx, dp_dx2, w, y0, x0, y1, x1, y2, x2): +def _minimize(f, df_dx, df_dx2, w, y0, x0, y1, x1, y2, x2): return minimize( _solve, w, method="BFGS", args=( - p, - dp_dx, - dp_dx2, + f, + df_dx, + df_dx2, jnp.array(y0), jnp.array(x0), jnp.array(y1), @@ -124,15 +170,15 @@ def _minimize(p, dp_dx, dp_dx2, w, y0, x0, y1, x1, y2, x2): @partial(jit, static_argnums=(1, 2, 3)) -def _solve(w, p, dp_dx, dp_dx2, y0, x0, y1, x1, y2, x2): - px = p(w, x0) - dp_px = dp_dx(w, x1) - dp_px2 = dp_dx2(w, x2) +def _solve(w, f, df_dx, df_dx2, y0, x0, y1, x1, y2, x2): + px = f(w, x0) + dp_px = df_dx(w, x1) + dp_px2 = df_dx2(w, x2) return _loss(y0, y1, y2, px, dp_px, dp_px2) @jit -def _loss(y0, y1, y2, px, dp_px, dp_px2): +def _loss(y0, y1, y2, fx, df_fx, df_fx2): return ( - ((px - y0) ** 2).sum() + ((dp_px - y1) ** 2).sum() + ((dp_px2 - y2) ** 2).sum() + ((fx - y0) ** 2).sum() + ((df_fx - y1) ** 2).sum() + ((df_fx2 - y2) ** 2).sum() ) diff --git a/mkdocs.yml b/mkdocs.yml index 350e19d..81a70b0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -25,6 +25,7 @@ plugins: canonical_version: latest markdown_extensions: + - admonition - pymdownx.highlight: anchor_linenums: true line_spans: __span diff --git a/tests/data_generation/test_data_generator.py b/tests/data_generation/test_data_generator.py index 7dc0d11..11d40eb 100644 --- a/tests/data_generation/test_data_generator.py +++ b/tests/data_generation/test_data_generator.py @@ -7,18 +7,26 @@ class TestCurveGenerator(unittest.TestCase): def setUp(self): - self.p = lambda w, x: w[0] * x ** 3 + w[1] * x ** 2 + w[2] * x + w[3] - x0 = np.array([0., 2., 4.]) - y0 = np.array([0., 8., 64.]) - x1 = np.array([1., 3.]) - y1 = np.array([3., 27.]) - x2 = np.array([2.]) - y2 = np.array([12.]) + self.p = lambda w, x: w[0] * x**3 + w[1] * x**2 + w[2] * x + w[3] + x0 = np.array([0.0, 2.0, 4.0]) + y0 = np.array([0.0, 8.0, 64.0]) + x1 = np.array([1.0, 3.0]) + y1 = np.array([3.0, 27.0]) + x2 = np.array([2.0]) + y2 = np.array([12.0]) self.latent_information = LatentInformation(y0, x0, y1, x1, y2, x2) - def test_curve_generation(self): + def test_curve_generation_vectorized(self): w0 = np.zeros(4) - curve_generator = CurveGenerator(self.p, w0) + curve_generator = CurveGenerator(self.p, w0, vectorize=True) + solution = curve_generator.run([self.latent_information]) + expected = np.array([[1.0, 0.0, 0.0, 0.0]]) + self.assertTupleEqual(solution.shape, (1, 4)) + self.assertTrue(np.allclose(expected, solution)) + + def test_curve_generation_sequentially(self): + w0 = np.zeros(4) + curve_generator = CurveGenerator(self.p, w0, vectorize=False) solution = curve_generator.run([self.latent_information]) expected = np.array([[1.0, 0.0, 0.0, 0.0]]) self.assertTupleEqual(solution.shape, (1, 4)) diff --git a/tests/data_generation/test_solvers.py b/tests/data_generation/test_solvers.py index e14a599..53cf7cf 100644 --- a/tests/data_generation/test_solvers.py +++ b/tests/data_generation/test_solvers.py @@ -18,9 +18,22 @@ def setUp(self): y2 = np.array([12.0]) self.latent_information = LatentInformation(y0, x0, y1, x1, y2, x2) - def test_solve(self): + def test_solve_vectorized(self): w0 = jnp.zeros(4) - solver = JaxCurveGenerationSolver(self.p, w0, max_fit_attemps=10) + solver = JaxCurveGenerationSolver( + self.p, w0, max_fit_attemps=10, vectorize=True + ) + coefficients = solver.solve([self.latent_information]) + expected = np.array([[1.0, 0.0, 0.0, 0.0]]) + self.assertIs(type(coefficients), jaxlib.ArrayImpl) + self.assertTupleEqual(coefficients.shape, (1, 4)) + self.assertTrue(np.allclose(expected, coefficients)) + + def test_solve_sequentially(self): + w0 = jnp.zeros(4) + solver = JaxCurveGenerationSolver( + self.p, w0, max_fit_attemps=10, vectorize=False + ) coefficients = solver.solve([self.latent_information]) expected = np.array([[1.0, 0.0, 0.0, 0.0]]) self.assertIs(type(coefficients), jaxlib.ArrayImpl)