Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 60 additions & 24 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"""

from collections import OrderedDict
from functools import reduce
import warnings

import numpy as np
Expand Down Expand Up @@ -267,6 +268,20 @@ def process_message(self, msg):
msg["stop"] = True


def _eager_expand_fn(fn):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that this function eliminates lazy ExpandedDistributions and replaces them with eagerly tree_map-expanded distributions, and that this is needed because Funsor distributions do not support Expanded distributions? If so could you add a commend explaining that?

I guess an alternative would be to support ExpandedDistributions in Funsor, or to do this conversion in the funsor layer. @eb8680 does that seem possible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is. I'll add a comment for clarity. I thought from our last discussion, this stuff should be treated in Pyro/NumPyro?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess an alternative would be to support ExpandedDistributions in Funsor, or to do this conversion in the funsor layer.

We could certainly implement an equivalent of .expand. I think the easiest thing to do to address this case would be to add a to_funsor pattern for ExpandedDistribution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a to_funsor pattern for ExpandedDistribution

Yeah, it seems that you both agree to do this. I can take an effort for it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would basically be the same as what you have here, plus a final step that calls to_funsor on the eagerly expanded result.

if isinstance(fn, Independent):
reinterpreted_batch_ndims = fn.reinterpreted_batch_ndims
fn = fn.base_dist
else:
reinterpreted_batch_ndims = 0 # no-op for to_event method
if isinstance(fn, ExpandedDistribution):
batch_shape = fn.batch_shape
base_batch_shape = fn.base_dist.batch_shape
appended_shape = batch_shape[:len(batch_shape) - len(base_batch_shape)]
fn = tree_map(lambda x: jnp.broadcast_to(x, appended_shape + jnp.shape(x)), fn.base_dist)
return fn.to_event(reinterpreted_batch_ndims)


class collapse(trace):
"""
EXPERIMENTAL Collapses all sites in the context by lazily sampling and
Expand All @@ -287,14 +302,24 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def process_message(self, msg):
from funsor.terms import Funsor
if msg["type"] != "sample":
return

if msg["type"] == "sample":
if msg["value"] is None:
msg["value"] = msg["name"]
import funsor

# Eagerly convert fn and value to Funsor.
dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]}
dim_to_name.update(self.preserved_plates)
if isinstance(msg["fn"], (Independent, ExpandedDistribution)):
msg["fn"] = _eager_expand_fn(msg["fn"])
msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name)
domain = msg["fn"].inputs["value"]
if msg["value"] is None:
msg["value"] = funsor.Variable(msg["name"], domain)
else:
msg["value"] = funsor.to_funsor(msg["value"], domain, dim_to_name)

if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)):
msg["stop"] = True
msg["stop"] = True

def __enter__(self):
self.preserved_plates = frozenset(
Expand All @@ -304,15 +329,22 @@ def __enter__(self):
return super().__enter__()

def __exit__(self, exc_type, exc_value, traceback):
import funsor

_coerce = COERCIONS.pop()
assert _coerce is self._coerce
super().__exit__(exc_type, exc_value, traceback)

if exc_type is not None:
self.trace.clear()
self.preserved_plates.clear()
return

if any(site["type"] == "sample" for site in self.trace.values()):
name, log_prob, _, _ = self._get_log_prob()
numpyro.factor(name, log_prob.data)

def _get_log_prob(self):
import funsor

# Convert delayed statements to pyro.factor()
reduced_vars = []
log_prob_terms = []
Expand All @@ -322,24 +354,28 @@ def __exit__(self, exc_type, exc_value, traceback):
continue
if not site["is_observed"]:
reduced_vars.append(name)
dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]}
fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name)
value = site["value"]
if not isinstance(value, str):
value = funsor.to_funsor(site["value"], fn.inputs["value"], dim_to_name)
log_prob_terms.append(fn(value=value))
log_prob_terms.append(site["fn"](value=site["value"]))
plates |= frozenset(f.name for f in site["cond_indep_stack"])
assert log_prob_terms, "nothing to collapse"
reduced_plates = plates - self.preserved_plates
log_prob = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
log_prob_terms,
eliminate=frozenset(reduced_vars) | reduced_plates,
plates=plates,
)
name = reduced_vars[0]
numpyro.factor(name, log_prob.data)
reduced_vars = frozenset(reduced_vars)
assert log_prob_terms, "nothing to collapse"
reduced_plates = plates - frozenset(self.preserved_plates.values())
self.trace.clear()
self.preserved_plates.clear()
if reduced_plates:
log_prob = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
log_prob_terms,
eliminate=frozenset(reduced_vars) | reduced_plates,
plates=plates,
)
log_joint = NotImplemented
else:
log_joint = reduce(funsor.ops.add, log_prob_terms)
log_prob = log_joint.reduce(funsor.ops.logaddexp, reduced_vars)

