diff --git a/cmdstanpy/stanfit/laplace.py b/cmdstanpy/stanfit/laplace.py index 741593e7..00e5199a 100644 --- a/cmdstanpy/stanfit/laplace.py +++ b/cmdstanpy/stanfit/laplace.py @@ -1,5 +1,5 @@ """ - Container for the result of running a laplace approximation. +Container for the result of running a laplace approximation. """ from typing import ( @@ -52,6 +52,39 @@ def __init__(self, runset: RunSet, mode: CmdStanMLE) -> None: config = scan_generic_csv(runset.csv_files[0]) self._metadata = InferenceMetadata(config) + def create_inits( + self, seed: Optional[int] = None, chains: int = 4 + ) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + """ + Create initial values for the parameters of the model + by randomly selecting draws from the Laplace approximation. + + :param seed: Used for random selection, defaults to None + :param chains: Number of initial values to return, defaults to 4 + :return: The initial values for the parameters of the model. + + If ``chains`` is 1, a dictionary is returned, otherwise a list + of dictionaries is returned, in the format expected for the + ``inits`` argument of :meth:`CmdStanModel.sample`. + """ + self._assemble_draws() + rng = np.random.default_rng(seed) + idxs = rng.choice(self._draws.shape[0], size=chains, replace=False) + if chains == 1: + draw = self._draws[idxs[0]] + return { + name: var.extract_reshape(draw) + for name, var in self._metadata.stan_vars.items() + } + else: + return [ + { + name: var.extract_reshape(self._draws[idx]) + for name, var in self._metadata.stan_vars.items() + } + for idx in idxs + ] + def _assemble_draws(self) -> None: if self._draws.shape != (0,): return diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index f96ff023..15cb91d9 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -105,6 +105,46 @@ def __init__( if not self._is_fixed_param: self._check_sampler_diagnostics() + def create_inits( + self, seed: Optional[int] = None, chains: int = 4 + ) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + """ + Create initial values for the parameters of the model by + randomly selecting draws from the MCMC samples. If the samples + contain draws from multiple chains, each draw will be from + a different chain, if possible. Otherwise the chain is randomly + selected. + + :param seed: Used for random selection, defaults to None + :param chains: Number of initial values to return, defaults to 4 + :return: The initial values for the parameters of the model. + + If ``chains`` is 1, a dictionary is returned, otherwise a list + of dictionaries is returned, in the format expected for the + ``inits`` argument of :meth:`CmdStanModel.sample`. + """ + self._assemble_draws() + rng = np.random.default_rng(seed) + n_draws, n_chains = self._draws.shape[:2] + draw_idxs = rng.choice(n_draws, size=chains, replace=False) + chain_idxs = rng.choice( + n_chains, size=chains, replace=(n_chains <= chains) + ) + if chains == 1: + draw = self._draws[draw_idxs[0], chain_idxs[0]] + return { + name: var.extract_reshape(draw) + for name, var in self._metadata.stan_vars.items() + } + else: + return [ + { + name: var.extract_reshape(self._draws[d, i]) + for name, var in self._metadata.stan_vars.items() + } + for d, i in zip(draw_idxs, chain_idxs) + ] + def __repr__(self) -> str: repr = 'CmdStanMCMC: model={} chains={}{}'.format( self.runset.model, @@ -685,7 +725,7 @@ def draws_xr( ) if inc_warmup and not self._save_warmup: get_logger().warning( - "Draws from warmup iterations not available," + 'Draws from warmup iterations not available,' ' must run sampler with "save_warmup=True".' ) if vars is None: diff --git a/cmdstanpy/stanfit/mle.py b/cmdstanpy/stanfit/mle.py index 3a50ba10..fd599dbf 100644 --- a/cmdstanpy/stanfit/mle.py +++ b/cmdstanpy/stanfit/mle.py @@ -36,6 +36,30 @@ def __init__(self, runset: RunSet) -> None: self._save_iterations: bool = optimize_args.save_iterations self._set_mle_attrs(runset.csv_files[0]) + def create_inits( + self, seed: Optional[int] = None, chains: int = 4 + ) -> Dict[str, np.ndarray]: + """ + Create initial values for the parameters of the model + from the MLE. + + :param seed: Unused. Kept for compatibility with other + create_inits methods. + :param chains: Unused. Kept for compatibility with other + create_inits methods. + :return: The initial values for the parameters of the model. + + Returns a dictionary of MLE estimates in the format expected + for the ``inits`` argument of :meth:`CmdStanModel.sample`. + When running multi-chain sampling, all chains will be initialized + at the same points. + """ + # pylint: disable=unused-argument + + return { + name: np.array(val) for name, val in self.stan_variables().items() + } + def __repr__(self) -> str: repr = 'CmdStanMLE: model={}{}'.format( self.runset.model, self.runset._args.method_args.compose(0, cmd=[]) diff --git a/cmdstanpy/stanfit/pathfinder.py b/cmdstanpy/stanfit/pathfinder.py index 8e63f85f..5ac4d213 100644 --- a/cmdstanpy/stanfit/pathfinder.py +++ b/cmdstanpy/stanfit/pathfinder.py @@ -45,7 +45,7 @@ def create_inits( If ``chains`` is 1, a dictionary is returned, otherwise a list of dictionaries is returned, in the format expected for the - ``inits`` argument. of :meth:`CmdStanModel.sample`. + ``inits`` argument of :meth:`CmdStanModel.sample`. """ self._assemble_draws() rng = np.random.default_rng(seed) diff --git a/cmdstanpy/stanfit/vb.py b/cmdstanpy/stanfit/vb.py index 102f292c..8d7ac552 100644 --- a/cmdstanpy/stanfit/vb.py +++ b/cmdstanpy/stanfit/vb.py @@ -1,7 +1,7 @@ """Container for the results of running autodiff variational inference""" from collections import OrderedDict -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -30,6 +30,41 @@ def __init__(self, runset: RunSet) -> None: self.runset = runset self._set_variational_attrs(runset.csv_files[0]) + def create_inits( + self, seed: Optional[int] = None, chains: int = 4 + ) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + """ + Create initial values for the parameters of the model + by randomly selecting draws from the variational approximation + draws. + + :param seed: Used for random selection, defaults to None + :param chains: Number of initial values to return, defaults to 4 + :return: The initial values for the parameters of the model. + + If ``chains`` is 1, a dictionary is returned, otherwise a list + of dictionaries is returned, in the format expected for the + ``inits`` argument of :meth:`CmdStanModel.sample`. + """ + rng = np.random.default_rng(seed) + idxs = rng.choice( + self.variational_sample.shape[0], size=chains, replace=False + ) + if chains == 1: + draw = self.variational_sample[idxs[0]] + return { + name: var.extract_reshape(draw) + for name, var in self._metadata.stan_vars.items() + } + else: + return [ + { + name: var.extract_reshape(self.variational_sample[idx]) + for name, var in self._metadata.stan_vars.items() + } + for idx in idxs + ] + def __repr__(self) -> str: repr = 'CmdStanVB: model={}{}'.format( self.runset.model, self.runset._args.method_args.compose(0, cmd=[]) diff --git a/docsrc/users-guide/examples/VI as Sampler Inits.ipynb b/docsrc/users-guide/examples/VI as Sampler Inits.ipynb index 03f5c527..a80886fc 100644 --- a/docsrc/users-guide/examples/VI as Sampler Inits.ipynb +++ b/docsrc/users-guide/examples/VI as Sampler Inits.ipynb @@ -4,13 +4,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Using Variational Estimates to Initialize the NUTS-HMC Sampler\n", + "## Initializing the NUTS-HMC sampler\n", "\n", - "In this example we show how to use the parameter estimates return by Stan's variational inference algorithms\n", - "[pathfinder ](https://mc-stan.org/docs/cmdstan-guide/pathfinder-config.html) and \n", - "[ADVI ](https://mc-stan.org/docs/cmdstan-guide/variational_config.html) \n", - "as the initial parameter values for Stan's NUTS-HMC sampler.\n", - "By default, the sampler algorithm randomly initializes all model parameters in the range uniform\\[-2, 2\\]. When the true parameter value is outside of this range, starting from the estimates from Pathfinder and ADVI will speed up and improve adaptation.\n", + "In this example, we show how to use parameter estimates returned by any of Stan's inference algorithms as initial values for Stan's NUTS-HMC sampler. These include:\n", + "\n", + "* [Pathfinder ](https://mc-stan.org/docs/cmdstan-guide/pathfinder-config.html) \n", + "* [ADVI ](https://mc-stan.org/docs/cmdstan-guide/variational_config.html) \n", + "* [Laplace](https://mc-stan.org/docs/cmdstan-guide/laplace_sample_config.html)\n", + "* [Optimization](https://mc-stan.org/docs/cmdstan-guide/optimize_config.html)\n", + "* [NUTS-HMC MCMC](https://mc-stan.org/docs/cmdstan-guide/mcmc_config.html)\n", + "\n", + "By default, the NUTS-HMC sampler randomly initializes all (unconstrained) model parameters uniformly in the interval (-2, 2). If this interval is far from the typical set of the posterior, initializing sampling from these approximation algorithms can speed up and improve adaptation.\n", "\n", "### Model and data\n", "\n", @@ -20,40 +24,14 @@ "a Bayesian standard linear regression model with noninformative priors,\n", "and its corresponding simulated dataset [sblri.json](https://github.com/stan-dev/posteriordb/blob/master/posterior_database/data/data/sblri.json.zip),\n", "which was simulated via script [sblr.R](https://github.com/stan-dev/posteriordb/blob/master/posterior_database/data/data-raw/sblr/sblr.R).\n", - "For conveince, we have copied the posteriordb model and data to this directory, in files `blr.stan` and `sblri.json`." + "For convenience, this example assumes the posteriordb model and data are local, in files `blr.stan` and `sblri.json`." ] }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "data {\n", - " int N;\n", - " int D;\n", - " matrix[N, D] X;\n", - " vector[N] y;\n", - "}\n", - "parameters {\n", - " vector[D] beta;\n", - " real sigma;\n", - "}\n", - "model {\n", - " // prior\n", - " target += normal_lpdf(beta | 0, 10);\n", - " target += normal_lpdf(sigma | 0, 10);\n", - " // likelihood\n", - " target += normal_lpdf(y | X * beta, sigma);\n", - "}\n", - "\n", - "\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "import os\n", "from cmdstanpy import CmdStanModel\n", @@ -70,38 +48,33 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Run Stan's `pathfinder` or `variational` algorithm, obtain fitted estimates\n", + "### Demonstration with Stan's `pathfinder` method\n", "\n", - "The [CmdStanModel pathfinder](https://mc-stan.org/cmdstanpy/api.html#cmdstanpy.CmdStanModel.pathfinder ) method\n", - "wraps the CmdStan [pathfinder ](https://mc-stan.org/docs/cmdstan-guide/pathfinder-config.html) method.\n", + "Initializing the sampler with estimates from any previous inference algorithm follows the same general usage pattern. First, we call the \n", + "corresponding method on the `CmdStanModel` object. From the resulting fit, we call the `.create_inits()` \n", + "method to construct a set of per-chain initializations for the model parameters. To make it explicit, \n", + "we will walk through the process using the `pathfinder` method (which wraps the \n", + "CmdStan [pathfinder ](https://mc-stan.org/docs/cmdstan-guide/pathfinder-config.html) method).\n", "\n", "Pathfinder locates normal approximations to the target\n", "density along a quasi-Newton optimization path, with local covariance\n", "estimated using the negative inverse Hessian estimates produced by the\n", - "LBFGS optimizer. Pathfinder returns draws from the Gaussian approximation\n", + "LBFGS optimizer. Pathfinder returns draws from the Gaussian approximation\n", "with the lowest estimated Kullback-Leibler (KL) divergence to the true\n", "posterior.\n", - "By default, CmdStanPy runs multi-path Pathfinder which returns an importance-resampled set of draws over the outputs of 4 independent single-path Pathfinders.\n", + "By default, CmdStanPy runs multi-path Pathfinder which returns an importance-resampled \n", + "set of draws over the outputs of 4 independent single-path Pathfinders.\n", "This better matches non-normal target densities and also mitigates\n", "the problem of L-BFGS getting stuck at local optima or in saddle points on plateaus.\n", "\n", - "The method [create_inits](https://mc-stan.org/cmdstanpy/api.html#cmdstanpy.CmdStanPathfinder.create_inits) returns a Python Dict containing a set of per-chain initializations for the model parameters. Each set of initializations is a random draw from the Pathfinder sample." + "We obtain Pathfinder estimates by calling the `.pathfinder()` method which returns a `CmdStanPathfinder` object:" ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "17:01:02 - cmdstanpy - INFO - Chain [1] start processing\n", - "17:01:02 - cmdstanpy - INFO - Chain [1] done processing\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "pathfinder_fit = model.pathfinder(data=data_file, seed=123)" ] @@ -110,133 +83,61 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Posteriordb provides reference posteriors for all models. For the blr model, conditioned on the dataset `sblri.json`, the reference posteriors are in file [sblri-blr.json](https://github.com/stan-dev/posteriordb/blob/master/posterior_database/reference_posteriors/summary_statistics/mean/mean/sblri-blr.json)\n", + "Posteriordb provides reference posteriors for all models. For the blr model, conditioned on the dataset `sblri.json`, the reference posteriors can be found in the [sblri-blr.json](https://github.com/stan-dev/posteriordb/blob/master/posterior_database/reference_posteriors/summary_statistics/mean/mean/sblri-blr.json) file.\n", "\n", - "The reference posteriors for all elements of `beta` and `sigma` are all very close to $1.0$.\n", + "The reference posteriors for all elements of `beta` and `sigma` are all very close to 1.0.\n", "\n", - "The experiments reported in Figure 3 of the paper [Pathfinder: Parallel quasi-Newton variational inference](https://arxiv.org/abs/2108.03782) by Zhang et al. show that Pathfinder provides a better estimate of the posterior, as measured by the 1-Wasserstein distance to the reference posterior, than 75 iterations of the warmup Phase I algorithm used by the NUTS-HMC sampler.\n", - "furthermore, Pathfinder is more computationally efficient, requiring fewer evaluations of the log density and gradient functions. Therefore, using the estimates from Pathfinder to initialize the parameter values for the NUTS-HMC sampler will allow the sampler to do a better job of adapting the stepsize and metric during warmup, resulting in better performance and estimation." + "The experiments reported in Figure 3 of the paper [Pathfinder: Parallel quasi-Newton variational inference](https://arxiv.org/abs/2108.03782) by Zhang et al. show that Pathfinder provides a better estimate of the posterior, as measured by the 1-Wasserstein distance to the reference posterior, than 75 iterations of the warmup Phase I algorithm used by the NUTS-HMC sampler.\n", + "Furthermore, Pathfinder is more computationally efficient, requiring fewer evaluations of the log density and gradient functions. Therefore, using the Pathfinder estimates to initialize the parameter values for the NUTS-HMC sampler can allow the sampler to do a better job of adapting the stepsize and metric during warmup, resulting in better performance and estimation.\n", + "\n", + "We construct the parameter inits for full MCMC sampling below. The `.create_inits()` default behavior is to create inits for four chains to correspond with the sampling defaults. You can requests more or less by modifying the `chains` keyword argument." ] }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[{'beta': array([0.996649, 0.999455, 1.00093 , 0.99873 , 1.00207 ]), 'sigma': array(0.934232)}, {'beta': array([1.00016 , 0.998764, 1.00055 , 1.00212 , 1.00047 ]), 'sigma': array(1.04441)}, {'beta': array([1.00139 , 0.997917, 1.00134 , 1.00123 , 1.00116 ]), 'sigma': array(0.946814)}, {'beta': array([0.999491, 0.999225, 1.00114 , 0.999147, 0.998943]), 'sigma': array(0.977812)}]\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "pathfinder_inits = pathfinder_fit.create_inits()\n", - "print(pathfinder_inits)" + "for chain_init in pathfinder_inits:\n", + " print(chain_init)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the Pathfinder inits are close the reference posteriors for the parameters. To use these inits, we pass the `pathfinder_inits` object to the `inits` kwarg:" ] }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "17:01:05 - cmdstanpy - INFO - CmdStan start processing\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d8a75128e05e4cf88f037897a38d0173", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "chain 1 | | 00:00 Status" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3dbf5f498c5a47a889b0b5229d200ac4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "chain 2 | | 00:00 Status" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "82aa8eb3e89a4d55852aaefd0cbe856e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "chain 3 | | 00:00 Status" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5b1e5ff5b1914fefa8aed58b19dff966", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "chain 4 | | 00:00 Status" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " " - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "17:01:05 - cmdstanpy - INFO - CmdStan done processing.\n", - "17:01:05 - cmdstanpy - WARNING - Non-fatal error during sampling:\n", - "Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/Users/mitzi/github/stan-dev/cmdstanpy/docsrc/users-guide/examples/blr.stan', line 16, column 2 to column 45)\n", - "Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/Users/mitzi/github/stan-dev/cmdstanpy/docsrc/users-guide/examples/blr.stan', line 16, column 2 to column 45)\n", - "Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/Users/mitzi/github/stan-dev/cmdstanpy/docsrc/users-guide/examples/blr.stan', line 16, column 2 to column 45)\n", - "Consider re-running with show_console=True if the above output is unclear!\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "mcmc_pathfinder_inits_fit = model.sample(\n", " data=data_file, inits=pathfinder_inits, iter_warmup=75, seed=12345\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(mcmc_pathfinder_inits_fit.diagnose())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Despite only running 75 warmup iterations, all posterior diagnostics from the sampler look good." + ] + }, { "cell_type": "code", "execution_count": null, @@ -250,7 +151,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Using the default random parameter initializations, we need to run more warmup iterations. If we only run 75 warmup iterations with random inits, the result fails to estimate `sigma` correctly. It is necessary to run the model with at least 150 warmup iterations to produce a good set of estimates." + "If we were to instead use the default random parameter initializations, we would need to run more warmup iterations to produce useful samples. For example, if we only run the same 75 warmup iterations with random inits, the result fails to estimate `sigma` correctly: " ] }, { @@ -284,8 +185,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The `CmdStanModel` method `variational` runs CmdStan's ADVI algorithm.\n", - "Because this algorithm is unstable and may fail to converge, we run it with argument `require_converged` set to `False`. We also specify a seed, to avoid instabilities as well as for reproducibility." + "The diagnostics clearly indicate problems with estimating `sigma`. In this case, it is necessary to run the model with at least 150 warmup iterations to produce a good set of estimates when starting from the default initialization." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Other inference algorithms" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can follow the same pattern with Stan's ADVI algorithm by first using the `CmdStanModel.variationl` method. Because this algorithm is unstable and may fail to converge, we run it with argument `require_converged` set to `False`. We also specify a seed, to avoid instabilities as well as for reproducibility." ] }, { @@ -303,9 +217,7 @@ "source": [ "The ADVI algorithm provides estimates of all model parameters.\n", "\n", - "The `variational` method returns a `CmdStanVB` object, with method `stan_variables`, which\n", - "returns the approximat posterior samples of all model parameters as a Python dictionary. \n", - "Here, we report the approximate posterior mean." + "The `variational` method returns a `CmdStanVB` object, which similarly can construct a set of inits with the `.create_inits()` method:" ] }, { @@ -314,8 +226,16 @@ "metadata": {}, "outputs": [], "source": [ - "vb_mean = {var: samples.mean(axis=0) for var, samples in vb_fit.stan_variables(mean=False).items()}\n", - "print(vb_mean)" + "vb_inits = vb_fit.create_inits()\n", + "for chain_init in vb_inits:\n", + " print(chain_init)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Which can be passed to the `inits` keyword argument of the sample method." ] }, { @@ -325,10 +245,19 @@ "outputs": [], "source": [ "mcmc_vb_inits_fit = model.sample(\n", - " data=data_file, inits=vb_mean, iter_warmup=75, seed=12345\n", + " data=data_file, inits=vb_inits, iter_warmup=75, seed=12345\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(mcmc_vb_inits_fit.diagnose())" + ] + }, { "cell_type": "code", "execution_count": null, @@ -342,7 +271,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The sampler estimates match the reference posterior." + "The sampler estimates match the reference posterior with no diagnostic issues." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inits can also be constructed from the `laplace` method:" ] }, { @@ -351,7 +287,106 @@ "metadata": {}, "outputs": [], "source": [ - "print(mcmc_vb_inits_fit.diagnose())" + "laplace_inits = model.laplace_sample(data=data_file, seed=123).create_inits()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mcmc_laplace_inits_fit = model.sample(\n", + " data=data_file, inits=laplace_inits, iter_warmup=75, seed=12345\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(mcmc_laplace_inits_fit.diagnose())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And the `optimize` method. Since optimizations attempts to return a posterior mode, this will initialize all chains at the same point, which is not typically ideal." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "optimized_inits = model.optimize(data=data_file, seed=123).create_inits()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mcmc_optimize_inits_fit = model.sample(\n", + " data=data_file, inits=optimized_inits, iter_warmup=75, seed=12345\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(mcmc_optimize_inits_fit.diagnose())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It is also possible to use the output of the `sample()` method itself to construct inits to be fed into a future sampling run:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "first_mcmc_inits = model.sample(\n", + " data=data_file, iter_warmup=75, seed=12345\n", + ").create_inits()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "second_mcmc_fit = model.sample(data=data_file, inits=first_mcmc_inits, iter_warmup=75, seed=12345)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(second_mcmc_fit.diagnose())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that despite the initial sampling issues in the first MCMC run, the inits sourced from that run result in reasonable sampling in the second run." ] } ], @@ -371,7 +406,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.11.12" } }, "nbformat": 4, diff --git a/test/test_laplace.py b/test/test_laplace.py index 14f00eb6..14729407 100644 --- a/test/test_laplace.py +++ b/test/test_laplace.py @@ -105,3 +105,47 @@ def test_laplace_outputs(): assert 'x' in fit_pd.columns assert 'y' in fit_pd.columns assert fit_pd['x'].shape == (123,) + + +def test_laplace_create_inits(): + stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') + bern_model = cmdstanpy.CmdStanModel(stan_file=stan) + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + + laplace = bern_model.laplace_sample(data=jdata) + + inits = laplace.create_inits() + assert isinstance(inits, list) + assert len(inits) == 4 + assert isinstance(inits[0], dict) + assert 'theta' in inits[0] + + inits_10 = laplace.create_inits(chains=10) + assert isinstance(inits_10, list) + assert len(inits_10) == 10 + + inits_1 = laplace.create_inits(chains=1) + assert isinstance(inits_1, dict) + assert 'theta' in inits_1 + assert len(inits_1) == 1 + + seeded = laplace.create_inits(seed=1234) + seeded2 = laplace.create_inits(seed=1234) + assert all( + init1['theta'] == init2['theta'] + for init1, init2 in zip(seeded, seeded2) + ) + + +def test_laplace_init_sampling(): + stan = os.path.join(DATAFILES_PATH, 'logistic.stan') + logistic_model = cmdstanpy.CmdStanModel(stan_file=stan) + logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') + + laplace = logistic_model.laplace_sample(data=logistic_data) + inits = laplace.create_inits() + + fit = logistic_model.sample(data=logistic_data, inits=inits) + + assert fit.chains == 4 + assert fit.draws().shape == (1000, 4, 9) diff --git a/test/test_optimize.py b/test/test_optimize.py index d5024ff5..034aca87 100644 --- a/test/test_optimize.py +++ b/test/test_optimize.py @@ -671,3 +671,30 @@ def test_serialization() -> None: np.testing.assert_array_equal( mle1.optimized_params_np, mle2.optimized_params_np ) + + +def test_optimize_create_inits(): + stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') + bern_model = CmdStanModel(stan_file=stan) + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + + mle = bern_model.optimize(data=jdata) + + inits = mle.create_inits() + assert isinstance(inits, dict) + assert 'theta' in inits + assert len(inits) == 1 + + +def test_optimize_init_sampling(): + stan = os.path.join(DATAFILES_PATH, 'logistic.stan') + logistic_model = CmdStanModel(stan_file=stan) + logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') + + mle = logistic_model.optimize(data=logistic_data) + inits = mle.create_inits() + + fit = logistic_model.sample(data=logistic_data, inits=inits) + + assert fit.chains == 4 + assert fit.draws().shape == (1000, 4, 9) diff --git a/test/test_sample.py b/test/test_sample.py index 3c987ddc..8ec155d8 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -2083,3 +2083,47 @@ def test_serialization(stanfile='bernoulli.stan'): assert set(variables1) == set(variables2) for key, value1 in variables1.items(): np.testing.assert_array_equal(value1, variables2[key]) + + +def test_mcmc_create_inits(): + stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') + bern_model = cmdstanpy.CmdStanModel(stan_file=stan) + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + + mcmc = bern_model.sample(data=jdata) + + inits = mcmc.create_inits() + assert isinstance(inits, list) + assert len(inits) == 4 + assert isinstance(inits[0], dict) + assert 'theta' in inits[0] + + inits_10 = mcmc.create_inits(chains=10) + assert isinstance(inits_10, list) + assert len(inits_10) == 10 + + inits_1 = mcmc.create_inits(chains=1) + assert isinstance(inits_1, dict) + assert 'theta' in inits_1 + assert len(inits_1) == 1 + + seeded = mcmc.create_inits(seed=1234) + seeded2 = mcmc.create_inits(seed=1234) + assert all( + init1['theta'] == init2['theta'] + for init1, init2 in zip(seeded, seeded2) + ) + + +def test_mcmc_init_sampling(): + stan = os.path.join(DATAFILES_PATH, 'logistic.stan') + logistic_model = cmdstanpy.CmdStanModel(stan_file=stan) + logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') + + initial_mcmc = logistic_model.sample(data=logistic_data) + inits = initial_mcmc.create_inits() + + fit = logistic_model.sample(data=logistic_data, inits=inits) + + assert fit.chains == 4 + assert fit.draws().shape == (1000, 4, 9) diff --git a/test/test_variational.py b/test/test_variational.py index 0b859e2f..3070d48c 100644 --- a/test/test_variational.py +++ b/test/test_variational.py @@ -348,3 +348,47 @@ def test_serialization() -> None: variational1.variational_params_dict == variational2.variational_params_dict ) + + +def test_variational_create_inits(): + stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') + bern_model = CmdStanModel(stan_file=stan) + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + + vb = bern_model.variational(data=jdata, seed=11235) + + inits = vb.create_inits() + assert isinstance(inits, list) + assert len(inits) == 4 + assert isinstance(inits[0], dict) + assert 'theta' in inits[0] + + inits_10 = vb.create_inits(chains=10) + assert isinstance(inits_10, list) + assert len(inits_10) == 10 + + inits_1 = vb.create_inits(chains=1) + assert isinstance(inits_1, dict) + assert 'theta' in inits_1 + assert len(inits_1) == 1 + + seeded = vb.create_inits(seed=1234) + seeded2 = vb.create_inits(seed=1234) + assert all( + init1['theta'] == init2['theta'] + for init1, init2 in zip(seeded, seeded2) + ) + + +def test_variational_init_sampling(): + stan = os.path.join(DATAFILES_PATH, 'logistic.stan') + logistic_model = CmdStanModel(stan_file=stan) + logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') + + vb = logistic_model.variational(data=logistic_data, seed=11235) + inits = vb.create_inits() + + fit = logistic_model.sample(data=logistic_data, inits=inits) + + assert fit.chains == 4 + assert fit.draws().shape == (1000, 4, 9)