From 0311f0b466604c63c5d3550561efd738ddf01702 Mon Sep 17 00:00:00 2001 From: Tom McClintock Date: Wed, 23 Jun 2021 09:55:05 -0700 Subject: [PATCH] Support for named parameters (#386) * Added named parameter functionality to the ensemble sampler and added tests * Made an error message more informative * Made a test for run_mcmc * Removed scipy from the test_ensemble.py unittests * Docstring and updated typing for parameter_names * Working on separate logic for dict vs list parameter names * Added functionality for the parameter_names to be a dictionary of either integers or lists of integers * Added functionality for the parameter_names to be a dictionary of either integers or lists of integers. Also ran isort and black * Test of the dictionary parameter name functionality --- src/emcee/ensemble.py | 95 +++++++++++-- src/emcee/tests/unit/test_ensemble.py | 184 ++++++++++++++++++++++++++ 2 files changed, 268 insertions(+), 11 deletions(-) create mode 100644 src/emcee/tests/unit/test_ensemble.py diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 45cc4ba9..24faba3c 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- import warnings +from itertools import count +from typing import Dict, List, Optional, Union import numpy as np -from itertools import count from .backends import Backend from .model import Model @@ -61,6 +62,10 @@ class EnsembleSampler(object): to accept a list of position vectors instead of just one. Note that ``pool`` will be ignored if this is ``True``. (default: ``False``) + parameter_names (Optional[Union[List[str], Dict[str, List[int]]]]): + names of individual parameters or groups of parameters. If + specified, the ``log_prob_fn`` will recieve a dictionary of + parameters, rather than a ``np.ndarray``. """ @@ -76,6 +81,7 @@ def __init__( backend=None, vectorize=False, blobs_dtype=None, + parameter_names: Optional[Union[Dict[str, int], List[str]]] = None, # Deprecated... a=None, postargs=None, @@ -157,6 +163,49 @@ def __init__( # ``args`` and ``kwargs`` pickleable. self.log_prob_fn = _FunctionWrapper(log_prob_fn, args, kwargs) + # Save the parameter names + self.params_are_named: bool = parameter_names is not None + if self.params_are_named: + assert isinstance(parameter_names, (list, dict)) + + # Don't support vectorizing yet + msg = "named parameters with vectorization unsupported for now" + assert not self.vectorize, msg + + # Check for duplicate names + dupes = set() + uniq = [] + for name in parameter_names: + if name not in dupes: + uniq.append(name) + dupes.add(name) + msg = f"duplicate paramters: {dupes}" + assert len(uniq) == len(parameter_names), msg + + if isinstance(parameter_names, list): + # Check for all named + msg = "name all parameters or set `parameter_names` to `None`" + assert len(parameter_names) == ndim, msg + # Convert a list to a dict + parameter_names: Dict[str, int] = { + name: i for i, name in enumerate(parameter_names) + } + + # Check not too many names + msg = "too many names" + assert len(parameter_names) <= ndim, msg + + # Check all indices appear + values = [ + v if isinstance(v, list) else [v] + for v in parameter_names.values() + ] + values = [item for sublist in values for item in sublist] + values = set(values) + msg = f"not all values appear -- set should be 0 to {ndim-1}" + assert values == set(np.arange(ndim)), msg + self.parameter_names = parameter_names + @property def random_state(self): """ @@ -251,8 +300,9 @@ def sample( raise ValueError("'store' must be False when 'iterations' is None") # Interpret the input as a walker state and check the dimensions. state = State(initial_state, copy=True) - if np.shape(state.coords) != (self.nwalkers, self.ndim): - raise ValueError("incompatible input dimensions") + state_shape = np.shape(state.coords) + if state_shape != (self.nwalkers, self.ndim): + raise ValueError(f"incompatible input dimensions {state_shape}") if (not skip_initial_state_check) and ( not walkers_independent(state.coords) ): @@ -416,6 +466,10 @@ def compute_log_prob(self, coords): if np.any(np.isnan(p)): raise ValueError("At least one parameter value was NaN") + # If the parmaeters are named, then switch to dictionaries + if self.params_are_named: + p = ndarray_to_list_of_dicts(p, self.parameter_names) + # Run the log-probability calculations (optionally in parallel). if self.vectorize: results = self.log_prob_fn(p) @@ -427,9 +481,7 @@ def compute_log_prob(self, coords): map_func = self.pool.map else: map_func = map - results = list( - map_func(self.log_prob_fn, (p[i] for i in range(len(p)))) - ) + results = list(map_func(self.log_prob_fn, p)) try: log_prob = np.array([float(l[0]) for l in results]) @@ -444,8 +496,9 @@ def compute_log_prob(self, coords): else: try: with warnings.catch_warnings(record=True): - warnings.simplefilter("error", - np.VisibleDeprecationWarning) + warnings.simplefilter( + "error", np.VisibleDeprecationWarning + ) try: dt = np.atleast_1d(blob[0]).dtype except Warning: @@ -455,7 +508,8 @@ def compute_log_prob(self, coords): "placed in an object array. Numpy has " "deprecated this automatic detection, so " "please specify " - "blobs_dtype=np.dtype('object')") + "blobs_dtype=np.dtype('object')" + ) dt = np.dtype("object") except ValueError: dt = np.dtype("object") @@ -557,8 +611,8 @@ class _FunctionWrapper(object): def __init__(self, f, args, kwargs): self.f = f - self.args = [] if args is None else args - self.kwargs = {} if kwargs is None else kwargs + self.args = args or [] + self.kwargs = kwargs or {} def __call__(self, x): try: @@ -605,3 +659,22 @@ def _scaled_cond(a): return np.inf c = b / bsum return np.linalg.cond(c.astype(float)) + + +def ndarray_to_list_of_dicts( + x: np.ndarray, + key_map: Dict[str, Union[int, List[int]]], +) -> List[Dict[str, Union[np.number, np.ndarray]]]: + """ + A helper function to convert a ``np.ndarray`` into a list + of dictionaries of parameters. Used when parameters are named. + + Args: + x (np.ndarray): parameter array of shape ``(N, n_dim)``, where + ``N`` is an integer + key_map (Dict[str, Union[int, List[int]]): + + Returns: + list of dictionaries of parameters + """ + return [{key: xi[val] for key, val in key_map.items()} for xi in x] diff --git a/src/emcee/tests/unit/test_ensemble.py b/src/emcee/tests/unit/test_ensemble.py new file mode 100644 index 00000000..f6f5ad27 --- /dev/null +++ b/src/emcee/tests/unit/test_ensemble.py @@ -0,0 +1,184 @@ +""" +Unit tests of some functionality in ensemble.py when the parameters are named +""" +import string +from unittest import TestCase + +import numpy as np +import pytest + +from emcee.ensemble import EnsembleSampler, ndarray_to_list_of_dicts + + +class TestNP2ListOfDicts(TestCase): + def test_ndarray_to_list_of_dicts(self): + # Try different numbers of keys + for n_keys in [1, 2, 10, 26]: + keys = list(string.ascii_lowercase[:n_keys]) + key_set = set(keys) + key_dict = {key: i for i, key in enumerate(keys)} + # Try different number of walker/procs + for N in [1, 2, 3, 10, 100]: + x = np.random.rand(N, n_keys) + + LOD = ndarray_to_list_of_dicts(x, key_dict) + assert len(LOD) == N, "need 1 dict per row" + for i, dct in enumerate(LOD): + assert dct.keys() == key_set, "keys are missing" + for j, key in enumerate(keys): + assert dct[key] == x[i, j], f"wrong value at {(i, j)}" + + +class TestNamedParameters(TestCase): + """ + Test that a keyword-based log-probability function instead of + a positional. + """ + + # Keyword based lnpdf + def lnpdf(self, pars) -> np.float64: + mean = pars["mean"] + var = pars["var"] + if var <= 0: + return -np.inf + return ( + -0.5 * ((mean - self.x) ** 2 / var + np.log(2 * np.pi * var)).sum() + ) + + def lnpdf_mixture(self, pars) -> np.float64: + mean1 = pars["mean1"] + var1 = pars["var1"] + mean2 = pars["mean2"] + var2 = pars["var2"] + if var1 <= 0 or var2 <= 0: + return -np.inf + return ( + -0.5 + * ( + (mean1 - self.x) ** 2 / var1 + + np.log(2 * np.pi * var1) + + (mean2 - self.x - 3) ** 2 / var2 + + np.log(2 * np.pi * var2) + ).sum() + ) + + def lnpdf_mixture_grouped(self, pars) -> np.float64: + mean1, mean2 = pars["means"] + var1, var2 = pars["vars"] + const = pars["constant"] + if var1 <= 0 or var2 <= 0: + return -np.inf + return ( + -0.5 + * ( + (mean1 - self.x) ** 2 / var1 + + np.log(2 * np.pi * var1) + + (mean2 - self.x - 3) ** 2 / var2 + + np.log(2 * np.pi * var2) + ).sum() + + const + ) + + def setUp(self): + # Draw some data from a unit Gaussian + self.x = np.random.randn(100) + self.names = ["mean", "var"] + + def test_named_parameters(self): + sampler = EnsembleSampler( + nwalkers=10, + ndim=len(self.names), + log_prob_fn=self.lnpdf, + parameter_names=self.names, + ) + assert sampler.params_are_named + assert list(sampler.parameter_names.keys()) == self.names + + def test_asserts(self): + # ndim name mismatch + with pytest.raises(AssertionError): + _ = EnsembleSampler( + nwalkers=10, + ndim=len(self.names) - 1, + log_prob_fn=self.lnpdf, + parameter_names=self.names, + ) + + # duplicate names + with pytest.raises(AssertionError): + _ = EnsembleSampler( + nwalkers=10, + ndim=3, + log_prob_fn=self.lnpdf, + parameter_names=["a", "b", "a"], + ) + + # vectorize turned on + with pytest.raises(AssertionError): + _ = EnsembleSampler( + nwalkers=10, + ndim=len(self.names), + log_prob_fn=self.lnpdf, + parameter_names=self.names, + vectorize=True, + ) + + def test_compute_log_prob(self): + # Try different numbers of walkers + for N in [4, 8, 10]: + sampler = EnsembleSampler( + nwalkers=N, + ndim=len(self.names), + log_prob_fn=self.lnpdf, + parameter_names=self.names, + ) + coords = np.random.rand(N, len(self.names)) + lnps, _ = sampler.compute_log_prob(coords) + assert len(lnps) == N + assert lnps.dtype == np.float64 + + def test_compute_log_prob_mixture(self): + names = ["mean1", "var1", "mean2", "var2"] + # Try different numbers of walkers + for N in [8, 10, 20]: + sampler = EnsembleSampler( + nwalkers=N, + ndim=len(names), + log_prob_fn=self.lnpdf_mixture, + parameter_names=names, + ) + coords = np.random.rand(N, len(names)) + lnps, _ = sampler.compute_log_prob(coords) + assert len(lnps) == N + assert lnps.dtype == np.float64 + + def test_compute_log_prob_mixture_grouped(self): + names = {"means": [0, 1], "vars": [2, 3], "constant": 4} + # Try different numbers of walkers + for N in [8, 10, 20]: + sampler = EnsembleSampler( + nwalkers=N, + ndim=5, + log_prob_fn=self.lnpdf_mixture_grouped, + parameter_names=names, + ) + coords = np.random.rand(N, 5) + lnps, _ = sampler.compute_log_prob(coords) + assert len(lnps) == N + assert lnps.dtype == np.float64 + + def test_run_mcmc(self): + # Sort of an integration test + n_walkers = 4 + sampler = EnsembleSampler( + nwalkers=n_walkers, + ndim=len(self.names), + log_prob_fn=self.lnpdf, + parameter_names=self.names, + ) + guess = np.random.rand(n_walkers, len(self.names)) + n_steps = 50 + results = sampler.run_mcmc(guess, n_steps) + assert results.coords.shape == (n_walkers, len(self.names)) + chain = sampler.chain + assert chain.shape == (n_walkers, n_steps, len(self.names))