Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions benchmark/pipelines/deep_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,11 @@ def setup_directories(workdir, seed, experiment, step_size):


def main(argv):
"""
Main function to train deep ensembles and save results.
"""Main function to train deep ensembles and save results.

Args:
argv (list): Command line arguments.
"""

logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
logging.info("JAX local devices: %r", jax.local_devices())

Expand Down
2 changes: 1 addition & 1 deletion benchmark/pipelines/kernelized_stein_discrepancies.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def compute_discrepancies(

for j, file in tqdm(enumerate(files)):
dataset_idx = int(str(file).split("_")[-1].split(".")[0])
X_train, _, y_train, _, _, _= load_data_fn(dataset_idx=dataset_idx)
X_train, _, y_train, _, _, _ = load_data_fn(dataset_idx=dataset_idx)
X_train, y_train = jnp.array(X_train), jnp.array(y_train)

score_fn = jax.jit(jax.grad(logprob_fn, argnums=0))
Expand Down
4 changes: 1 addition & 3 deletions benchmark/pipelines/laplace_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@ def train_model(model, train_loader, optimizer, loss_fn, num_epochs=1000):


def main(argv):
"""
Main function to train a model, fit a Laplace approximation, and save results.
"""Main function to train a model, fit a Laplace approximation, and save results.

Args:
argv (list): Command line arguments.
"""

workdir = FLAGS.workdir
step_size = FLAGS.step_size
num_datasets = FLAGS.num_datasets
Expand Down
40 changes: 19 additions & 21 deletions benchmark/pipelines/laplace_pretrained_heteroscedastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from absl import app, flags
from experiments import load_experiment
from hetreg.marglik import marglik_optimization
from laplace import KronLaplace

# from laplace import Laplace
# from laplace.curvature.backpack import BackPackGGN
from laplace.curvature.asdl import AsdlGGN
from laplace import KronLaplace
from rich.progress import track
from torch.utils.data import DataLoader, TensorDataset
from hetreg.marglik import marglik_optimization

FLAGS = flags.FLAGS
flags.DEFINE_string("workdir", default=".", help="Directory where data will be stored")
Expand Down Expand Up @@ -53,13 +52,11 @@ def setup_directories(workdir, seed, experiment, step_size):


def main(argv):
"""
Main function to train a model, fit a Laplace approximation, and save results.
"""Main function to train a model, fit a Laplace approximation, and save results.

