Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.12
0.0.13
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
::: driftbench.data_generation.sample
::: driftbench.data_generation.drifts
::: driftbench.data_generation.solvers
::: driftbench.data_generation.latent_information

# Drift detection

Expand Down
31 changes: 31 additions & 0 deletions docs/how_it_works.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,34 @@


# Technical implementation in `jax`
This package uses [JAX](JAX (https://github.com/jax-ml/jax) as its backend for generating synthetic curves.
In particular, `jax` is used for:

- Solving non-linear optimization problems.
- Automatic differentiation for calculating partial derivates of arbitrary functions.
- Just-in-time (JIT)-compilation with XLA for performance optimization.

For more detailed information regarding the XLA-compilation, please see the
[offical JAX documentation](https://docs.jax.dev/en/latest/index.html)
or [XLA-documentation](https://openxla.org/xla/tf2xla?hl=en).

These three points ensure an efficient generation of curves, while being
able to control the latent information used and the behaviour of drifts applied
on the curves.

The method used to solve optimization problems is the
[LBFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) algorithm, which is supported by `jax`.
The corresponding functions in order to compute the error term, running a iteration of the
optimization solving problem, and computing the gradients is all done in functions which are
compiled just-in-time and can be run on a GPU.
The procedure can be described as follows:

- Choose a function $f(w(t), x)$, which describes the shape of the curves to generate with
inital internal parameters $w_0(t)$.
- Provide problem constraints encoded in
[`LatentInformation`][driftbench.data_generation.latent_information.LatentInformation]
objects.
- Compute the partial derivates of $f(w(t), x)$ with respect to $x_i$.
- For each latent information object, compute $w(t)$ according to the LBFGS-algorithm.
- Return all computed solutions $w(t)$ for each curve.

14 changes: 10 additions & 4 deletions driftbench/data_generation/latent_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class LatentInformation:
y2 (list-like): The y-values of the derivative of a function.
x2 (list-like): The x-values of the second derivative of a function.
Hence, no duplicates are allowed.

"""

y0: np.ndarray
x0: np.ndarray
y1: np.ndarray
Expand All @@ -35,11 +35,17 @@ def __post_init__(self):

def _validate_matching_shapes(self):
if self.y0.shape != self.x0.shape:
raise ValueError("Features y0 and x0 are not allowed to have different shape")
raise ValueError(
"Features y0 and x0 are not allowed to have different shape"
)
if self.y1.shape != self.x1.shape:
raise ValueError("Features y1 and x1 are not allowed to have different shape")
raise ValueError(
"Features y1 and x1 are not allowed to have different shape"
)
if self.y2.shape != self.x2.shape:
raise ValueError("Features y2 and x2 are not allowed to have different shape")
raise ValueError(
"Features y2 and x2 are not allowed to have different shape"
)

def _validate_1d_array(self):

Expand Down