diff --git a/VERSION b/VERSION index 8cbf02c..43b2961 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.12 +0.0.13 diff --git a/docs/api.md b/docs/api.md index 893ad3c..82aae1c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -3,6 +3,7 @@ ::: driftbench.data_generation.sample ::: driftbench.data_generation.drifts ::: driftbench.data_generation.solvers +::: driftbench.data_generation.latent_information # Drift detection diff --git a/docs/how_it_works.md b/docs/how_it_works.md index a7c85de..8760fed 100644 --- a/docs/how_it_works.md +++ b/docs/how_it_works.md @@ -2,3 +2,34 @@ # Technical implementation in `jax` +This package uses [JAX](JAX (https://github.com/jax-ml/jax) as its backend for generating synthetic curves. +In particular, `jax` is used for: + +- Solving non-linear optimization problems. +- Automatic differentiation for calculating partial derivates of arbitrary functions. +- Just-in-time (JIT)-compilation with XLA for performance optimization. + +For more detailed information regarding the XLA-compilation, please see the +[offical JAX documentation](https://docs.jax.dev/en/latest/index.html) +or [XLA-documentation](https://openxla.org/xla/tf2xla?hl=en). + +These three points ensure an efficient generation of curves, while being +able to control the latent information used and the behaviour of drifts applied +on the curves. + +The method used to solve optimization problems is the +[LBFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) algorithm, which is supported by `jax`. +The corresponding functions in order to compute the error term, running a iteration of the +optimization solving problem, and computing the gradients is all done in functions which are +compiled just-in-time and can be run on a GPU. +The procedure can be described as follows: + +- Choose a function $f(w(t), x)$, which describes the shape of the curves to generate with +inital internal parameters $w_0(t)$. +- Provide problem constraints encoded in +[`LatentInformation`][driftbench.data_generation.latent_information.LatentInformation] +objects. +- Compute the partial derivates of $f(w(t), x)$ with respect to $x_i$. +- For each latent information object, compute $w(t)$ according to the LBFGS-algorithm. +- Return all computed solutions $w(t)$ for each curve. + diff --git a/driftbench/data_generation/latent_information.py b/driftbench/data_generation/latent_information.py index 12c8080..b6abd6b 100644 --- a/driftbench/data_generation/latent_information.py +++ b/driftbench/data_generation/latent_information.py @@ -19,8 +19,8 @@ class LatentInformation: y2 (list-like): The y-values of the derivative of a function. x2 (list-like): The x-values of the second derivative of a function. Hence, no duplicates are allowed. - """ + y0: np.ndarray x0: np.ndarray y1: np.ndarray @@ -35,11 +35,17 @@ def __post_init__(self): def _validate_matching_shapes(self): if self.y0.shape != self.x0.shape: - raise ValueError("Features y0 and x0 are not allowed to have different shape") + raise ValueError( + "Features y0 and x0 are not allowed to have different shape" + ) if self.y1.shape != self.x1.shape: - raise ValueError("Features y1 and x1 are not allowed to have different shape") + raise ValueError( + "Features y1 and x1 are not allowed to have different shape" + ) if self.y2.shape != self.x2.shape: - raise ValueError("Features y2 and x2 are not allowed to have different shape") + raise ValueError( + "Features y2 and x2 are not allowed to have different shape" + ) def _validate_1d_array(self):