Skip to content

Add new create_inits() methods to other stanfit classes #791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 14, 2025
Merged
35 changes: 34 additions & 1 deletion cmdstanpy/stanfit/laplace.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Container for the result of running a laplace approximation.
Container for the result of running a laplace approximation.
"""

from typing import (
Expand Down Expand Up @@ -52,6 +52,39 @@ def __init__(self, runset: RunSet, mode: CmdStanMLE) -> None:
config = scan_generic_csv(runset.csv_files[0])
self._metadata = InferenceMetadata(config)

def create_inits(
self, seed: Optional[int] = None, chains: int = 4
) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
"""
Create initial values for the parameters of the model
by randomly selecting draws from the Laplace approximation.

:param seed: Used for random selection, defaults to None
:param chains: Number of initial values to return, defaults to 4
:return: The initial values for the parameters of the model.

If ``chains`` is 1, a dictionary is returned, otherwise a list
of dictionaries is returned, in the format expected for the
``inits`` argument of :meth:`CmdStanModel.sample`.
"""
self._assemble_draws()
rng = np.random.default_rng(seed)
idxs = rng.choice(self._draws.shape[0], size=chains, replace=False)
if chains == 1:
draw = self._draws[idxs[0]]
return {
name: var.extract_reshape(draw)
for name, var in self._metadata.stan_vars.items()
}
else:
return [
{
name: var.extract_reshape(self._draws[idx])
for name, var in self._metadata.stan_vars.items()
}
for idx in idxs
]

def _assemble_draws(self) -> None:
if self._draws.shape != (0,):
return
Expand Down
42 changes: 41 additions & 1 deletion cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,46 @@ def __init__(
if not self._is_fixed_param:
self._check_sampler_diagnostics()

def create_inits(
self, seed: Optional[int] = None, chains: int = 4
) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
"""
Create initial values for the parameters of the model by
randomly selecting draws from the MCMC samples. If the samples
contain draws from multiple chains, each draw will be from
a different chain, if possible. Otherwise the chain is randomly
selected.

:param seed: Used for random selection, defaults to None
:param chains: Number of initial values to return, defaults to 4
:return: The initial values for the parameters of the model.

If ``chains`` is 1, a dictionary is returned, otherwise a list
of dictionaries is returned, in the format expected for the
``inits`` argument of :meth:`CmdStanModel.sample`.
"""
self._assemble_draws()
rng = np.random.default_rng(seed)
n_draws, n_chains = self._draws.shape[:2]
draw_idxs = rng.choice(n_draws, size=chains, replace=False)
chain_idxs = rng.choice(
n_chains, size=chains, replace=(n_chains <= chains)
)
if chains == 1:
draw = self._draws[draw_idxs[0], chain_idxs[0]]
return {
name: var.extract_reshape(draw)
for name, var in self._metadata.stan_vars.items()
}
else:
return [
{
name: var.extract_reshape(self._draws[d, i])
for name, var in self._metadata.stan_vars.items()
}
for d, i in zip(draw_idxs, chain_idxs)
]

def __repr__(self) -> str:
repr = 'CmdStanMCMC: model={} chains={}{}'.format(
self.runset.model,
Expand Down Expand Up @@ -685,7 +725,7 @@ def draws_xr(
)
if inc_warmup and not self._save_warmup:
get_logger().warning(
"Draws from warmup iterations not available,"
'Draws from warmup iterations not available,'
' must run sampler with "save_warmup=True".'
)
if vars is None:
Expand Down
24 changes: 24 additions & 0 deletions cmdstanpy/stanfit/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,30 @@ def __init__(self, runset: RunSet) -> None:
self._save_iterations: bool = optimize_args.save_iterations
self._set_mle_attrs(runset.csv_files[0])

def create_inits(
self, seed: Optional[int] = None, chains: int = 4
) -> Dict[str, np.ndarray]:
"""
Create initial values for the parameters of the model
from the MLE.

:param seed: Unused. Kept for compatibility with other
create_inits methods.
:param chains: Unused. Kept for compatibility with other
create_inits methods.
:return: The initial values for the parameters of the model.

Returns a dictionary of MLE estimates in the format expected
for the ``inits`` argument of :meth:`CmdStanModel.sample`.
When running multi-chain sampling, all chains will be initialized
at the same points.
"""
# pylint: disable=unused-argument

return {
name: np.array(val) for name, val in self.stan_variables().items()
}

def __repr__(self) -> str:
repr = 'CmdStanMLE: model={}{}'.format(
self.runset.model, self.runset._args.method_args.compose(0, cmd=[])
Expand Down
2 changes: 1 addition & 1 deletion cmdstanpy/stanfit/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def create_inits(

If ``chains`` is 1, a dictionary is returned, otherwise a list
of dictionaries is returned, in the format expected for the
``inits`` argument. of :meth:`CmdStanModel.sample`.
``inits`` argument of :meth:`CmdStanModel.sample`.
"""
self._assemble_draws()
rng = np.random.default_rng(seed)
Expand Down
37 changes: 36 additions & 1 deletion cmdstanpy/stanfit/vb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Container for the results of running autodiff variational inference"""

from collections import OrderedDict
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -30,6 +30,41 @@ def __init__(self, runset: RunSet) -> None:
self.runset = runset
self._set_variational_attrs(runset.csv_files[0])

def create_inits(
self, seed: Optional[int] = None, chains: int = 4
) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
"""
Create initial values for the parameters of the model
by randomly selecting draws from the variational approximation
draws.

:param seed: Used for random selection, defaults to None
:param chains: Number of initial values to return, defaults to 4
:return: The initial values for the parameters of the model.

If ``chains`` is 1, a dictionary is returned, otherwise a list
of dictionaries is returned, in the format expected for the
``inits`` argument of :meth:`CmdStanModel.sample`.
"""
rng = np.random.default_rng(seed)
idxs = rng.choice(
self.variational_sample.shape[0], size=chains, replace=False
)
if chains == 1:
draw = self.variational_sample[idxs[0]]
return {
name: var.extract_reshape(draw)
for name, var in self._metadata.stan_vars.items()
}
else:
return [
{
name: var.extract_reshape(self.variational_sample[idx])
for name, var in self._metadata.stan_vars.items()
}
for idx in idxs
]

def __repr__(self) -> str:
repr = 'CmdStanVB: model={}{}'.format(
self.runset.model, self.runset._args.method_args.compose(0, cmd=[])
Expand Down
Loading
Loading