diff --git a/benchmark/pipelines/deep_ensembles.py b/benchmark/pipelines/deep_ensembles.py index d2a37ed..4efa09a 100644 --- a/benchmark/pipelines/deep_ensembles.py +++ b/benchmark/pipelines/deep_ensembles.py @@ -52,13 +52,11 @@ def setup_directories(workdir, seed, experiment, step_size): def main(argv): - """ - Main function to train deep ensembles and save results. + """Main function to train deep ensembles and save results. Args: argv (list): Command line arguments. """ - logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX local devices: %r", jax.local_devices()) diff --git a/benchmark/pipelines/kernelized_stein_discrepancies.py b/benchmark/pipelines/kernelized_stein_discrepancies.py index d5ed42d..fb55a66 100644 --- a/benchmark/pipelines/kernelized_stein_discrepancies.py +++ b/benchmark/pipelines/kernelized_stein_discrepancies.py @@ -147,7 +147,7 @@ def compute_discrepancies( for j, file in tqdm(enumerate(files)): dataset_idx = int(str(file).split("_")[-1].split(".")[0]) - X_train, _, y_train, _, _, _= load_data_fn(dataset_idx=dataset_idx) + X_train, _, y_train, _, _, _ = load_data_fn(dataset_idx=dataset_idx) X_train, y_train = jnp.array(X_train), jnp.array(y_train) score_fn = jax.jit(jax.grad(logprob_fn, argnums=0)) diff --git a/benchmark/pipelines/laplace_pretrained.py b/benchmark/pipelines/laplace_pretrained.py index 3548eff..4236fb1 100644 --- a/benchmark/pipelines/laplace_pretrained.py +++ b/benchmark/pipelines/laplace_pretrained.py @@ -50,13 +50,11 @@ def train_model(model, train_loader, optimizer, loss_fn, num_epochs=1000): def main(argv): - """ - Main function to train a model, fit a Laplace approximation, and save results. + """Main function to train a model, fit a Laplace approximation, and save results. Args: argv (list): Command line arguments. """ - workdir = FLAGS.workdir step_size = FLAGS.step_size num_datasets = FLAGS.num_datasets diff --git a/benchmark/pipelines/laplace_pretrained_heteroscedastic.py b/benchmark/pipelines/laplace_pretrained_heteroscedastic.py index e8e6ce5..9d552ae 100644 --- a/benchmark/pipelines/laplace_pretrained_heteroscedastic.py +++ b/benchmark/pipelines/laplace_pretrained_heteroscedastic.py @@ -3,17 +3,16 @@ import numpy as np import torch -import torch.nn as nn -import torch.optim as optim from absl import app, flags from experiments import load_experiment +from hetreg.marglik import marglik_optimization +from laplace import KronLaplace + # from laplace import Laplace # from laplace.curvature.backpack import BackPackGGN from laplace.curvature.asdl import AsdlGGN -from laplace import KronLaplace from rich.progress import track from torch.utils.data import DataLoader, TensorDataset -from hetreg.marglik import marglik_optimization FLAGS = flags.FLAGS flags.DEFINE_string("workdir", default=".", help="Directory where data will be stored") @@ -53,13 +52,11 @@ def setup_directories(workdir, seed, experiment, step_size): def main(argv): - """ - Main function to train a model, fit a Laplace approximation, and save results. + """Main function to train a model, fit a Laplace approximation, and save results. Args: argv (list): Command line arguments. """ - workdir = FLAGS.workdir step_size = FLAGS.step_size num_datasets = FLAGS.num_datasets @@ -88,7 +85,7 @@ def main(argv): ) t_initial = time() - + lr = step_size lr_min = 1e-5 lr_hyp = 1e-1 @@ -97,17 +94,18 @@ def main(argv): n_epochs = 10000 n_hypersteps = 50 marglik_frequency = 50 - laplace = KronLaplace - optimizer = 'Adam' + laplace = KronLaplace + optimizer = "Adam" backend = AsdlGGN n_epochs_burnin = 100 prior_prec_init = 1e-3 - + la, model, margliksh, _, _ = marglik_optimization( - model, - train_loader, - likelihood='heteroscedastic_regression', - lr=lr, lr_min=lr_min, + model, + train_loader, + likelihood="heteroscedastic_regression", + lr=lr, + lr_min=lr_min, lr_hyp=lr_hyp, early_stopping=marglik_early_stopping, lr_hyp_min=lr_hyp_min, @@ -115,22 +113,22 @@ def main(argv): n_hypersteps=n_hypersteps, marglik_frequency=marglik_frequency, laplace=laplace, - prior_structure='layerwise', + prior_structure="layerwise", backend=backend, n_epochs_burnin=n_epochs_burnin, - scheduler='cos', + scheduler="cos", optimizer=optimizer, prior_prec_init=prior_prec_init, - use_wandb=False + use_wandb=False, ) - + t_final = time() X_test = torch.tensor(np.array(X_test), dtype=torch.float32) f_mu, f_var, y_var = la(X_test) - + positions = la.sample(n_samples=2000).detach().cpu().numpy() - + predictions = la.predictive_samples(X_test, n_samples=2000) np.savez( diff --git a/benchmark/pipelines/lengthscale_estimation.py b/benchmark/pipelines/lengthscale_estimation.py index bd4a416..718c348 100644 --- a/benchmark/pipelines/lengthscale_estimation.py +++ b/benchmark/pipelines/lengthscale_estimation.py @@ -31,7 +31,6 @@ def estimate_lengthscale(experiment: str, filenames: list, n: int = 1000): """Estimates the median heuristic with a subsample of the whole available data. It returns the median heuristics that should be used for computing KSD and SKSD. """ - bins = [0] repeated_filenames = [] idx_map = [] diff --git a/benchmark/pipelines/maximum_mean_discrepancies.py b/benchmark/pipelines/maximum_mean_discrepancies.py index e75a61f..5cdd8ba 100644 --- a/benchmark/pipelines/maximum_mean_discrepancies.py +++ b/benchmark/pipelines/maximum_mean_discrepancies.py @@ -127,15 +127,15 @@ def MaximumMeanDiscrepancy(X: jnp.ndarray, Y: jnp.ndarray): params_hmc += [data_hmc["positions"][chain][100:]] preds_hmc = np.concatenate(preds_hmc, axis=0).squeeze() params_hmc = np.concatenate(params_hmc, axis=0).squeeze() - + if preds_hmc.ndim == 3: preds_hmc = preds_hmc[:, :, 0] - + for i, (ikey, ival) in tqdm(enumerate(samples.items())): data_i = jnp.load(ival) x_params = jnp.array(data_i["positions"], dtype=jnp.float64) x_preds = jnp.array(data_i["predictions"].squeeze(), dtype=jnp.float64) - + if x_preds.ndim == 3: x_preds = x_preds[:, :, 0] diff --git a/benchmark/pipelines/metrics_prediction_intervals.py b/benchmark/pipelines/metrics_prediction_intervals.py index 61eea1e..fed1c8e 100644 --- a/benchmark/pipelines/metrics_prediction_intervals.py +++ b/benchmark/pipelines/metrics_prediction_intervals.py @@ -3,7 +3,6 @@ import jax.numpy as jnp import numpy as np -import numpy as np import scipy.stats as stats from experiments import Experiment, load_experiment from tqdm import tqdm @@ -31,12 +30,14 @@ def compute_interval_metrics(alphas, data, algorithm, y_test): inside_dim = (y_test - qhigh[:, None] < 0) * (y_test - qlow[:, None] > 0) inside.append(inside_dim) - + widths_per_alpha.append(np.mean(qhigh - qlow)) else: # get predictions f_predictions = data["predictions"] - y_predictions = f_predictions + noise_level * np.random.randn(*f_predictions.shape) + y_predictions = f_predictions + noise_level * np.random.randn( + *f_predictions.shape + ) # compute coverage probabilities inside = [] @@ -71,12 +72,14 @@ def compute_interval_metrics_heteroscedastic(alphas, data, algorithm, y_test): inside_dim = (y_test - qhigh[:, None] < 0) * (y_test - qlow[:, None] > 0) inside.append(inside_dim) - + widths_per_alpha.append(np.mean(qhigh - qlow)) else: # get predictions f_predictions = data["predictions"] - y_predictions = f_predictions[:, :, 0] + noise_level * f_predictions[:, :, 1] * np.random.randn(*f_predictions[:, :, 0].shape) + y_predictions = f_predictions[:, :, 0] + noise_level * f_predictions[ + :, :, 1 + ] * np.random.randn(*f_predictions[:, :, 0].shape) # compute coverage probabilities inside = [] @@ -95,8 +98,7 @@ def compute_interval_metrics_heteroscedastic(alphas, data, algorithm, y_test): def compute_regression_metrics(y_true, y_pred): - """ - Args: + """Args: y_true (np.ndarray) y_pred (np.ndarray) @@ -114,8 +116,7 @@ def compute_regression_metrics(y_true, y_pred): def compute_coverage_probabilities(alphas: np.ndarray, inside_list: np.ndarray): - """ - Args: + """Args: alphas (np.ndarray): Array of confidence levels inside_list (np.ndarray): Array of bools such that True means that the points falls into the interval. Shape: [#_datasets, #_alphas, #_xtest, #_ytest_per_xtest] @@ -136,8 +137,7 @@ def compute_coverage_probabilities(alphas: np.ndarray, inside_list: np.ndarray): def post_process_data(experiment: Experiment, files: list, algorithm: str): - """ - Post-processes benchmark data to compute coverage probabilities, RMSE, Q2 scores, and other metrics. + """Post-processes benchmark data to compute coverage probabilities, RMSE, Q2 scores, and other metrics. Parameters: - experiment (Experiment): An instance of the Experiment class with a method to load data. @@ -152,7 +152,6 @@ def post_process_data(experiment: Experiment, files: list, algorithm: str): - widths (np.ndarray): Average width of confidence intervals for each dataset. - times (np.ndarray): Computational time for each dataset. """ - inside_list = [] widths_list = [] q2_scores = [] @@ -177,8 +176,10 @@ def post_process_data(experiment: Experiment, files: list, algorithm: str): # compute is_inside arrays alphas = np.linspace(0.05, 0.95, 19) - inside_per_alpha, widths_per_alpha = compute_interval_metrics_heteroscedastic( - alphas, data, algorithm, y_test + inside_per_alpha, widths_per_alpha = ( + compute_interval_metrics_heteroscedastic( + alphas, data, algorithm, y_test + ) ) else: # mean prediction diff --git a/benchmark/pipelines/sgmcmc.py b/benchmark/pipelines/sgmcmc.py index 82998dc..46d1cc4 100644 --- a/benchmark/pipelines/sgmcmc.py +++ b/benchmark/pipelines/sgmcmc.py @@ -5,13 +5,14 @@ import jax import jax.numpy as jnp import jax.random as jr -import pbnn.mcmc as mcmc from absl import app, flags, logging from experiments import load_experiment from logprobs import logprior_fn -from pbnn.utils.misc import thinning_fn from rich.progress import track +import pbnn.mcmc as mcmc +from pbnn.utils.misc import thinning_fn + FLAGS = flags.FLAGS flags.DEFINE_string("workdir", default=".", help="Directory where data will be stored") flags.DEFINE_string("algorithm", default="sgld", help="name of the sgmcmc algorithm") @@ -31,7 +32,9 @@ flags.DEFINE_integer( "seed", default=0, help="Initial seed that will be split accross the functions" ) -flags.DEFINE_string("init_method", default="map", help="Initialization method for MCMC.") +flags.DEFINE_string( + "init_method", default="map", help="Initialization method for MCMC." +) def setup_directories(workdir, seed, experiment, algorithm, step_size): @@ -47,7 +50,11 @@ def setup_directories(workdir, seed, experiment, algorithm, step_size): WORKDIR = Path(workdir) ROOT = ( - WORKDIR / f"seed_{seed}" / experiment.name / algorithm.__name__ / f"lr_{step_size}" + WORKDIR + / f"seed_{seed}" + / experiment.name + / algorithm.__name__ + / f"lr_{step_size}" ) INIT_PARAMS_DIR = WORKDIR / f"seed_{seed}" / experiment.name / "init_params" @@ -56,13 +63,11 @@ def setup_directories(workdir, seed, experiment, algorithm, step_size): def main(argv): - """ - Main function for running SGMCMC algorithms and save results. + """Main function for running SGMCMC algorithms and save results. Args: argv (list): Command line arguments. """ - logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX local devices: %r", jax.local_devices()) @@ -96,13 +101,11 @@ def sgmcmc_fn(algorithm, X, y, init_positions, map_positions, rng_key): rng_key PRNGKey - Returns + Returns: ------- - Markov Chain of positions with burnin already removed """ - match algorithm.__name__: case "sgld": hparams = {"num_iterations": 100_00} diff --git a/benchmark/pipelines/sgmcmc_single_dataset.py b/benchmark/pipelines/sgmcmc_single_dataset.py index 36c2d3a..f116010 100644 --- a/benchmark/pipelines/sgmcmc_single_dataset.py +++ b/benchmark/pipelines/sgmcmc_single_dataset.py @@ -5,10 +5,11 @@ import jax import jax.numpy as jnp import jax.random as jr -import pbnn.mcmc as mcmc from absl import app, flags, logging from experiments import load_experiment from logprobs import logprior_fn + +import pbnn.mcmc as mcmc from pbnn.utils.misc import thinning_fn FLAGS = flags.FLAGS @@ -35,7 +36,9 @@ flags.DEFINE_integer( "seed", default=0, help="Initial seed that will be split accross the functions" ) -flags.DEFINE_string("init_method", default="map", help="Initialization method for MCMC.") +flags.DEFINE_string( + "init_method", default="map", help="Initialization method for MCMC." +) def setup_directories(workdir, seed, experiment, algorithm, step_size): @@ -51,7 +54,11 @@ def setup_directories(workdir, seed, experiment, algorithm, step_size): WORKDIR = Path(workdir) ROOT = ( - WORKDIR / f"seed_{seed}" / experiment.name / algorithm.__name__ / f"lr_{step_size}" + WORKDIR + / f"seed_{seed}" + / experiment.name + / algorithm.__name__ + / f"lr_{step_size}" ) INIT_PARAMS_DIR = WORKDIR / f"seed_{seed}" / experiment.name / "init_params" @@ -94,13 +101,11 @@ def sgmcmc_fn(algorithm, X, y, init_positions, map_positions, rng_key): rng_key PRNGKey - Returns + Returns: ------- - Markov Chain of positions with burnin already removed """ - match algorithm.__name__: case "sgld": hparams = {"num_iterations": 100_000} diff --git a/benchmark/read_results.py b/benchmark/read_results.py index b635011..acd37a8 100644 --- a/benchmark/read_results.py +++ b/benchmark/read_results.py @@ -7,8 +7,7 @@ def load_all_results_with_hparam_names(base_path, seed, test_case, algorithm): - """ - Automatically load all results for a given algorithm and seed, storing the data in a dict + """Automatically load all results for a given algorithm and seed, storing the data in a dict where keys are dicts containing hyperparameters names and values, and the values are the loaded data. Parameters: @@ -61,8 +60,7 @@ def load_all_results_with_hparam_names(base_path, seed, test_case, algorithm): def aggregate_metrics_with_shapes(base_path, test_case, algorithm): - """ - Aggregates metrics across seeds and handles shape mismatches by padding shorter arrays. + """Aggregates metrics across seeds and handles shape mismatches by padding shorter arrays. Parameters: - base_path: str, base directory containing results. @@ -208,7 +206,9 @@ def plot_metrics(aggregated_metrics, algorithm_name): print( f"Hyperparameters: {dict(hparams)}" ) # Convert frozenset back to dict for readability - print(f"Metrics: {np.mean(data['q2_scores'])}") # Replace 'metrics' with the actual key in your npz file + print( + f"Metrics: {np.mean(data['q2_scores'])}" + ) # Replace 'metrics' with the actual key in your npz file # print(data["mcp"]) # # Example usage diff --git a/benchmark/run_conformal_algorithms.py b/benchmark/run_conformal_algorithms.py index 78c4797..7b9bbac 100644 --- a/benchmark/run_conformal_algorithms.py +++ b/benchmark/run_conformal_algorithms.py @@ -4,7 +4,6 @@ import numpy as np from rich.console import Console - from utils import launch sys.path.append("pipelines") @@ -37,13 +36,13 @@ for step_size in step_sizes: params["step_size"] = step_size launch(template_name="conformal", **params) - + params = common_params.copy() params["algorithm"] = "cv_plus" for step_size in step_sizes: params["step_size"] = step_size launch(template_name="conformal", **params) - + params = common_params.copy() params["algorithm"] = "split_cqr" for step_size in step_sizes: diff --git a/doc/conf.py b/doc/conf.py index c0d75e6..3b2d5cc 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -23,19 +23,19 @@ release = "0.0.1" extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.viewcode', - 'sphinx.ext.todo', - 'sphinx.ext.intersphinx', - 'sphinx.ext.doctest', - 'sphinx.ext.ifconfig', - 'sphinx.ext.duration', - 'sphinx.ext.extlinks', - 'sphinx.ext.napoleon', - 'myst_nb', + "sphinx.ext.autodoc", + "sphinx.ext.viewcode", + "sphinx.ext.todo", + "sphinx.ext.intersphinx", + "sphinx.ext.doctest", + "sphinx.ext.ifconfig", + "sphinx.ext.duration", + "sphinx.ext.extlinks", + "sphinx.ext.napoleon", + "myst_nb", # 'myst_parser', # imported by myst_nb # 'sphinxcontrib.apidoc', # autoapi is better - 'sphinx.ext.autosummary', + "sphinx.ext.autosummary", # 'sphinxcontrib.bibtex' ] @@ -56,10 +56,10 @@ # graphviz_output_format = 'svg' # myst_parser options source_suffix = { - '.rst': 'restructuredtext', - '.ipynb': 'myst-nb', - '.myst': 'myst-nb', - '.md': 'myst-nb', + ".rst": "restructuredtext", + ".ipynb": "myst-nb", + ".myst": "myst-nb", + ".md": "myst-nb", } myst_enable_extensions = [ "amsmath", @@ -99,7 +99,7 @@ "source_directory": "doc/", } -github_url = 'https://gitlab.com/drti/pbnn' +github_url = "https://gitlab.com/drti/pbnn" autodoc_mock_imports = ["jax"] -autodoc_member_order = "bysource" \ No newline at end of file +autodoc_member_order = "bysource" diff --git a/doc/sources/notebooks/flax_mapie.ipynb b/doc/sources/notebooks/flax_mapie.ipynb index ced1410..a7710a0 100644 --- a/doc/sources/notebooks/flax_mapie.ipynb +++ b/doc/sources/notebooks/flax_mapie.ipynb @@ -33,7 +33,7 @@ "from pbnn.utils.analytical_functions import gramacy_function\n", "from pbnn.utils.plot import plot_on_axis\n", "\n", - "warnings.filterwarnings('ignore')\n", + "warnings.filterwarnings(\"ignore\")\n", "\n", "%load_ext watermark" ] diff --git a/doc/sources/notebooks/full_example.ipynb b/doc/sources/notebooks/full_example.ipynb index f6c527e..7d5ad22 100644 --- a/doc/sources/notebooks/full_example.ipynb +++ b/doc/sources/notebooks/full_example.ipynb @@ -25,7 +25,6 @@ "source": [ "from time import time\n", "\n", - "import blackjax\n", "import flax.linen as nn\n", "import jax\n", "import jax.numpy as jnp\n", @@ -438,7 +437,7 @@ " positions, y_prediction = sgmcmc_fn(\n", " algorithm, burnin, thin_freq, init_pos, key, **hparams\n", " )\n", - " print(f\"Elapsed time for {algorithm.__name__}: {time()-t0}\")\n", + " print(f\"Elapsed time for {algorithm.__name__}: {time() - t0}\")\n", " y_predictions[algorithm.__name__] = y_prediction" ] }, @@ -480,7 +479,7 @@ " num_integration_steps=40,\n", " rng_key=key,\n", ")\n", - "print(f\"Elapsed time for HMC: {time()-t0}\")\n", + "print(f\"Elapsed time for HMC: {time() - t0}\")\n", "\n", "# predict\n", "f_predictions = predict_fn(network, positions, X_test).squeeze()\n", diff --git a/doc/sources/notebooks/sgld_example.ipynb b/doc/sources/notebooks/sgld_example.ipynb index d6f9f97..6ba9a72 100644 --- a/doc/sources/notebooks/sgld_example.ipynb +++ b/doc/sources/notebooks/sgld_example.ipynb @@ -13,7 +13,6 @@ "metadata": {}, "outputs": [], "source": [ - "import blackjax\n", "import flax.linen as nn\n", "import jax\n", "import jax.numpy as jnp\n", diff --git a/doc/sources/notebooks/thinning_example.ipynb b/doc/sources/notebooks/thinning_example.ipynb index f86ac98..82c24bf 100644 --- a/doc/sources/notebooks/thinning_example.ipynb +++ b/doc/sources/notebooks/thinning_example.ipynb @@ -15,7 +15,6 @@ "metadata": {}, "outputs": [], "source": [ - "import blackjax\n", "import flax.linen as nn\n", "import jax\n", "import jax.numpy as jnp\n", @@ -28,8 +27,8 @@ "\n", "from pbnn.mcmc.langevin import sgld\n", "from pbnn.utils.analytical_functions import gramacy_function\n", - "from pbnn.utils.plot import plot_on_axis\n", "from pbnn.utils.misc import thinning_fn\n", + "from pbnn.utils.plot import plot_on_axis\n", "\n", "%load_ext watermark" ] @@ -238,7 +237,7 @@ "qhigh = jnp.quantile(y_predictions, (1 - 0.5 * alpha), axis=0)\n", "\n", "ax = fig.add_subplot(gs[0])\n", - "plot_on_axis(ax, X_test, y_test, mean_prediction, qlow, qhigh, title=f\"sgld\")" + "plot_on_axis(ax, X_test, y_test, mean_prediction, qlow, qhigh, title=\"sgld\")" ] }, { diff --git a/src/pbnn/core/api.py b/src/pbnn/core/api.py new file mode 100644 index 0000000..159e14b --- /dev/null +++ b/src/pbnn/core/api.py @@ -0,0 +1,96 @@ +# src/pbnn/core/api.py +"""Core API definitions for probabilistic models and inference methods.""" + +from dataclasses import dataclass +from typing import Any, Dict, Protocol, Tuple, runtime_checkable + +from jax import Array + +__all__ = ["SupervisedBatch", "Posterior", "InferenceMethod"] + + +@dataclass +class SupervisedBatch: + """A batch of supervised data. + + Attributes: + x (Array): Input features with leading batch dimension, + e.g., shape ``(B, d_in, ...)``. + y (Array): Targets compatible with ``x``, e.g., shape ``(B, d_out, ...)``. + """ + + x: Array + y: Array + + +@runtime_checkable +class Posterior(Protocol): + """Posterior distribution over model predictions or parameters.""" + + def predictive_mean_var(self, x: Array, **kwargs) -> Tuple[Array, Array]: + """Compute predictive mean and variance at inputs. + + Args: + x (Array): Inputs with leading batch dimension. + **kwargs: Method-specific keyword arguments + (e.g., ``train=False``, RNGs, Flax `mutable`). + + Returns: + Tuple[Array, Array]: A tuple ``(mean, var)`` where + - ``mean`` has the same shape as the model outputs for ``x``. + - ``var`` is the elementwise predictive variance + (zero for point estimates such as MAP). + """ + ... + + def predict(self, x: Array, **kwargs) -> Array: + """Deterministic prediction (typically the predictive mean). + + Args: + x (Array): Inputs with leading batch dimension. + **kwargs: Method-specific keyword arguments. + + Returns: + Array: Predicted mean values with the same shape as + ``predictive_mean_var(x)[0]``. + """ + ... + + +@runtime_checkable +class InferenceMethod(Protocol): + """Interface for training probabilistic models. + + Implementations consume a model and supervised data, and return + a `Posterior` object that exposes a method-agnostic predictive API. + """ + + def fit( + self, + model: Any, + train_ds: Dict[str, Array], + valid_ds: Dict[str, Array] | None = None, + **kwargs, + ) -> Posterior: + """Train on supervised data and return a posterior. + + Args: + model (Any): Model or architecture to be trained + (e.g., a Flax `nn.Module`). + train_ds (Dict[str, Array]): Training dataset with keys: + - ``"x"``: inputs (Array with leading batch dimension). + - ``"y"``: targets (Array compatible with ``"x"``). + valid_ds (Dict[str, Array] | None, optional): Validation dataset + in the same format as ``train_ds``. Defaults to None. + **kwargs: Additional method-specific parameters (optimizers, + schedules, seeds, etc.). + + Returns: + Posterior: Posterior object exposing + ``predict`` and ``predictive_mean_var``. + + Notes: + Implementations should avoid side effects outside the returned + posterior object (e.g., no global state). + """ + ... diff --git a/src/pbnn/inference/map/__init__.py b/src/pbnn/inference/map/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pbnn/inference/map/method.py b/src/pbnn/inference/map/method.py new file mode 100644 index 0000000..e678075 --- /dev/null +++ b/src/pbnn/inference/map/method.py @@ -0,0 +1,33 @@ +"""MAP method adapter to the unified API.""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional + +from jax import Array + +from pbnn.core.api import InferenceMethod, Posterior + +from .optimizer import MAPConfig, fit_map + + +@dataclass +class MAP(InferenceMethod): + """Method adapter so MAP conforms to the unified API.""" + + logposterior_fn: Callable[[Dict[str, Any], Dict[str, Array]], Array] + cfg: MAPConfig + + def fit( + self, + model, + train_ds: Dict[str, Array], + valid_ds: Optional[Dict[str, Array]] = None, # noqa: ARG002 + **kwargs, # noqa: ARG002 + ) -> Posterior: + """Fit MAP and return a Posterior object.""" + return fit_map( + logposterior_fn=self.logposterior_fn, + network=model, + train_ds=train_ds, + cfg=self.cfg, + ) diff --git a/src/pbnn/inference/map/optimizer.py b/src/pbnn/inference/map/optimizer.py new file mode 100644 index 0000000..2b135ee --- /dev/null +++ b/src/pbnn/inference/map/optimizer.py @@ -0,0 +1,199 @@ +"""MAP estimation with JAX/Flax: fast, minimal, and API-friendly.""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax +from flax.training import train_state +from jax import Array + +from pbnn.utils.data import NumpyDataset, NumpyLoader + + +@dataclass +class MAPConfig: + """Configuration for MAP training.""" + + learning_rate: float = 1e-3 + optimizer: str = "adam" # "adam" | "sgd" + batch_size: int = 128 + num_epochs: int = 100 + clip_grad_norm: Optional[float] = None + weight_decay: float = 0.0 # decoupled (AdamW-style) if > 0 + seed: int = 0 + + +class PosteriorMAP: + """Posterior wrapper for MAP: point estimate + module apply.""" + + def __init__(self, params, apply_fn: Callable): + self.params = params + self._apply = apply_fn + + def predict(self, x: Array, **apply_kwargs) -> Array: + """Deterministic prediction (mean).""" + return self._apply({"params": self.params}, x, **apply_kwargs) + + def predictive_mean_var(self, x: Array, **apply_kwargs) -> Tuple[Array, Array]: + """(mean, variance) with zero predictive variance at the parameter level.""" + mean = self.predict(x, **apply_kwargs) + var = jnp.zeros_like(mean) + return mean, var + + +def _make_tx(cfg: MAPConfig) -> optax.GradientTransformation: + # LR schedule (constant for now, but keep it a transform to swap later) + schedule = cfg.learning_rate + + # Base optimizer + if cfg.optimizer.lower() == "adam": + base = optax.adam(schedule) + elif cfg.optimizer.lower() == "sgd": + base = optax.sgd(schedule, momentum=0.0, nesterov=False) + else: + raise ValueError(f"Unknown optimizer: {cfg.optimizer}") + + # Optional decoupled weight decay + wd = ( + optax.add_decayed_weights(cfg.weight_decay) + if cfg.weight_decay > 0 + else optax.identity() + ) + + # Optional grad clipping + clip = ( + optax.clip_by_global_norm(cfg.clip_grad_norm) + if cfg.clip_grad_norm + else optax.identity() + ) + + return optax.chain(clip, wd, base) + + +def _create_train_state( + rng: Array, + flax_module: nn.Module, + init_input: Array, + tx: optax.GradientTransformation, +) -> train_state.TrainState: + variables = flax_module.init(rng, init_input) + params = variables["params"] + return train_state.TrainState.create( + apply_fn=flax_module.apply, params=params, tx=tx + ) + + +def fit_map( + *, + logposterior_fn: Callable[[Dict[str, Any], Dict[str, Array]], Array], + network: nn.Module, + train_ds: Dict[str, Array], + cfg: MAPConfig, +) -> PosteriorMAP: + """Estimate MAP params by maximizing a user-provided log-posterior. + + Args: + logposterior_fn: (params, batch) -> scalar log-posterior (sum or mean over batch). + You define model likelihood/prior inside this function. + network: Flax module used only for initialization & apply. + train_ds: dict with "x" and "y" arrays. + cfg: training configuration. + + Returns: + PosteriorMAP: wraps the point-estimate params + apply_fn. + """ + + # Negative log-posterior (we minimize) + def nlp(params, batch): + return -logposterior_fn(params, batch) + + # Single step (jit-able) + nlp_and_grad = jax.value_and_grad(nlp) + + def train_step(state, batch): + loss, grads = nlp_and_grad(state.params, batch) + state = state.apply_gradients(grads=grads) + return state, loss + + @jax.jit + def run_epoch(state, batches): + def step_fn(carry, batch): + state = carry + state, loss = train_step(state, batch) + return state, loss + + state, losses = jax.lax.scan(step_fn, state, batches) + # return last state and mean loss (over steps, then mean across batch dims if any) + return state, jnp.mean(losses) + + # Init state + rng = jax.random.PRNGKey(cfg.seed) + rng, init_rng = jax.random.split(rng) + tx = _make_tx(cfg) + state = _create_train_state(init_rng, network, train_ds["x"][: cfg.batch_size], tx) + + # Pre-batch to enable scan + dataset = NumpyDataset(train_ds["x"], train_ds["y"]) + loader = NumpyLoader( + dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=True + ) + batches = list(loader) + if len(batches) == 0: + raise ValueError("Empty loader: check batch_size vs dataset size.") + # From a list of dicts to dict of stacked arrays with leading 'num_batches' + batches = jax.tree_util.tree_map(lambda *xs: jnp.stack(xs), *batches) + + # Train + for _ in range(cfg.num_epochs): + state, _ = run_epoch(state, batches) + + return PosteriorMAP(state.params, apply_fn=network.apply) + + +# Backwards-compatible surface + + +def create_train_state( + rng: Array, + flax_module: nn.Module, + init_input: Array, + learning_rate: float, + optimizer: Optional[str] = "adam", +) -> train_state.TrainState: + """Kept for backward-compat; prefer _create_train_state + MAPConfig.""" + tx = _make_tx(MAPConfig(learning_rate=learning_rate, optimizer=optimizer or "adam")) + return _create_train_state(rng, flax_module, init_input, tx) + + +def train_fn( + logposterior_fn: Callable, + network: nn.Module, + train_ds: Dict[str, Array], + batch_size: int, + num_epochs: int, + learning_rate: float, + rng_key: Array, # kept for signature parity (unused; see cfg.seed) + optimizer: str = "adam", +): + """Backward-compatible wrapper: returns MAP params. + + Prefer `fit_map(logposterior_fn=..., network=..., train_ds=..., cfg=MAPConfig(...))`. + """ + posterior = fit_map( + logposterior_fn=logposterior_fn, + network=network, + train_ds=train_ds, + cfg=MAPConfig( + learning_rate=learning_rate, + optimizer=optimizer, + batch_size=batch_size, + num_epochs=num_epochs, + seed=int( + jax.random.randint(rng_key, (), 0, 2**31 - 1) + ), # keep rng influence + ), + ) + return posterior.params