diff --git a/setup.py b/setup.py index 1f4f42b609..df62492726 100644 --- a/setup.py +++ b/setup.py @@ -17,13 +17,11 @@ REQUIRED_MINOR = 7 INSTALL_REQUIRES = [ - "arviz>=0.12.1", "astor>=0.7.1", "black==22.3.0", "botorch>=0.5.1", "gpytorch>=1.3.0, <1.9.0", "graphviz>=0.17", - "netCDF4<=1.5.8; python_version<'3.8'", "numpy>=1.18.1", "pandas>=0.24.2", "plotly>=2.2.1", @@ -36,6 +34,7 @@ ] TEST_REQUIRES = ["pytest>=7.0.0", "pytest-cov"] TUTORIALS_REQUIRES = [ + "arviz>=0.12.1", "bokeh", "cma", "ipywidgets", @@ -44,6 +43,7 @@ "matplotlib", "mdformat", "mdformat-myst", + "netCDF4<=1.5.8; python_version<'3.8'", "scikit-learn>=1.0.0", "seaborn", "tabulate", diff --git a/src/beanmachine/ppl/inference/monte_carlo_samples.py b/src/beanmachine/ppl/inference/monte_carlo_samples.py index 8345bd6e79..3bb64ccc52 100644 --- a/src/beanmachine/ppl/inference/monte_carlo_samples.py +++ b/src/beanmachine/ppl/inference/monte_carlo_samples.py @@ -10,6 +10,7 @@ import xarray as xr from beanmachine.ppl.inference.utils import detach_samples, merge_dicts from beanmachine.ppl.model.rv_identifier import RVIdentifier +from beanmachine.ppl.diagnostics.tools.viz import _requires_dev_packages RVDict = Dict[RVIdentifier, torch.Tensor] @@ -266,10 +267,12 @@ def add_groups(self, mcs: "MonteCarloSamples"): if n not in self.namespaces: self.namespaces[n] = mcs.namespaces[n] - def to_inference_data(self, include_adapt_steps: bool = False) -> az.InferenceData: + @_requires_dev_packages + def to_inference_data(self, include_adapt_steps: bool = False): """ Return an az.InferenceData from MonteCarloSamples. """ + import arviz as az if "posterior" in self.namespaces: posterior = detach_samples(self.namespaces["posterior"].samples)