diff --git a/VERSION b/VERSION index 5a5831a..d169b2f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.7 +0.0.8 diff --git a/driftbench/data_generation/sample.py b/driftbench/data_generation/sample.py index 2d93431..30f5179 100644 --- a/driftbench/data_generation/sample.py +++ b/driftbench/data_generation/sample.py @@ -1,10 +1,17 @@ import numpy as np -import jax # noqa +import jax # noqa from driftbench.data_generation.latent_information import LatentInformation from driftbench.data_generation.data_generator import CurveGenerator -def sample_curves(dataset_specification, f=None, w0=None, random_state=2024, measurement_scale=None, callback=None): +def sample_curves( + dataset_specification, + f=None, + w0=None, + random_state=2024, + measurement_scale=None, + callback=None, +): dimensions = dataset_specification["dimensions"] drifts = dataset_specification.get("drifts") x_scale = dataset_specification.get("x_scale", 0.02) @@ -12,13 +19,21 @@ def sample_curves(dataset_specification, f=None, w0=None, random_state=2024, mea func = _get_func(dataset_specification, f) w_init = _get_w_init(dataset_specification, w0) rng = np.random.RandomState(random_state) - latent_information = _generate_latent_information(dataset_specification, rng, x_scale, y_scale) + latent_information = _generate_latent_information( + dataset_specification, rng, x_scale, y_scale + ) if drifts is not None: latent_information = drifts.apply(latent_information) data_generator = CurveGenerator(func, w_init) w = data_generator.run(latent_information, callback=callback) - x_min = int(np.min(dataset_specification["latent_information"].x0)) - x_max = int(np.max(dataset_specification["latent_information"].x0)) + x_range = np.concatenate( + ( + dataset_specification["latent_information"].x0, + dataset_specification["latent_information"].x1, + dataset_specification["latent_information"].x2, + ) + ) + x_min, x_max = int(np.min(x_range)), int(np.max(x_range)) xs = np.linspace(x_min, x_max, dimensions) curves = np.array([func(w_i, xs) for w_i in w]) # Apply a default noise of 5% of the mean of the sampled curves @@ -34,17 +49,20 @@ def _generate_latent_information(dataset_specification, rng, x_scale, y_scale): N = dataset_specification["N"] base_latent_information = dataset_specification["latent_information"] latent_information = [] + xis = [f"x{i}" for i in range(3)] + yis = [f"y{i}" for i in range(3)] for i in range(N): # Apply some random noise on the base values - x0 = base_latent_information.x0 + rng.normal(size=len(base_latent_information.x0), scale=x_scale) - y0 = base_latent_information.y0 + rng.normal(size=len(base_latent_information.y0), scale=y_scale) - - x1 = base_latent_information.x1 + rng.normal(size=len(base_latent_information.x1), scale=x_scale) - y1 = base_latent_information.y1 + rng.normal(size=len(base_latent_information.y1), scale=y_scale) - - x2 = base_latent_information.x2 + rng.normal(size=len(base_latent_information.x2), scale=x_scale) - y2 = base_latent_information.y2 + rng.normal(size=len(base_latent_information.y2), scale=y_scale) - latent_information.append(LatentInformation(y0, x0, y1, x1, y2, x2)) + latent_dict = {} + for xi in xis: + latent_dict[xi] = getattr(base_latent_information, xi) + rng.normal( + size=len(getattr(base_latent_information, xi)), scale=x_scale + ) + for yi in yis: + latent_dict[yi] = getattr(base_latent_information, yi) + rng.normal( + size=len(getattr(base_latent_information, yi)), scale=y_scale + ) + latent_information.append(LatentInformation(**latent_dict)) return latent_information @@ -55,9 +73,11 @@ def _get_func(dataset_specification, f): func_expr = dataset_specification["func"] return eval(f"lambda w, x: {func_expr}") else: - raise ValueError("""No function provided. Either specify function in yaml + raise ValueError( + """No function provided. Either specify function in yaml file, or provide a function as argument to the sample - function.""") + function.""" + ) def _get_w_init(dataset_specification, w0): @@ -71,5 +91,7 @@ def _get_w_init(dataset_specification, w0): w_init = w_init_expr return np.array(w_init, dtype=np.float64) else: - raise ValueError("""No initial guess provided. Either specify initial guess in - yaml file, or provide an inital guess as argument to the sample function.""") + raise ValueError( + """No initial guess provided. Either specify initial guess in + yaml file, or provide an inital guess as argument to the sample function.""" + )