Args:
argv (list): Command line arguments.
"""

workdir = FLAGS.workdir
step_size = FLAGS.step_size
num_datasets = FLAGS.num_datasets
Expand Down Expand Up @@ -88,7 +85,7 @@ def main(argv):
)

t_initial = time()

lr = step_size
lr_min = 1e-5
lr_hyp = 1e-1
Expand All @@ -97,40 +94,41 @@ def main(argv):
n_epochs = 10000
n_hypersteps = 50
marglik_frequency = 50
laplace = KronLaplace
optimizer = 'Adam'
laplace = KronLaplace
optimizer = "Adam"
backend = AsdlGGN
n_epochs_burnin = 100
prior_prec_init = 1e-3

la, model, margliksh, _, _ = marglik_optimization(
model,
train_loader,
likelihood='heteroscedastic_regression',
lr=lr, lr_min=lr_min,
model,
train_loader,
likelihood="heteroscedastic_regression",
lr=lr,
lr_min=lr_min,
lr_hyp=lr_hyp,
early_stopping=marglik_early_stopping,
lr_hyp_min=lr_hyp_min,
n_epochs=n_epochs,
n_hypersteps=n_hypersteps,
marglik_frequency=marglik_frequency,
laplace=laplace,
prior_structure='layerwise',
prior_structure="layerwise",
backend=backend,
n_epochs_burnin=n_epochs_burnin,
scheduler='cos',
scheduler="cos",
optimizer=optimizer,
prior_prec_init=prior_prec_init,
use_wandb=False
use_wandb=False,
)

t_final = time()

X_test = torch.tensor(np.array(X_test), dtype=torch.float32)
f_mu, f_var, y_var = la(X_test)

positions = la.sample(n_samples=2000).detach().cpu().numpy()

predictions = la.predictive_samples(X_test, n_samples=2000)

np.savez(
Expand Down
1 change: 0 additions & 1 deletion benchmark/pipelines/lengthscale_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def estimate_lengthscale(experiment: str, filenames: list, n: int = 1000):
"""Estimates the median heuristic with a subsample of the whole available data.
It returns the median heuristics that should be used for computing KSD and SKSD.
"""

bins = [0]
repeated_filenames = []
idx_map = []
Expand Down
6 changes: 3 additions & 3 deletions benchmark/pipelines/maximum_mean_discrepancies.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ def MaximumMeanDiscrepancy(X: jnp.ndarray, Y: jnp.ndarray):
params_hmc += [data_hmc["positions"][chain][100:]]
preds_hmc = np.concatenate(preds_hmc, axis=0).squeeze()
params_hmc = np.concatenate(params_hmc, axis=0).squeeze()

if preds_hmc.ndim == 3:
preds_hmc = preds_hmc[:, :, 0]

for i, (ikey, ival) in tqdm(enumerate(samples.items())):
data_i = jnp.load(ival)
x_params = jnp.array(data_i["positions"], dtype=jnp.float64)
x_preds = jnp.array(data_i["predictions"].squeeze(), dtype=jnp.float64)

if x_preds.ndim == 3:
x_preds = x_preds[:, :, 0]

Expand Down
29 changes: 15 additions & 14 deletions benchmark/pipelines/metrics_prediction_intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import jax.numpy as jnp
import numpy as np
import numpy as np
import scipy.stats as stats
from experiments import Experiment, load_experiment
from tqdm import tqdm
Expand Down Expand Up @@ -31,12 +30,14 @@ def compute_interval_metrics(alphas, data, algorithm, y_test):

inside_dim = (y_test - qhigh[:, None] < 0) * (y_test - qlow[:, None] > 0)
inside.append(inside_dim)

widths_per_alpha.append(np.mean(qhigh - qlow))
else:
# get predictions
f_predictions = data["predictions"]
y_predictions = f_predictions + noise_level * np.random.randn(*f_predictions.shape)
y_predictions = f_predictions + noise_level * np.random.randn(
*f_predictions.shape
)

# compute coverage probabilities
inside = []
Expand Down Expand Up @@ -71,12 +72,14 @@ def compute_interval_metrics_heteroscedastic(alphas, data, algorithm, y_test):

inside_dim = (y_test - qhigh[:, None] < 0) * (y_test - qlow[:, None] > 0)
inside.append(inside_dim)

widths_per_alpha.append(np.mean(qhigh - qlow))
else:
# get predictions
f_predictions = data["predictions"]
y_predictions = f_predictions[:, :, 0] + noise_level * f_predictions[:, :, 1] * np.random.randn(*f_predictions[:, :, 0].shape)
y_predictions = f_predictions[:, :, 0] + noise_level * f_predictions[
:, :, 1
] * np.random.randn(*f_predictions[:, :, 0].shape)

# compute coverage probabilities
inside = []
Expand All @@ -95,8 +98,7 @@ def compute_interval_metrics_heteroscedastic(alphas, data, algorithm, y_test):


def compute_regression_metrics(y_true, y_pred):
"""
Args:
"""Args:
y_true (np.ndarray)
y_pred (np.ndarray)

Expand All @@ -114,8 +116,7 @@ def compute_regression_metrics(y_true, y_pred):


def compute_coverage_probabilities(alphas: np.ndarray, inside_list: np.ndarray):
"""
Args:
"""Args:
alphas (np.ndarray): Array of confidence levels
inside_list (np.ndarray): Array of bools such that True means that the points falls into the interval. Shape: [#_datasets, #_alphas, #_xtest, #_ytest_per_xtest]

Expand All @@ -136,8 +137,7 @@ def compute_coverage_probabilities(alphas: np.ndarray, inside_list: np.ndarray):


def post_process_data(experiment: Experiment, files: list, algorithm: str):
"""
Post-processes benchmark data to compute coverage probabilities, RMSE, Q2 scores, and other metrics.
"""Post-processes benchmark data to compute coverage probabilities, RMSE, Q2 scores, and other metrics.

Parameters:
- experiment (Experiment): An instance of the Experiment class with a method to load data.
Expand All @@ -152,7 +152,6 @@ def post_process_data(experiment: Experiment, files: list, algorithm: str):
- widths (np.ndarray): Average width of confidence intervals for each dataset.
- times (np.ndarray): Computational time for each dataset.
"""

inside_list = []
widths_list = []
q2_scores = []
Expand All @@ -177,8 +176,10 @@ def post_process_data(experiment: Experiment, files: list, algorithm: str):