return name, log_prob, log_joint, reduced_vars


class condition(Messenger):
Expand Down
4 changes: 3 additions & 1 deletion numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jax.numpy as jnp

import numpyro
from numpyro.distributions.distribution import Distribution
from numpyro.util import identity

_PYRO_STACK = []
Expand Down Expand Up @@ -501,7 +502,8 @@ def process_message(self, msg):
cond_indep_stack = msg["cond_indep_stack"]
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
if msg["type"] == "sample":
# only expand if fn is Distribution, not a Funsor
if msg['type'] == 'sample' and isinstance(msg['fn'], Distribution):
expected_shape = self._get_batch_shape(cond_indep_stack)
dist_batch_shape = msg["fn"].batch_shape
if "sample_shape" in msg["kwargs"]:
Expand Down
110 changes: 109 additions & 1 deletion test/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,6 @@ def guide():
svi.update(svi_state)


@pytest.mark.xfail(reason="missing pattern in Funsor")
def test_collapse_beta_binomial_plate():
data = np.array([0.0, 1.0, 5.0, 5.0])

Expand All @@ -734,6 +733,115 @@ def guide():
svi.update(svi_state)


def test_collapse_normal_normal():
data = np.array(0.)

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
y = numpyro.sample("y", dist.Normal(x, 1.))
numpyro.sample("z", dist.Normal(y, 1.), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_collapse_normal_normal_plate():
data = np.arange(5.)

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
y = numpyro.sample("y", dist.Normal(x, 1.))
with handlers.plate("data", len(data)):
numpyro.sample("z", dist.Normal(y, 1.), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_collapse_normal_plate_normal():
data = np.arange(5.)

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
with handlers.plate("data", len(data)):
y = numpyro.sample("y", dist.Normal(x, 1.))
numpyro.sample("z", dist.Normal(y, 1.), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


@pytest.mark.xfail(reason="missing pattern in Funsor")
def test_collapse_diag_normal_plate_normal():
d = 3
data = np.ones((5, d))

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
with handlers.plate("data", len(data)):
y = numpyro.sample("y", dist.Normal(x, 1.).expand([d]).to_event(1))
numpyro.sample("z", dist.Normal(y, 1.).to_event(1), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


@pytest.mark.xfail(reason="missing pattern in Funsor")
def test_collapse_normal_mvn_mvn():
T, d, S = 5, 2, 3
data = jnp.ones((T, S))

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
with numpyro.plate("d", d, dim=-1):
beta0 = numpyro.sample("beta0", dist.Normal(x, 1.).expand([d, S]).to_event(1))
beta = numpyro.sample(
"beta", dist.MultivariateNormal(beta0, scale_tril=jnp.eye(S)))

# this fails because beta shape is (3,) while it should be (2, 3)
mean = jnp.ones((T, d)) @ beta
with numpyro.plate("data", T, dim=-1):
numpyro.sample("obs", dist.MultivariateNormal(mean, scale_tril=jnp.eye(S)), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_prng_key():
assert numpyro.prng_key() is None

Expand Down