Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.7
0.0.8
58 changes: 40 additions & 18 deletions driftbench/data_generation/sample.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
import numpy as np
import jax # noqa
import jax # noqa
from driftbench.data_generation.latent_information import LatentInformation
from driftbench.data_generation.data_generator import CurveGenerator


def sample_curves(dataset_specification, f=None, w0=None, random_state=2024, measurement_scale=None, callback=None):
def sample_curves(
dataset_specification,
f=None,
w0=None,
random_state=2024,
measurement_scale=None,
callback=None,
):
dimensions = dataset_specification["dimensions"]
drifts = dataset_specification.get("drifts")
x_scale = dataset_specification.get("x_scale", 0.02)
y_scale = dataset_specification.get("y_scale", 0.1)
func = _get_func(dataset_specification, f)
w_init = _get_w_init(dataset_specification, w0)
rng = np.random.RandomState(random_state)
latent_information = _generate_latent_information(dataset_specification, rng, x_scale, y_scale)
latent_information = _generate_latent_information(
dataset_specification, rng, x_scale, y_scale
)
if drifts is not None:
latent_information = drifts.apply(latent_information)
data_generator = CurveGenerator(func, w_init)
w = data_generator.run(latent_information, callback=callback)
x_min = int(np.min(dataset_specification["latent_information"].x0))
x_max = int(np.max(dataset_specification["latent_information"].x0))
x_range = np.concatenate(
(
dataset_specification["latent_information"].x0,
dataset_specification["latent_information"].x1,
dataset_specification["latent_information"].x2,
)
)
x_min, x_max = int(np.min(x_range)), int(np.max(x_range))
xs = np.linspace(x_min, x_max, dimensions)
curves = np.array([func(w_i, xs) for w_i in w])
# Apply a default noise of 5% of the mean of the sampled curves
Expand All @@ -34,17 +49,20 @@ def _generate_latent_information(dataset_specification, rng, x_scale, y_scale):
N = dataset_specification["N"]
base_latent_information = dataset_specification["latent_information"]
latent_information = []
xis = [f"x{i}" for i in range(3)]
yis = [f"y{i}" for i in range(3)]
for i in range(N):
# Apply some random noise on the base values
x0 = base_latent_information.x0 + rng.normal(size=len(base_latent_information.x0), scale=x_scale)
y0 = base_latent_information.y0 + rng.normal(size=len(base_latent_information.y0), scale=y_scale)

x1 = base_latent_information.x1 + rng.normal(size=len(base_latent_information.x1), scale=x_scale)
y1 = base_latent_information.y1 + rng.normal(size=len(base_latent_information.y1), scale=y_scale)

x2 = base_latent_information.x2 + rng.normal(size=len(base_latent_information.x2), scale=x_scale)
y2 = base_latent_information.y2 + rng.normal(size=len(base_latent_information.y2), scale=y_scale)
latent_information.append(LatentInformation(y0, x0, y1, x1, y2, x2))
latent_dict = {}
for xi in xis:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may could also be a single for loop :-)

for v in xis+yis:

with x_scale and y_scale somehow accessed via v[0] and

scale = {'x': x_scale, 'y': y_scale}

But fine for me for now!

latent_dict[xi] = getattr(base_latent_information, xi) + rng.normal(
size=len(getattr(base_latent_information, xi)), scale=x_scale
)
for yi in yis:
latent_dict[yi] = getattr(base_latent_information, yi) + rng.normal(
size=len(getattr(base_latent_information, yi)), scale=y_scale
)
latent_information.append(LatentInformation(**latent_dict))
return latent_information


Expand All @@ -55,9 +73,11 @@ def _get_func(dataset_specification, f):
func_expr = dataset_specification["func"]
return eval(f"lambda w, x: {func_expr}")
else:
raise ValueError("""No function provided. Either specify function in yaml
raise ValueError(
"""No function provided. Either specify function in yaml
file, or provide a function as argument to the sample
function.""")
function."""
)


def _get_w_init(dataset_specification, w0):
Expand All @@ -71,5 +91,7 @@ def _get_w_init(dataset_specification, w0):
w_init = w_init_expr
return np.array(w_init, dtype=np.float64)
else:
raise ValueError("""No initial guess provided. Either specify initial guess in
yaml file, or provide an inital guess as argument to the sample function.""")
raise ValueError(
"""No initial guess provided. Either specify initial guess in
yaml file, or provide an inital guess as argument to the sample function."""
)