# compute is_inside arrays
alphas = np.linspace(0.05, 0.95, 19)
inside_per_alpha, widths_per_alpha = compute_interval_metrics_heteroscedastic(
alphas, data, algorithm, y_test
inside_per_alpha, widths_per_alpha = (
compute_interval_metrics_heteroscedastic(
alphas, data, algorithm, y_test
)
)
else:
# mean prediction
Expand Down
23 changes: 13 additions & 10 deletions benchmark/pipelines/sgmcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import jax
import jax.numpy as jnp
import jax.random as jr
import pbnn.mcmc as mcmc
from absl import app, flags, logging
from experiments import load_experiment
from logprobs import logprior_fn
from pbnn.utils.misc import thinning_fn
from rich.progress import track

import pbnn.mcmc as mcmc
from pbnn.utils.misc import thinning_fn

FLAGS = flags.FLAGS
flags.DEFINE_string("workdir", default=".", help="Directory where data will be stored")
flags.DEFINE_string("algorithm", default="sgld", help="name of the sgmcmc algorithm")
Expand All @@ -31,7 +32,9 @@
flags.DEFINE_integer(
"seed", default=0, help="Initial seed that will be split accross the functions"
)
flags.DEFINE_string("init_method", default="map", help="Initialization method for MCMC.")
flags.DEFINE_string(
"init_method", default="map", help="Initialization method for MCMC."
)


def setup_directories(workdir, seed, experiment, algorithm, step_size):
Expand All @@ -47,7 +50,11 @@ def setup_directories(workdir, seed, experiment, algorithm, step_size):

WORKDIR = Path(workdir)
ROOT = (
WORKDIR / f"seed_{seed}" / experiment.name / algorithm.__name__ / f"lr_{step_size}"
WORKDIR
/ f"seed_{seed}"
/ experiment.name
/ algorithm.__name__
/ f"lr_{step_size}"
)
INIT_PARAMS_DIR = WORKDIR / f"seed_{seed}" / experiment.name / "init_params"

Expand All @@ -56,13 +63,11 @@ def setup_directories(workdir, seed, experiment, algorithm, step_size):


def main(argv):
"""
Main function for running SGMCMC algorithms and save results.
"""Main function for running SGMCMC algorithms and save results.

Args:
argv (list): Command line arguments.
"""

logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
logging.info("JAX local devices: %r", jax.local_devices())

Expand Down Expand Up @@ -96,13 +101,11 @@ def sgmcmc_fn(algorithm, X, y, init_positions, map_positions, rng_key):
rng_key
PRNGKey

Returns
Returns:
-------

Markov Chain of positions with burnin already removed

"""

match algorithm.__name__:
case "sgld":
hparams = {"num_iterations": 100_00}
Expand Down
17 changes: 11 additions & 6 deletions benchmark/pipelines/sgmcmc_single_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import jax
import jax.numpy as jnp
import jax.random as jr
import pbnn.mcmc as mcmc
from absl import app, flags, logging
from experiments import load_experiment
from logprobs import logprior_fn

import pbnn.mcmc as mcmc
from pbnn.utils.misc import thinning_fn

FLAGS = flags.FLAGS
Expand All @@ -35,7 +36,9 @@
flags.DEFINE_integer(
"seed", default=0, help="Initial seed that will be split accross the functions"
)
flags.DEFINE_string("init_method", default="map", help="Initialization method for MCMC.")
flags.DEFINE_string(
"init_method", default="map", help="Initialization method for MCMC."
)


def setup_directories(workdir, seed, experiment, algorithm, step_size):
Expand All @@ -51,7 +54,11 @@ def setup_directories(workdir, seed, experiment, algorithm, step_size):

WORKDIR = Path(workdir)
ROOT = (
WORKDIR / f"seed_{seed}" / experiment.name / algorithm.__name__ / f"lr_{step_size}"
WORKDIR
/ f"seed_{seed}"
/ experiment.name
/ algorithm.__name__
/ f"lr_{step_size}"
)
INIT_PARAMS_DIR = WORKDIR / f"seed_{seed}" / experiment.name / "init_params"

Expand Down Expand Up @@ -94,13 +101,11 @@ def sgmcmc_fn(algorithm, X, y, init_positions, map_positions, rng_key):
rng_key
PRNGKey

Returns
Returns:
-------

Markov Chain of positions with burnin already removed

"""

match algorithm.__name__:
case "sgld":
hparams = {"num_iterations": 100_000}
Expand Down
Loading