From bee9baff64c37926e013f1731afcdcb1a4be0a47 Mon Sep 17 00:00:00 2001 From: amas Date: Tue, 6 May 2025 20:59:21 -0400 Subject: [PATCH 01/13] Add CmdStanMCMC.create_inits() --- cmdstanpy/stanfit/mcmc.py | 40 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index f96ff023..991028ae 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -105,6 +105,46 @@ def __init__( if not self._is_fixed_param: self._check_sampler_diagnostics() + def create_inits( + self, seed: Optional[int] = None, chains: int = 4 + ) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + """ + Create initial values for the parameters of the model by + randomly selecting draws from the MCMC samples. If the samples + contain draws from multiple chains, each draw will be from + a different chain, if possible. Otherwise the chain is randomly + selected. + + :param seed: Used for random selection, defaults to None + :param chains: Number of initial values to return, defaults to 4 + :return: The initial values for the parameters of the model. + + If ``chains`` is 1, a dictionary is returned, otherwise a list + of dictionaries is returned, in the format expected for the + ``inits`` argument. of :meth:`CmdStanModel.sample`. + """ + self._assemble_draws() + rng = np.random.default_rng(seed) + n_draws, n_chains = self._draws.shape[:2] + draw_idxs = rng.choice(n_draws, size=chains, replace=False) + chain_idxs = rng.choice( + n_chains, size=chains, replace=(n_chains <= chains) + ) + if chains == 1: + draw = self._draws[draw_idxs[0], chain_idxs[0]] + return { + name: var.extract_reshape(draw) + for name, var in self._metadata.stan_vars.items() + } + else: + return [ + { + name: var.extract_reshape(self._draws[d, i]) + for name, var in self._metadata.stan_vars.items() + } + for d, i in zip(draw_idxs, chain_idxs) + ] + def __repr__(self) -> str: repr = 'CmdStanMCMC: model={} chains={}{}'.format( self.runset.model, From e0b35fb5cc2dc520b03d9f28e0063ba67b5423f4 Mon Sep 17 00:00:00 2001 From: amas Date: Tue, 6 May 2025 21:00:31 -0400 Subject: [PATCH 02/13] Fix inconsistent string quoting --- cmdstanpy/stanfit/mcmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index 991028ae..c467c796 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -725,7 +725,7 @@ def draws_xr( ) if inc_warmup and not self._save_warmup: get_logger().warning( - "Draws from warmup iterations not available," + 'Draws from warmup iterations not available,' ' must run sampler with "save_warmup=True".' ) if vars is None: From 617b5cba954e38d171db37793c9a10bbbd25ab30 Mon Sep 17 00:00:00 2001 From: amas Date: Tue, 6 May 2025 21:23:13 -0400 Subject: [PATCH 03/13] Add CmdStanMCMC.create_inits() test --- test/test_sample.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/test/test_sample.py b/test/test_sample.py index 3c987ddc..ad548d66 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -2083,3 +2083,33 @@ def test_serialization(stanfile='bernoulli.stan'): assert set(variables1) == set(variables2) for key, value1 in variables1.items(): np.testing.assert_array_equal(value1, variables2[key]) + + +def test_mcmc_create_inits(): + stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') + bern_model = cmdstanpy.CmdStanModel(stan_file=stan) + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + + mcmc = bern_model.sample(data=jdata) + + inits = mcmc.create_inits() + assert isinstance(inits, list) + assert len(inits) == 4 + assert isinstance(inits[0], dict) + assert 'theta' in inits[0] + + inits_10 = mcmc.create_inits(chains=10) + assert isinstance(inits_10, list) + assert len(inits_10) == 10 + + inits_1 = mcmc.create_inits(chains=1) + assert isinstance(inits_1, dict) + assert 'theta' in inits_1 + assert len(inits_1) == 1 + + seeded = mcmc.create_inits(seed=1234) + seeded2 = mcmc.create_inits(seed=1234) + assert all( + init1['theta'] == init2['theta'] + for init1, init2 in zip(seeded, seeded2) + ) From 409b13fef683a64e0d6b35394c54cecf55a3ec17 Mon Sep 17 00:00:00 2001 From: amas Date: Tue, 6 May 2025 21:49:50 -0400 Subject: [PATCH 04/13] Add CmdStanLaplace.create_inits() --- cmdstanpy/stanfit/laplace.py | 35 ++++++++++++++++++++++++++++++++++- test/test_laplace.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/cmdstanpy/stanfit/laplace.py b/cmdstanpy/stanfit/laplace.py index 741593e7..99ce0d53 100644 --- a/cmdstanpy/stanfit/laplace.py +++ b/cmdstanpy/stanfit/laplace.py @@ -1,5 +1,5 @@ """ - Container for the result of running a laplace approximation. +Container for the result of running a laplace approximation. """ from typing import ( @@ -52,6 +52,39 @@ def __init__(self, runset: RunSet, mode: CmdStanMLE) -> None: config = scan_generic_csv(runset.csv_files[0]) self._metadata = InferenceMetadata(config) + def create_inits( + self, seed: Optional[int] = None, chains: int = 4 + ) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + """ + Create initial values for the parameters of the model + by randomly selecting draws from the Laplace approximation. + + :param seed: Used for random selection, defaults to None + :param chains: Number of initial values to return, defaults to 4 + :return: The initial values for the parameters of the model. + + If ``chains`` is 1, a dictionary is returned, otherwise a list + of dictionaries is returned, in the format expected for the + ``inits`` argument. of :meth:`CmdStanModel.sample`. + """ + self._assemble_draws() + rng = np.random.default_rng(seed) + idxs = rng.choice(self._draws.shape[0], size=chains, replace=False) + if chains == 1: + draw = self._draws[idxs[0]] + return { + name: var.extract_reshape(draw) + for name, var in self._metadata.stan_vars.items() + } + else: + return [ + { + name: var.extract_reshape(self._draws[idx]) + for name, var in self._metadata.stan_vars.items() + } + for idx in idxs + ] + def _assemble_draws(self) -> None: if self._draws.shape != (0,): return diff --git a/test/test_laplace.py b/test/test_laplace.py index 14f00eb6..35e5d170 100644 --- a/test/test_laplace.py +++ b/test/test_laplace.py @@ -105,3 +105,33 @@ def test_laplace_outputs(): assert 'x' in fit_pd.columns assert 'y' in fit_pd.columns assert fit_pd['x'].shape == (123,) + + +def test_laplace_create_inits(): + stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') + bern_model = cmdstanpy.CmdStanModel(stan_file=stan) + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + + laplace = bern_model.laplace_sample(data=jdata) + + inits = laplace.create_inits() + assert isinstance(inits, list) + assert len(inits) == 4 + assert isinstance(inits[0], dict) + assert 'theta' in inits[0] + + inits_10 = laplace.create_inits(chains=10) + assert isinstance(inits_10, list) + assert len(inits_10) == 10 + + inits_1 = laplace.create_inits(chains=1) + assert isinstance(inits_1, dict) + assert 'theta' in inits_1 + assert len(inits_1) == 1 + + seeded = laplace.create_inits(seed=1234) + seeded2 = laplace.create_inits(seed=1234) + assert all( + init1['theta'] == init2['theta'] + for init1, init2 in zip(seeded, seeded2) + ) From 62f46d2adc1f3581d1d4f904f96aeebd401ac749 Mon Sep 17 00:00:00 2001 From: amas Date: Tue, 6 May 2025 22:16:21 -0400 Subject: [PATCH 05/13] Add CmdStanMLE.create_inits() --- cmdstanpy/stanfit/mle.py | 26 +++++++++++++++++++++++++- test/test_optimize.py | 23 +++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/cmdstanpy/stanfit/mle.py b/cmdstanpy/stanfit/mle.py index 3a50ba10..bcd003ba 100644 --- a/cmdstanpy/stanfit/mle.py +++ b/cmdstanpy/stanfit/mle.py @@ -1,7 +1,7 @@ """Container for the result of running optimization""" from collections import OrderedDict -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -36,6 +36,30 @@ def __init__(self, runset: RunSet) -> None: self._save_iterations: bool = optimize_args.save_iterations self._set_mle_attrs(runset.csv_files[0]) + def create_inits( + self, chains: int = 4 + ) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + """ + Create initial values for the parameters of the model + from the MLE. + + :param chains: Number of initial values to return, defaults to 4 + :return: The initial values for the parameters of the model. + + If ``chains`` is 1, a dictionary is returned, otherwise a list + of dictionaries is returned, in the format expected for the + ``inits`` argument. of :meth:`CmdStanModel.sample`. + """ + mle_inits = { + name: var.extract_reshape(self.optimized_params_np) + for name, var in self._metadata.stan_vars.items() + } + + if chains == 1: + return mle_inits + else: + return [mle_inits for _ in range(chains)] + def __repr__(self) -> str: repr = 'CmdStanMLE: model={}{}'.format( self.runset.model, self.runset._args.method_args.compose(0, cmd=[]) diff --git a/test/test_optimize.py b/test/test_optimize.py index d5024ff5..7b408d95 100644 --- a/test/test_optimize.py +++ b/test/test_optimize.py @@ -671,3 +671,26 @@ def test_serialization() -> None: np.testing.assert_array_equal( mle1.optimized_params_np, mle2.optimized_params_np ) + + +def test_optimize_create_inits(): + stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') + bern_model = CmdStanModel(stan_file=stan) + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + + mle = bern_model.optimize(data=jdata) + + inits = mle.create_inits() + assert isinstance(inits, list) + assert len(inits) == 4 + assert isinstance(inits[0], dict) + assert 'theta' in inits[0] + + inits_10 = mle.create_inits(chains=10) + assert isinstance(inits_10, list) + assert len(inits_10) == 10 + + inits_1 = mle.create_inits(chains=1) + assert isinstance(inits_1, dict) + assert 'theta' in inits_1 + assert len(inits_1) == 1 From 40c96e3568ed9ef98e15eeddb5cb79dcc0d66442 Mon Sep 17 00:00:00 2001 From: amas Date: Tue, 6 May 2025 22:25:00 -0400 Subject: [PATCH 06/13] Add CmdStanVB.create_inits() --- cmdstanpy/stanfit/vb.py | 37 ++++++++++++++++++++++++++++++++++++- test/test_variational.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/cmdstanpy/stanfit/vb.py b/cmdstanpy/stanfit/vb.py index 102f292c..ed7b5c92 100644 --- a/cmdstanpy/stanfit/vb.py +++ b/cmdstanpy/stanfit/vb.py @@ -1,7 +1,7 @@ """Container for the results of running autodiff variational inference""" from collections import OrderedDict -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -30,6 +30,41 @@ def __init__(self, runset: RunSet) -> None: self.runset = runset self._set_variational_attrs(runset.csv_files[0]) + def create_inits( + self, seed: Optional[int] = None, chains: int = 4 + ) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + """ + Create initial values for the parameters of the model + by randomly selecting draws from the variational approximation + draws. + + :param seed: Used for random selection, defaults to None + :param chains: Number of initial values to return, defaults to 4 + :return: The initial values for the parameters of the model. + + If ``chains`` is 1, a dictionary is returned, otherwise a list + of dictionaries is returned, in the format expected for the + ``inits`` argument. of :meth:`CmdStanModel.sample`. + """ + rng = np.random.default_rng(seed) + idxs = rng.choice( + self.variational_sample.shape[0], size=chains, replace=False + ) + if chains == 1: + draw = self.variational_sample[idxs[0]] + return { + name: var.extract_reshape(draw) + for name, var in self._metadata.stan_vars.items() + } + else: + return [ + { + name: var.extract_reshape(self.variational_sample[idx]) + for name, var in self._metadata.stan_vars.items() + } + for idx in idxs + ] + def __repr__(self) -> str: repr = 'CmdStanVB: model={}{}'.format( self.runset.model, self.runset._args.method_args.compose(0, cmd=[]) diff --git a/test/test_variational.py b/test/test_variational.py index 0b859e2f..c0e39cf7 100644 --- a/test/test_variational.py +++ b/test/test_variational.py @@ -348,3 +348,33 @@ def test_serialization() -> None: variational1.variational_params_dict == variational2.variational_params_dict ) + + +def test_variational_create_inits(): + stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') + bern_model = CmdStanModel(stan_file=stan) + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + + vb = bern_model.variational(data=jdata) + + inits = vb.create_inits() + assert isinstance(inits, list) + assert len(inits) == 4 + assert isinstance(inits[0], dict) + assert 'theta' in inits[0] + + inits_10 = vb.create_inits(chains=10) + assert isinstance(inits_10, list) + assert len(inits_10) == 10 + + inits_1 = vb.create_inits(chains=1) + assert isinstance(inits_1, dict) + assert 'theta' in inits_1 + assert len(inits_1) == 1 + + seeded = vb.create_inits(seed=1234) + seeded2 = vb.create_inits(seed=1234) + assert all( + init1['theta'] == init2['theta'] + for init1, init2 in zip(seeded, seeded2) + ) From ab73751dab4a9a176400b9f2363bda15bd9141bd Mon Sep 17 00:00:00 2001 From: amas Date: Tue, 6 May 2025 22:39:55 -0400 Subject: [PATCH 07/13] Add sampling tests from *.create_inits() for MLE, VB, Laplace, and MCMC --- test/test_laplace.py | 14 ++++++++++++++ test/test_optimize.py | 14 ++++++++++++++ test/test_sample.py | 14 ++++++++++++++ test/test_variational.py | 14 ++++++++++++++ 4 files changed, 56 insertions(+) diff --git a/test/test_laplace.py b/test/test_laplace.py index 35e5d170..14729407 100644 --- a/test/test_laplace.py +++ b/test/test_laplace.py @@ -135,3 +135,17 @@ def test_laplace_create_inits(): init1['theta'] == init2['theta'] for init1, init2 in zip(seeded, seeded2) ) + + +def test_laplace_init_sampling(): + stan = os.path.join(DATAFILES_PATH, 'logistic.stan') + logistic_model = cmdstanpy.CmdStanModel(stan_file=stan) + logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') + + laplace = logistic_model.laplace_sample(data=logistic_data) + inits = laplace.create_inits() + + fit = logistic_model.sample(data=logistic_data, inits=inits) + + assert fit.chains == 4 + assert fit.draws().shape == (1000, 4, 9) diff --git a/test/test_optimize.py b/test/test_optimize.py index 7b408d95..341db684 100644 --- a/test/test_optimize.py +++ b/test/test_optimize.py @@ -694,3 +694,17 @@ def test_optimize_create_inits(): assert isinstance(inits_1, dict) assert 'theta' in inits_1 assert len(inits_1) == 1 + + +def test_optimize_init_sampling(): + stan = os.path.join(DATAFILES_PATH, 'logistic.stan') + logistic_model = CmdStanModel(stan_file=stan) + logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') + + mle = logistic_model.optimize(data=logistic_data) + inits = mle.create_inits() + + fit = logistic_model.sample(data=logistic_data, inits=inits) + + assert fit.chains == 4 + assert fit.draws().shape == (1000, 4, 9) diff --git a/test/test_sample.py b/test/test_sample.py index ad548d66..8ec155d8 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -2113,3 +2113,17 @@ def test_mcmc_create_inits(): init1['theta'] == init2['theta'] for init1, init2 in zip(seeded, seeded2) ) + + +def test_mcmc_init_sampling(): + stan = os.path.join(DATAFILES_PATH, 'logistic.stan') + logistic_model = cmdstanpy.CmdStanModel(stan_file=stan) + logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') + + initial_mcmc = logistic_model.sample(data=logistic_data) + inits = initial_mcmc.create_inits() + + fit = logistic_model.sample(data=logistic_data, inits=inits) + + assert fit.chains == 4 + assert fit.draws().shape == (1000, 4, 9) diff --git a/test/test_variational.py b/test/test_variational.py index c0e39cf7..f16a1e63 100644 --- a/test/test_variational.py +++ b/test/test_variational.py @@ -378,3 +378,17 @@ def test_variational_create_inits(): init1['theta'] == init2['theta'] for init1, init2 in zip(seeded, seeded2) ) + + +def test_variational_init_sampling(): + stan = os.path.join(DATAFILES_PATH, 'logistic.stan') + logistic_model = CmdStanModel(stan_file=stan) + logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') + + vb = logistic_model.sample(data=logistic_data) + inits = vb.create_inits() + + fit = logistic_model.sample(data=logistic_data, inits=inits) + + assert fit.chains == 4 + assert fit.draws().shape == (1000, 4, 9) From 162ed0b4ec2ee849c9b57d7508da6e87a286a1c9 Mon Sep 17 00:00:00 2001 From: amas Date: Wed, 7 May 2025 21:51:12 -0400 Subject: [PATCH 08/13] Update sampler init user's guide example --- .../examples/VI as Sampler Inits.ipynb | 1478 ++++++++++++++++- 1 file changed, 1405 insertions(+), 73 deletions(-) diff --git a/docsrc/users-guide/examples/VI as Sampler Inits.ipynb b/docsrc/users-guide/examples/VI as Sampler Inits.ipynb index 03f5c527..914b0c78 100644 --- a/docsrc/users-guide/examples/VI as Sampler Inits.ipynb +++ b/docsrc/users-guide/examples/VI as Sampler Inits.ipynb @@ -4,13 +4,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Using Variational Estimates to Initialize the NUTS-HMC Sampler\n", + "## Using Estimates from Variational, Laplace, or Optimization Methods to Initialize the NUTS-HMC Sampler\n", "\n", - "In this example we show how to use the parameter estimates return by Stan's variational inference algorithms\n", - "[pathfinder ](https://mc-stan.org/docs/cmdstan-guide/pathfinder-config.html) and \n", - "[ADVI ](https://mc-stan.org/docs/cmdstan-guide/variational_config.html) \n", - "as the initial parameter values for Stan's NUTS-HMC sampler.\n", - "By default, the sampler algorithm randomly initializes all model parameters in the range uniform\\[-2, 2\\]. When the true parameter value is outside of this range, starting from the estimates from Pathfinder and ADVI will speed up and improve adaptation.\n", + "In this example, we show how to use parameter estimates returned by Stan's various posterior approximation or optimization algorithms as initial values for Stan's NUTS-HMC sampler. These include:\n", + "\n", + "* [Pathfinder ](https://mc-stan.org/docs/cmdstan-guide/pathfinder-config.html) \n", + "* [ADVI ](https://mc-stan.org/docs/cmdstan-guide/variational_config.html) \n", + "* [Laplace](https://mc-stan.org/docs/cmdstan-guide/laplace_sample_config.html)\n", + "* [Optimization](https://mc-stan.org/docs/cmdstan-guide/optimize_config.html)\n", + "\n", + "By default, the NUTS-HMC sampler randomly initializes all model parameters uniformly in the interval $(-2, 2)$. If this interval is far from the typical set of the posterior, initializing sampling from these approximation algorithms can speed up and improve adaptation.\n", "\n", "### Model and data\n", "\n", @@ -20,7 +23,7 @@ "a Bayesian standard linear regression model with noninformative priors,\n", "and its corresponding simulated dataset [sblri.json](https://github.com/stan-dev/posteriordb/blob/master/posterior_database/data/data/sblri.json.zip),\n", "which was simulated via script [sblr.R](https://github.com/stan-dev/posteriordb/blob/master/posterior_database/data/data-raw/sblr/sblr.R).\n", - "For conveince, we have copied the posteriordb model and data to this directory, in files `blr.stan` and `sblri.json`." + "For convenience, this example assumes the posteriordb model and data are local, in files `blr.stan` and `sblri.json`." ] }, { @@ -70,22 +73,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Run Stan's `pathfinder` or `variational` algorithm, obtain fitted estimates\n", + "### Demonstration with Stan's `pathfinder` method\n", "\n", - "The [CmdStanModel pathfinder](https://mc-stan.org/cmdstanpy/api.html#cmdstanpy.CmdStanModel.pathfinder ) method\n", - "wraps the CmdStan [pathfinder ](https://mc-stan.org/docs/cmdstan-guide/pathfinder-config.html) method.\n", + "The approximation methods all follow the same general pattern of usage. First, we call the \n", + "corresponding method on the `CmdStanModel` object. From the resulting fit, we call the `.create_inits()` \n", + "method to construct a set of per-chain initializations for the model parameters. To make it explicit, \n", + "we will walk through the process using the `pathfinder` method (which wraps the \n", + "CmdStan [pathfinder ](https://mc-stan.org/docs/cmdstan-guide/pathfinder-config.html) method).\n", "\n", "Pathfinder locates normal approximations to the target\n", "density along a quasi-Newton optimization path, with local covariance\n", "estimated using the negative inverse Hessian estimates produced by the\n", - "LBFGS optimizer. Pathfinder returns draws from the Gaussian approximation\n", + "LBFGS optimizer. Pathfinder returns draws from the Gaussian approximation\n", "with the lowest estimated Kullback-Leibler (KL) divergence to the true\n", "posterior.\n", - "By default, CmdStanPy runs multi-path Pathfinder which returns an importance-resampled set of draws over the outputs of 4 independent single-path Pathfinders.\n", + "By default, CmdStanPy runs multi-path Pathfinder which returns an importance-resampled \n", + "set of draws over the outputs of 4 independent single-path Pathfinders.\n", "This better matches non-normal target densities and also mitigates\n", "the problem of L-BFGS getting stuck at local optima or in saddle points on plateaus.\n", "\n", - "The method [create_inits](https://mc-stan.org/cmdstanpy/api.html#cmdstanpy.CmdStanPathfinder.create_inits) returns a Python Dict containing a set of per-chain initializations for the model parameters. Each set of initializations is a random draw from the Pathfinder sample." + "We obtain Pathfinder estimates by calling the `.pathfinder()` method which returns a `CmdStanPathfinder` object:" ] }, { @@ -97,8 +104,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "17:01:02 - cmdstanpy - INFO - Chain [1] start processing\n", - "17:01:02 - cmdstanpy - INFO - Chain [1] done processing\n" + "21:47:10 - cmdstanpy - INFO - Chain [1] start processing\n", + "21:47:10 - cmdstanpy - INFO - Chain [1] done processing\n" ] } ], @@ -110,12 +117,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Posteriordb provides reference posteriors for all models. For the blr model, conditioned on the dataset `sblri.json`, the reference posteriors are in file [sblri-blr.json](https://github.com/stan-dev/posteriordb/blob/master/posterior_database/reference_posteriors/summary_statistics/mean/mean/sblri-blr.json)\n", + "Posteriordb provides reference posteriors for all models. For the blr model, conditioned on the dataset `sblri.json`, the reference posteriors can be found in the [sblri-blr.json](https://github.com/stan-dev/posteriordb/blob/master/posterior_database/reference_posteriors/summary_statistics/mean/mean/sblri-blr.json) file.\n", "\n", "The reference posteriors for all elements of `beta` and `sigma` are all very close to $1.0$.\n", "\n", - "The experiments reported in Figure 3 of the paper [Pathfinder: Parallel quasi-Newton variational inference](https://arxiv.org/abs/2108.03782) by Zhang et al. show that Pathfinder provides a better estimate of the posterior, as measured by the 1-Wasserstein distance to the reference posterior, than 75 iterations of the warmup Phase I algorithm used by the NUTS-HMC sampler.\n", - "furthermore, Pathfinder is more computationally efficient, requiring fewer evaluations of the log density and gradient functions. Therefore, using the estimates from Pathfinder to initialize the parameter values for the NUTS-HMC sampler will allow the sampler to do a better job of adapting the stepsize and metric during warmup, resulting in better performance and estimation." + "The experiments reported in Figure 3 of the paper [Pathfinder: Parallel quasi-Newton variational inference](https://arxiv.org/abs/2108.03782) by Zhang et al. show that Pathfinder provides a better estimate of the posterior, as measured by the 1-Wasserstein distance to the reference posterior, than 75 iterations of the warmup Phase I algorithm used by the NUTS-HMC sampler.\n", + "Furthermore, Pathfinder is more computationally efficient, requiring fewer evaluations of the log density and gradient functions. Therefore, using the Pathfinder estimates to initialize the parameter values for the NUTS-HMC sampler can allow the sampler to do a better job of adapting the stepsize and metric during warmup, resulting in better performance and estimation.\n", + "\n", + "We construct the parameter inits for full MCMC sampling below. The `.create_inits()` default behavior is to create inits for four chains to correspond with the sampling defaults. You can requests more or less by modifying the `chains` keyword argument." ] }, { @@ -127,13 +136,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "[{'beta': array([0.996649, 0.999455, 1.00093 , 0.99873 , 1.00207 ]), 'sigma': array(0.934232)}, {'beta': array([1.00016 , 0.998764, 1.00055 , 1.00212 , 1.00047 ]), 'sigma': array(1.04441)}, {'beta': array([1.00139 , 0.997917, 1.00134 , 1.00123 , 1.00116 ]), 'sigma': array(0.946814)}, {'beta': array([0.999491, 0.999225, 1.00114 , 0.999147, 0.998943]), 'sigma': array(0.977812)}]\n" + "{'beta': array([1.00019 , 0.999721, 0.999143, 1.00221 , 1.00193 ]), 'sigma': array(0.955428)}\n", + "{'beta': array([0.998311, 1.00282 , 1.00017 , 1.00119 , 1.00148 ]), 'sigma': array(0.829495)}\n", + "{'beta': array([1.0007 , 1.00177 , 0.999522, 1.00289 , 0.999926]), 'sigma': array(0.904491)}\n", + "{'beta': array([0.998958, 1.00013 , 1.00095 , 0.999549, 1.00184 ]), 'sigma': array(0.895088)}\n" ] } ], "source": [ "pathfinder_inits = pathfinder_fit.create_inits()\n", - "print(pathfinder_inits)" + "for chain_init in pathfinder_inits:\n", + " print(chain_init)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the Pathfinder inits are close the reference posteriors for the parameters. To use these inits, we pass the `pathfinder_inits` object to the `inits` kwarg:" ] }, { @@ -145,18 +165,18 @@ "name": "stderr", "output_type": "stream", "text": [ - "17:01:05 - cmdstanpy - INFO - CmdStan start processing\n" + "21:47:10 - cmdstanpy - INFO - CmdStan start processing\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d8a75128e05e4cf88f037897a38d0173", + "model_id": "ab1a8782ca28489ba4c18fa4483fb85a", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "chain 1 | | 00:00 Status" + "chain 1: 0%| | 0/1075 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MeanMCSEStdDev5%50%95%N_EffN_Eff/sR_hat
lp__-156.9230000.0697911.822880-160.397000-156.590000-154.59900682.212003832.650001.009740
beta[1]0.9994750.0000140.0009790.9978220.9994711.001114972.6900027936.500000.999750
beta[2]1.0002400.0000180.0011690.9983251.0002201.002164377.8800024594.900000.999390
beta[3]1.0004300.0000140.0009570.9988181.0004401.001984804.9000026993.800000.999950
beta[4]1.0011500.0000160.0010620.9994481.0011501.002914289.5500024098.600000.999605
beta[5]1.0015800.0000150.0010380.9998961.0015701.003284676.1400026270.400001.000200
sigma0.9621700.0042100.0713300.8542700.9591301.08159286.756631610.992311.010100
\n", + "" + ], + "text/plain": [ + " Mean MCSE StdDev 5% 50% 95% \\\n", + "lp__ -156.923000 0.069791 1.822880 -160.397000 -156.590000 -154.59900 \n", + "beta[1] 0.999475 0.000014 0.000979 0.997822 0.999471 1.00111 \n", + "beta[2] 1.000240 0.000018 0.001169 0.998325 1.000220 1.00216 \n", + "beta[3] 1.000430 0.000014 0.000957 0.998818 1.000440 1.00198 \n", + "beta[4] 1.001150 0.000016 0.001062 0.999448 1.001150 1.00291 \n", + "beta[5] 1.001580 0.000015 0.001038 0.999896 1.001570 1.00328 \n", + "sigma 0.962170 0.004210 0.071330 0.854270 0.959130 1.08159 \n", + "\n", + " N_Eff N_Eff/s R_hat \n", + "lp__ 682.21200 3832.65000 1.009740 \n", + "beta[1] 4972.69000 27936.50000 0.999750 \n", + "beta[2] 4377.88000 24594.90000 0.999390 \n", + "beta[3] 4804.90000 26993.80000 0.999950 \n", + "beta[4] 4289.55000 24098.60000 0.999605 \n", + "beta[5] 4676.14000 26270.40000 1.000200 \n", + "sigma 286.75663 1610.99231 1.010100 " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "mcmc_pathfinder_inits_fit.summary()" ] @@ -250,32 +451,301 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Using the default random parameter initializations, we need to run more warmup iterations. If we only run 75 warmup iterations with random inits, the result fails to estimate `sigma` correctly. It is necessary to run the model with at least 150 warmup iterations to produce a good set of estimates." + "If we were to instead use the default random parameter initializations, we would need to run more warmup iterations to produce useful samples. For example, if we only run the same 75 warmup iterations with random inits, the result fails to estimate `sigma` correctly: " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "21:47:11 - cmdstanpy - INFO - CmdStan start processing\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ab7377170b7149feb0fbf4ff025920d9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "chain 1: 0%| | 0/1075 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MeanMCSEStdDev5%50%95%N_EffN_Eff/sR_hat
lp__-191.92800024.77020035.245800-232.300000-165.541000-155.461002.0246735.5206012.75920
beta[1]0.9998590.0001480.0020860.9966080.9999241.00357198.160003476.480001.03553
beta[2]0.9999490.0002660.0026750.9960960.9995251.00457101.400001778.950001.04521
beta[3]1.0004800.0001470.0021870.9969661.0004101.00452222.452003902.660001.00934
beta[4]1.0014100.0001770.0026590.9964441.0016601.00611225.987003964.680001.03112
beta[5]1.0016800.0001920.0025250.9977401.0012901.00655173.717003047.670001.03154
sigma1.9789200.7140101.0188800.9174902.7005703.173462.0362735.7240310.27506
\n", + "" + ], + "text/plain": [ + " Mean MCSE StdDev 5% 50% 95% \\\n", + "lp__ -191.928000 24.770200 35.245800 -232.300000 -165.541000 -155.46100 \n", + "beta[1] 0.999859 0.000148 0.002086 0.996608 0.999924 1.00357 \n", + "beta[2] 0.999949 0.000266 0.002675 0.996096 0.999525 1.00457 \n", + "beta[3] 1.000480 0.000147 0.002187 0.996966 1.000410 1.00452 \n", + "beta[4] 1.001410 0.000177 0.002659 0.996444 1.001660 1.00611 \n", + "beta[5] 1.001680 0.000192 0.002525 0.997740 1.001290 1.00655 \n", + "sigma 1.978920 0.714010 1.018880 0.917490 2.700570 3.17346 \n", + "\n", + " N_Eff N_Eff/s R_hat \n", + "lp__ 2.02467 35.52060 12.75920 \n", + "beta[1] 198.16000 3476.48000 1.03553 \n", + "beta[2] 101.40000 1778.95000 1.04521 \n", + "beta[3] 222.45200 3902.66000 1.00934 \n", + "beta[4] 225.98700 3964.68000 1.03112 \n", + "beta[5] 173.71700 3047.67000 1.03154 \n", + "sigma 2.03627 35.72403 10.27506 " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "mcmc_random_inits_fit.summary()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing csv files: /tmp/tmplo9k1zwz/blr8kgp791v/blr-20250507214711_1.csv, /tmp/tmplo9k1zwz/blr8kgp791v/blr-20250507214711_2.csv, /tmp/tmplo9k1zwz/blr8kgp791v/blr-20250507214711_3.csv, /tmp/tmplo9k1zwz/blr8kgp791v/blr-20250507214711_4.csv\n", + "\n", + "Checking sampler transitions treedepth.\n", + "Treedepth satisfactory for all transitions.\n", + "\n", + "Checking sampler transitions for divergences.\n", + "597 of 4000 (14.93%) transitions ended with a divergence.\n", + "These divergent transitions indicate that HMC is not fully able to explore the posterior distribution.\n", + "Try increasing adapt delta closer to 1.\n", + "If this doesn't remove all divergences, try to reparameterize the model.\n", + "\n", + "Checking E-BFMI - sampler transitions HMC potential energy.\n", + "The E-BFMI, 0.01, is below the nominal threshold of 0.30 which suggests that HMC may have trouble exploring the target distribution.\n", + "If possible, try to reparameterize the model.\n", + "\n", + "The following parameters had fewer than 0.001 effective draws per transition:\n", + " sigma\n", + "Such low values indicate that the effective sample size estimators may be biased high and actual performance may be substantially lower than quoted.\n", + "\n", + "The following parameters had split R-hat greater than 1.05:\n", + " sigma\n", + "Such high values indicate incomplete mixing and biased estimation.\n", + "You should consider regularizating your model with additional prior information or a more effective parameterization.\n", + "\n", + "Processing complete.\n", + "\n" + ] + } + ], "source": [ "print(mcmc_random_inits_fit.diagnose())" ] @@ -284,74 +754,936 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The `CmdStanModel` method `variational` runs CmdStan's ADVI algorithm.\n", - "Because this algorithm is unstable and may fail to converge, we run it with argument `require_converged` set to `False`. We also specify a seed, to avoid instabilities as well as for reproducibility." + "The diagnostics clearly indicate problems with estimating `sigma`. In this case, it is necessary to run the model with at least 150 warmup iterations to produce a good set of estimates when starting from the default initialization." ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "vb_fit = model.variational(data=data_file, require_converged=False, seed=123)" + "### Other approximation algorithms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The ADVI algorithm provides estimates of all model parameters.\n", - "\n", - "The `variational` method returns a `CmdStanVB` object, with method `stan_variables`, which\n", - "returns the approximat posterior samples of all model parameters as a Python dictionary. \n", - "Here, we report the approximate posterior mean." + "We can follow the same pattern with Stan's ADVI algorithm by first using the `CmdStanModel.variationl` method. Because this algorithm is unstable and may fail to converge, we run it with argument `require_converged` set to `False`. We also specify a seed, to avoid instabilities as well as for reproducibility." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "21:47:12 - cmdstanpy - INFO - Chain [1] start processing\n", + "21:47:12 - cmdstanpy - INFO - Chain [1] done processing\n", + "21:47:12 - cmdstanpy - WARNING - The algorithm may not have converged.\n", + "Proceeding because require_converged is set to False\n" + ] + } + ], "source": [ - "vb_mean = {var: samples.mean(axis=0) for var, samples in vb_fit.stan_variables(mean=False).items()}\n", - "print(vb_mean)" + "vb_fit = model.variational(data=data_file, require_converged=False, seed=123)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "mcmc_vb_inits_fit = model.sample(\n", - " data=data_file, inits=vb_mean, iter_warmup=75, seed=12345\n", - ")" + "The ADVI algorithm provides estimates of all model parameters.\n", + "\n", + "The `variational` method returns a `CmdStanVB` object, which similarly can construct a set of inits with the `.create_inits()` method:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'beta': array([0.997148, 0.992516, 0.991829, 0.991095, 1.01057 ]), 'sigma': array(1.84433)}\n", + "{'beta': array([0.996249, 0.990954, 0.992313, 0.993369, 1.01155 ]), 'sigma': array(1.92087)}\n", + "{'beta': array([0.997361, 0.992357, 0.989631, 0.995749, 1.0083 ]), 'sigma': array(1.49741)}\n", + "{'beta': array([0.995738, 0.994643, 0.993908, 0.993482, 1.00921 ]), 'sigma': array(1.60191)}\n" + ] + } + ], "source": [ - "mcmc_vb_inits_fit.summary()" + "vb_inits = vb_fit.create_inits()\n", + "for chain_init in vb_inits:\n", + " print(chain_init)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The sampler estimates match the reference posterior." + "Which can be passed to the `inits` keyword argument of the sample method." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "21:47:13 - cmdstanpy - INFO - CmdStan start processing\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1a23f210d88248d7b86353a78ecd3281", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "chain 1: 0%| | 0/1075 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MeanMCSEStdDev5%50%95%N_EffN_Eff/sR_hat
lp__-156.9260000.0622531.775690-160.352000-156.589000-154.65400813.615004596.700001.002930
beta[1]0.9994890.0000140.0009550.9979170.9994941.001064623.0500026118.900000.999487
beta[2]1.0002400.0000180.0011480.9983631.0002601.002144167.4000023544.600001.000340
beta[3]1.0004200.0000140.0009410.9988771.0004201.001964680.4600026443.300000.999984
beta[4]1.0011500.0000170.0010820.9993941.0011301.002984017.7500022699.200000.999891
beta[5]1.0015700.0000160.0010720.9998361.0015901.003324777.0600026989.100001.000730
sigma0.9627400.0047200.0724700.8493000.9599501.08129235.938761332.987341.012130
\n", + "" + ], + "text/plain": [ + " Mean MCSE StdDev 5% 50% 95% \\\n", + "lp__ -156.926000 0.062253 1.775690 -160.352000 -156.589000 -154.65400 \n", + "beta[1] 0.999489 0.000014 0.000955 0.997917 0.999494 1.00106 \n", + "beta[2] 1.000240 0.000018 0.001148 0.998363 1.000260 1.00214 \n", + "beta[3] 1.000420 0.000014 0.000941 0.998877 1.000420 1.00196 \n", + "beta[4] 1.001150 0.000017 0.001082 0.999394 1.001130 1.00298 \n", + "beta[5] 1.001570 0.000016 0.001072 0.999836 1.001590 1.00332 \n", + "sigma 0.962740 0.004720 0.072470 0.849300 0.959950 1.08129 \n", + "\n", + " N_Eff N_Eff/s R_hat \n", + "lp__ 813.61500 4596.70000 1.002930 \n", + "beta[1] 4623.05000 26118.90000 0.999487 \n", + "beta[2] 4167.40000 23544.60000 1.000340 \n", + "beta[3] 4680.46000 26443.30000 0.999984 \n", + "beta[4] 4017.75000 22699.20000 0.999891 \n", + "beta[5] 4777.06000 26989.10000 1.000730 \n", + "sigma 235.93876 1332.98734 1.012130 " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mcmc_vb_inits_fit.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The sampler estimates match the reference posterior with no diagnostic issues." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inits can also be constructed from the `laplace` method:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "21:47:14 - cmdstanpy - INFO - Chain [1] start processing\n", + "21:47:14 - cmdstanpy - INFO - Chain [1] done processing\n", + "21:47:14 - cmdstanpy - INFO - Chain [1] start processing\n", + "21:47:14 - cmdstanpy - INFO - Chain [1] done processing\n" + ] + } + ], + "source": [ + "laplace_inits = model.laplace_sample(data=data_file, seed=123).create_inits()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "21:47:14 - cmdstanpy - INFO - CmdStan start processing\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c157bc1c9be442e4b12038e5d0c152e1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "chain 1: 0%| | 0/1075 [00:00 Date: Thu, 8 May 2025 16:36:47 -0400 Subject: [PATCH 09/13] Fix errant period in docstrings --- cmdstanpy/stanfit/laplace.py | 2 +- cmdstanpy/stanfit/mcmc.py | 2 +- cmdstanpy/stanfit/mle.py | 2 +- cmdstanpy/stanfit/pathfinder.py | 2 +- cmdstanpy/stanfit/vb.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cmdstanpy/stanfit/laplace.py b/cmdstanpy/stanfit/laplace.py index 99ce0d53..00e5199a 100644 --- a/cmdstanpy/stanfit/laplace.py +++ b/cmdstanpy/stanfit/laplace.py @@ -65,7 +65,7 @@ def create_inits( If ``chains`` is 1, a dictionary is returned, otherwise a list of dictionaries is returned, in the format expected for the - ``inits`` argument. of :meth:`CmdStanModel.sample`. + ``inits`` argument of :meth:`CmdStanModel.sample`. """ self._assemble_draws() rng = np.random.default_rng(seed) diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index c467c796..15cb91d9 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -121,7 +121,7 @@ def create_inits( If ``chains`` is 1, a dictionary is returned, otherwise a list of dictionaries is returned, in the format expected for the - ``inits`` argument. of :meth:`CmdStanModel.sample`. + ``inits`` argument of :meth:`CmdStanModel.sample`. """ self._assemble_draws() rng = np.random.default_rng(seed) diff --git a/cmdstanpy/stanfit/mle.py b/cmdstanpy/stanfit/mle.py index bcd003ba..6db4b15b 100644 --- a/cmdstanpy/stanfit/mle.py +++ b/cmdstanpy/stanfit/mle.py @@ -48,7 +48,7 @@ def create_inits( If ``chains`` is 1, a dictionary is returned, otherwise a list of dictionaries is returned, in the format expected for the - ``inits`` argument. of :meth:`CmdStanModel.sample`. + ``inits`` argument of :meth:`CmdStanModel.sample`. """ mle_inits = { name: var.extract_reshape(self.optimized_params_np) diff --git a/cmdstanpy/stanfit/pathfinder.py b/cmdstanpy/stanfit/pathfinder.py index 8e63f85f..5ac4d213 100644 --- a/cmdstanpy/stanfit/pathfinder.py +++ b/cmdstanpy/stanfit/pathfinder.py @@ -45,7 +45,7 @@ def create_inits( If ``chains`` is 1, a dictionary is returned, otherwise a list of dictionaries is returned, in the format expected for the - ``inits`` argument. of :meth:`CmdStanModel.sample`. + ``inits`` argument of :meth:`CmdStanModel.sample`. """ self._assemble_draws() rng = np.random.default_rng(seed) diff --git a/cmdstanpy/stanfit/vb.py b/cmdstanpy/stanfit/vb.py index ed7b5c92..8d7ac552 100644 --- a/cmdstanpy/stanfit/vb.py +++ b/cmdstanpy/stanfit/vb.py @@ -44,7 +44,7 @@ def create_inits( If ``chains`` is 1, a dictionary is returned, otherwise a list of dictionaries is returned, in the format expected for the - ``inits`` argument. of :meth:`CmdStanModel.sample`. + ``inits`` argument of :meth:`CmdStanModel.sample`. """ rng = np.random.default_rng(seed) idxs = rng.choice( From 65ffc9b5cf1e193bfe3e4eb94563c0a0a266e5af Mon Sep 17 00:00:00 2001 From: amas Date: Thu, 8 May 2025 17:11:31 -0400 Subject: [PATCH 10/13] Update CmdStanMLE.create_inits() compatibility This change simplifies the create_inits() for the MLE params to always return a dictionary of inits, rather than a list for multiple chains. The default behavior of the sample method is to initialize all chains at the same init if only one is given per param. The chains paramter is kept and the seed parameter is added to the signature, despite being no-ops, for the purposes of maintaining uniformity across the other create_inits() methods on other stanfit objects. --- cmdstanpy/stanfit/mle.py | 30 +++++++++++++++--------------- test/test_optimize.py | 16 +++------------- 2 files changed, 18 insertions(+), 28 deletions(-) diff --git a/cmdstanpy/stanfit/mle.py b/cmdstanpy/stanfit/mle.py index 6db4b15b..fd599dbf 100644 --- a/cmdstanpy/stanfit/mle.py +++ b/cmdstanpy/stanfit/mle.py @@ -1,7 +1,7 @@ """Container for the result of running optimization""" from collections import OrderedDict -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import pandas as pd @@ -37,28 +37,28 @@ def __init__(self, runset: RunSet) -> None: self._set_mle_attrs(runset.csv_files[0]) def create_inits( - self, chains: int = 4 - ) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + self, seed: Optional[int] = None, chains: int = 4 + ) -> Dict[str, np.ndarray]: """ Create initial values for the parameters of the model from the MLE. - :param chains: Number of initial values to return, defaults to 4 + :param seed: Unused. Kept for compatibility with other + create_inits methods. + :param chains: Unused. Kept for compatibility with other + create_inits methods. :return: The initial values for the parameters of the model. - If ``chains`` is 1, a dictionary is returned, otherwise a list - of dictionaries is returned, in the format expected for the - ``inits`` argument of :meth:`CmdStanModel.sample`. + Returns a dictionary of MLE estimates in the format expected + for the ``inits`` argument of :meth:`CmdStanModel.sample`. + When running multi-chain sampling, all chains will be initialized + at the same points. """ - mle_inits = { - name: var.extract_reshape(self.optimized_params_np) - for name, var in self._metadata.stan_vars.items() - } + # pylint: disable=unused-argument - if chains == 1: - return mle_inits - else: - return [mle_inits for _ in range(chains)] + return { + name: np.array(val) for name, val in self.stan_variables().items() + } def __repr__(self) -> str: repr = 'CmdStanMLE: model={}{}'.format( diff --git a/test/test_optimize.py b/test/test_optimize.py index 341db684..034aca87 100644 --- a/test/test_optimize.py +++ b/test/test_optimize.py @@ -681,19 +681,9 @@ def test_optimize_create_inits(): mle = bern_model.optimize(data=jdata) inits = mle.create_inits() - assert isinstance(inits, list) - assert len(inits) == 4 - assert isinstance(inits[0], dict) - assert 'theta' in inits[0] - - inits_10 = mle.create_inits(chains=10) - assert isinstance(inits_10, list) - assert len(inits_10) == 10 - - inits_1 = mle.create_inits(chains=1) - assert isinstance(inits_1, dict) - assert 'theta' in inits_1 - assert len(inits_1) == 1 + assert isinstance(inits, dict) + assert 'theta' in inits + assert len(inits) == 1 def test_optimize_init_sampling(): From 8ce6ebe7ab5aa5d16c0d7533e1ad9f172a0e69e7 Mon Sep 17 00:00:00 2001 From: amas Date: Thu, 8 May 2025 17:26:35 -0400 Subject: [PATCH 11/13] Strip output from sampler inits notebook --- .../examples/VI as Sampler Inits.ipynb | 1390 +---------------- 1 file changed, 46 insertions(+), 1344 deletions(-) diff --git a/docsrc/users-guide/examples/VI as Sampler Inits.ipynb b/docsrc/users-guide/examples/VI as Sampler Inits.ipynb index 914b0c78..96ff8037 100644 --- a/docsrc/users-guide/examples/VI as Sampler Inits.ipynb +++ b/docsrc/users-guide/examples/VI as Sampler Inits.ipynb @@ -28,35 +28,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "data {\n", - " int N;\n", - " int D;\n", - " matrix[N, D] X;\n", - " vector[N] y;\n", - "}\n", - "parameters {\n", - " vector[D] beta;\n", - " real sigma;\n", - "}\n", - "model {\n", - " // prior\n", - " target += normal_lpdf(beta | 0, 10);\n", - " target += normal_lpdf(sigma | 0, 10);\n", - " // likelihood\n", - " target += normal_lpdf(y | X * beta, sigma);\n", - "}\n", - "\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "from cmdstanpy import CmdStanModel\n", @@ -97,18 +71,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "21:47:10 - cmdstanpy - INFO - Chain [1] start processing\n", - "21:47:10 - cmdstanpy - INFO - Chain [1] done processing\n" - ] - } - ], + "outputs": [], "source": [ "pathfinder_fit = model.pathfinder(data=data_file, seed=123)" ] @@ -129,20 +94,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'beta': array([1.00019 , 0.999721, 0.999143, 1.00221 , 1.00193 ]), 'sigma': array(0.955428)}\n", - "{'beta': array([0.998311, 1.00282 , 1.00017 , 1.00119 , 1.00148 ]), 'sigma': array(0.829495)}\n", - "{'beta': array([1.0007 , 1.00177 , 0.999522, 1.00289 , 0.999926]), 'sigma': array(0.904491)}\n", - "{'beta': array([0.998958, 1.00013 , 1.00095 , 0.999549, 1.00184 ]), 'sigma': array(0.895088)}\n" - ] - } - ], + "outputs": [], "source": [ "pathfinder_inits = pathfinder_fit.create_inits()\n", "for chain_init in pathfinder_inits:\n", @@ -158,94 +112,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "21:47:10 - cmdstanpy - INFO - CmdStan start processing\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ab1a8782ca28489ba4c18fa4483fb85a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "chain 1: 0%| | 0/1075 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
MeanMCSEStdDev5%50%95%N_EffN_Eff/sR_hat
lp__-156.9230000.0697911.822880-160.397000-156.590000-154.59900682.212003832.650001.009740
beta[1]0.9994750.0000140.0009790.9978220.9994711.001114972.6900027936.500000.999750
beta[2]1.0002400.0000180.0011690.9983251.0002201.002164377.8800024594.900000.999390
beta[3]1.0004300.0000140.0009570.9988181.0004401.001984804.9000026993.800000.999950
beta[4]1.0011500.0000160.0010620.9994481.0011501.002914289.5500024098.600000.999605
beta[5]1.0015800.0000150.0010380.9998961.0015701.003284676.1400026270.400001.000200
sigma0.9621700.0042100.0713300.8542700.9591301.08159286.756631610.992311.010100
\n", - "" - ], - "text/plain": [ - " Mean MCSE StdDev 5% 50% 95% \\\n", - "lp__ -156.923000 0.069791 1.822880 -160.397000 -156.590000 -154.59900 \n", - "beta[1] 0.999475 0.000014 0.000979 0.997822 0.999471 1.00111 \n", - "beta[2] 1.000240 0.000018 0.001169 0.998325 1.000220 1.00216 \n", - "beta[3] 1.000430 0.000014 0.000957 0.998818 1.000440 1.00198 \n", - "beta[4] 1.001150 0.000016 0.001062 0.999448 1.001150 1.00291 \n", - "beta[5] 1.001580 0.000015 0.001038 0.999896 1.001570 1.00328 \n", - "sigma 0.962170 0.004210 0.071330 0.854270 0.959130 1.08159 \n", - "\n", - " N_Eff N_Eff/s R_hat \n", - "lp__ 682.21200 3832.65000 1.009740 \n", - "beta[1] 4972.69000 27936.50000 0.999750 \n", - "beta[2] 4377.88000 24594.90000 0.999390 \n", - "beta[3] 4804.90000 26993.80000 0.999950 \n", - "beta[4] 4289.55000 24098.60000 0.999605 \n", - "beta[5] 4676.14000 26270.40000 1.000200 \n", - "sigma 286.75663 1610.99231 1.010100 " - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "mcmc_pathfinder_inits_fit.summary()" ] @@ -456,296 +155,27 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "21:47:11 - cmdstanpy - INFO - CmdStan start processing\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ab7377170b7149feb0fbf4ff025920d9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "chain 1: 0%| | 0/1075 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
MeanMCSEStdDev5%50%95%N_EffN_Eff/sR_hat
lp__-191.92800024.77020035.245800-232.300000-165.541000-155.461002.0246735.5206012.75920
beta[1]0.9998590.0001480.0020860.9966080.9999241.00357198.160003476.480001.03553
beta[2]0.9999490.0002660.0026750.9960960.9995251.00457101.400001778.950001.04521
beta[3]1.0004800.0001470.0021870.9969661.0004101.00452222.452003902.660001.00934
beta[4]1.0014100.0001770.0026590.9964441.0016601.00611225.987003964.680001.03112
beta[5]1.0016800.0001920.0025250.9977401.0012901.00655173.717003047.670001.03154
sigma1.9789200.7140101.0188800.9174902.7005703.173462.0362735.7240310.27506
\n", - "" - ], - "text/plain": [ - " Mean MCSE StdDev 5% 50% 95% \\\n", - "lp__ -191.928000 24.770200 35.245800 -232.300000 -165.541000 -155.46100 \n", - "beta[1] 0.999859 0.000148 0.002086 0.996608 0.999924 1.00357 \n", - "beta[2] 0.999949 0.000266 0.002675 0.996096 0.999525 1.00457 \n", - "beta[3] 1.000480 0.000147 0.002187 0.996966 1.000410 1.00452 \n", - "beta[4] 1.001410 0.000177 0.002659 0.996444 1.001660 1.00611 \n", - "beta[5] 1.001680 0.000192 0.002525 0.997740 1.001290 1.00655 \n", - "sigma 1.978920 0.714010 1.018880 0.917490 2.700570 3.17346 \n", - "\n", - " N_Eff N_Eff/s R_hat \n", - "lp__ 2.02467 35.52060 12.75920 \n", - "beta[1] 198.16000 3476.48000 1.03553 \n", - "beta[2] 101.40000 1778.95000 1.04521 \n", - "beta[3] 222.45200 3902.66000 1.00934 \n", - "beta[4] 225.98700 3964.68000 1.03112 \n", - "beta[5] 173.71700 3047.67000 1.03154 \n", - "sigma 2.03627 35.72403 10.27506 " - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "mcmc_random_inits_fit.summary()" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processing csv files: /tmp/tmplo9k1zwz/blr8kgp791v/blr-20250507214711_1.csv, /tmp/tmplo9k1zwz/blr8kgp791v/blr-20250507214711_2.csv, /tmp/tmplo9k1zwz/blr8kgp791v/blr-20250507214711_3.csv, /tmp/tmplo9k1zwz/blr8kgp791v/blr-20250507214711_4.csv\n", - "\n", - "Checking sampler transitions treedepth.\n", - "Treedepth satisfactory for all transitions.\n", - "\n", - "Checking sampler transitions for divergences.\n", - "597 of 4000 (14.93%) transitions ended with a divergence.\n", - "These divergent transitions indicate that HMC is not fully able to explore the posterior distribution.\n", - "Try increasing adapt delta closer to 1.\n", - "If this doesn't remove all divergences, try to reparameterize the model.\n", - "\n", - "Checking E-BFMI - sampler transitions HMC potential energy.\n", - "The E-BFMI, 0.01, is below the nominal threshold of 0.30 which suggests that HMC may have trouble exploring the target distribution.\n", - "If possible, try to reparameterize the model.\n", - "\n", - "The following parameters had fewer than 0.001 effective draws per transition:\n", - " sigma\n", - "Such low values indicate that the effective sample size estimators may be biased high and actual performance may be substantially lower than quoted.\n", - "\n", - "The following parameters had split R-hat greater than 1.05:\n", - " sigma\n", - "Such high values indicate incomplete mixing and biased estimation.\n", - "You should consider regularizating your model with additional prior information or a more effective parameterization.\n", - "\n", - "Processing complete.\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(mcmc_random_inits_fit.diagnose())" ] @@ -773,20 +203,9 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "21:47:12 - cmdstanpy - INFO - Chain [1] start processing\n", - "21:47:12 - cmdstanpy - INFO - Chain [1] done processing\n", - "21:47:12 - cmdstanpy - WARNING - The algorithm may not have converged.\n", - "Proceeding because require_converged is set to False\n" - ] - } - ], + "outputs": [], "source": [ "vb_fit = model.variational(data=data_file, require_converged=False, seed=123)" ] @@ -802,20 +221,9 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'beta': array([0.997148, 0.992516, 0.991829, 0.991095, 1.01057 ]), 'sigma': array(1.84433)}\n", - "{'beta': array([0.996249, 0.990954, 0.992313, 0.993369, 1.01155 ]), 'sigma': array(1.92087)}\n", - "{'beta': array([0.997361, 0.992357, 0.989631, 0.995749, 1.0083 ]), 'sigma': array(1.49741)}\n", - "{'beta': array([0.995738, 0.994643, 0.993908, 0.993482, 1.00921 ]), 'sigma': array(1.60191)}\n" - ] - } - ], + "outputs": [], "source": [ "vb_inits = vb_fit.create_inits()\n", "for chain_init in vb_inits:\n", @@ -831,100 +239,9 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "21:47:13 - cmdstanpy - INFO - CmdStan start processing\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1a23f210d88248d7b86353a78ecd3281", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "chain 1: 0%| | 0/1075 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
MeanMCSEStdDev5%50%95%N_EffN_Eff/sR_hat
lp__-156.9260000.0622531.775690-160.352000-156.589000-154.65400813.615004596.700001.002930
beta[1]0.9994890.0000140.0009550.9979170.9994941.001064623.0500026118.900000.999487
beta[2]1.0002400.0000180.0011480.9983631.0002601.002144167.4000023544.600001.000340
beta[3]1.0004200.0000140.0009410.9988771.0004201.001964680.4600026443.300000.999984
beta[4]1.0011500.0000170.0010820.9993941.0011301.002984017.7500022699.200000.999891
beta[5]1.0015700.0000160.0010720.9998361.0015901.003324777.0600026989.100001.000730
sigma0.9627400.0047200.0724700.8493000.9599501.08129235.938761332.987341.012130
\n", - "" - ], - "text/plain": [ - " Mean MCSE StdDev 5% 50% 95% \\\n", - "lp__ -156.926000 0.062253 1.775690 -160.352000 -156.589000 -154.65400 \n", - "beta[1] 0.999489 0.000014 0.000955 0.997917 0.999494 1.00106 \n", - "beta[2] 1.000240 0.000018 0.001148 0.998363 1.000260 1.00214 \n", - "beta[3] 1.000420 0.000014 0.000941 0.998877 1.000420 1.00196 \n", - "beta[4] 1.001150 0.000017 0.001082 0.999394 1.001130 1.00298 \n", - "beta[5] 1.001570 0.000016 0.001072 0.999836 1.001590 1.00332 \n", - "sigma 0.962740 0.004720 0.072470 0.849300 0.959950 1.08129 \n", - "\n", - " N_Eff N_Eff/s R_hat \n", - "lp__ 813.61500 4596.70000 1.002930 \n", - "beta[1] 4623.05000 26118.90000 0.999487 \n", - "beta[2] 4167.40000 23544.60000 1.000340 \n", - "beta[3] 4680.46000 26443.30000 0.999984 \n", - "beta[4] 4017.75000 22699.20000 0.999891 \n", - "beta[5] 4777.06000 26989.10000 1.000730 \n", - "sigma 235.93876 1332.98734 1.012130 " - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "mcmc_vb_inits_fit.summary()" ] @@ -1135,117 +282,18 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "21:47:14 - cmdstanpy - INFO - Chain [1] start processing\n", - "21:47:14 - cmdstanpy - INFO - Chain [1] done processing\n", - "21:47:14 - cmdstanpy - INFO - Chain [1] start processing\n", - "21:47:14 - cmdstanpy - INFO - Chain [1] done processing\n" - ] - } - ], + "outputs": [], "source": [ "laplace_inits = model.laplace_sample(data=data_file, seed=123).create_inits()" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "21:47:14 - cmdstanpy - INFO - CmdStan start processing\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c157bc1c9be442e4b12038e5d0c152e1", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "chain 1: 0%| | 0/1075 [00:00 Date: Fri, 9 May 2025 18:03:47 -0400 Subject: [PATCH 12/13] Fix test_variational test on creating inits --- test/test_variational.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_variational.py b/test/test_variational.py index f16a1e63..3070d48c 100644 --- a/test/test_variational.py +++ b/test/test_variational.py @@ -355,7 +355,7 @@ def test_variational_create_inits(): bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') - vb = bern_model.variational(data=jdata) + vb = bern_model.variational(data=jdata, seed=11235) inits = vb.create_inits() assert isinstance(inits, list) @@ -385,7 +385,7 @@ def test_variational_init_sampling(): logistic_model = CmdStanModel(stan_file=stan) logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') - vb = logistic_model.sample(data=logistic_data) + vb = logistic_model.variational(data=logistic_data, seed=11235) inits = vb.create_inits() fit = logistic_model.sample(data=logistic_data, inits=inits) From d6d5bd3f6625e29d16e673a3d82132a78463e1c0 Mon Sep 17 00:00:00 2001 From: amas Date: Wed, 14 May 2025 18:18:08 -0400 Subject: [PATCH 13/13] Clarify docs on using sampler output as inits --- .../examples/VI as Sampler Inits.ipynb | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/docsrc/users-guide/examples/VI as Sampler Inits.ipynb b/docsrc/users-guide/examples/VI as Sampler Inits.ipynb index 96ff8037..a80886fc 100644 --- a/docsrc/users-guide/examples/VI as Sampler Inits.ipynb +++ b/docsrc/users-guide/examples/VI as Sampler Inits.ipynb @@ -4,16 +4,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Using Estimates from Variational, Laplace, or Optimization Methods to Initialize the NUTS-HMC Sampler\n", + "## Initializing the NUTS-HMC sampler\n", "\n", - "In this example, we show how to use parameter estimates returned by Stan's various posterior approximation or optimization algorithms as initial values for Stan's NUTS-HMC sampler. These include:\n", + "In this example, we show how to use parameter estimates returned by any of Stan's inference algorithms as initial values for Stan's NUTS-HMC sampler. These include:\n", "\n", "* [Pathfinder ](https://mc-stan.org/docs/cmdstan-guide/pathfinder-config.html) \n", "* [ADVI ](https://mc-stan.org/docs/cmdstan-guide/variational_config.html) \n", "* [Laplace](https://mc-stan.org/docs/cmdstan-guide/laplace_sample_config.html)\n", "* [Optimization](https://mc-stan.org/docs/cmdstan-guide/optimize_config.html)\n", + "* [NUTS-HMC MCMC](https://mc-stan.org/docs/cmdstan-guide/mcmc_config.html)\n", "\n", - "By default, the NUTS-HMC sampler randomly initializes all model parameters uniformly in the interval $(-2, 2)$. If this interval is far from the typical set of the posterior, initializing sampling from these approximation algorithms can speed up and improve adaptation.\n", + "By default, the NUTS-HMC sampler randomly initializes all (unconstrained) model parameters uniformly in the interval (-2, 2). If this interval is far from the typical set of the posterior, initializing sampling from these approximation algorithms can speed up and improve adaptation.\n", "\n", "### Model and data\n", "\n", @@ -49,7 +50,7 @@ "source": [ "### Demonstration with Stan's `pathfinder` method\n", "\n", - "The approximation methods all follow the same general pattern of usage. First, we call the \n", + "Initializing the sampler with estimates from any previous inference algorithm follows the same general usage pattern. First, we call the \n", "corresponding method on the `CmdStanModel` object. From the resulting fit, we call the `.create_inits()` \n", "method to construct a set of per-chain initializations for the model parameters. To make it explicit, \n", "we will walk through the process using the `pathfinder` method (which wraps the \n", @@ -84,7 +85,7 @@ "source": [ "Posteriordb provides reference posteriors for all models. For the blr model, conditioned on the dataset `sblri.json`, the reference posteriors can be found in the [sblri-blr.json](https://github.com/stan-dev/posteriordb/blob/master/posterior_database/reference_posteriors/summary_statistics/mean/mean/sblri-blr.json) file.\n", "\n", - "The reference posteriors for all elements of `beta` and `sigma` are all very close to $1.0$.\n", + "The reference posteriors for all elements of `beta` and `sigma` are all very close to 1.0.\n", "\n", "The experiments reported in Figure 3 of the paper [Pathfinder: Parallel quasi-Newton variational inference](https://arxiv.org/abs/2108.03782) by Zhang et al. show that Pathfinder provides a better estimate of the posterior, as measured by the 1-Wasserstein distance to the reference posterior, than 75 iterations of the warmup Phase I algorithm used by the NUTS-HMC sampler.\n", "Furthermore, Pathfinder is more computationally efficient, requiring fewer evaluations of the log density and gradient functions. Therefore, using the Pathfinder estimates to initialize the parameter values for the NUTS-HMC sampler can allow the sampler to do a better job of adapting the stepsize and metric during warmup, resulting in better performance and estimation.\n", @@ -191,7 +192,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Other approximation algorithms" + "### Other inference algorithms" ] }, { @@ -349,7 +350,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "It is also possible to use the output of the `sample()` method to construct inits to be fed into a future sampling run:" + "It is also possible to use the output of the `sample()` method itself to construct inits to be fed into a future sampling run:" ] }, {