Skip to content
6 changes: 3 additions & 3 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,9 @@ def enumerate_support(self, expand=True):
raise NotImplementedError(
"Inhomogeneous `high` not supported by `enumerate_support`."
)
values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1)).reshape(
(-1,) + (1,) * len(self.batch_shape)
)
low = jnp.reshape(self.low, -1)[0]
high = jnp.reshape(self.high, -1)[0]
values = jnp.arange(low, high + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
Expand Down
82 changes: 72 additions & 10 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np

import jax
from jax import device_put, grad, jacfwd, random, value_and_grad
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
Expand Down Expand Up @@ -192,12 +193,22 @@ def __getstate__(self):


def _discrete_gibbs_proposal_body_fn(
z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val
z_init_flat,
unravel_fn,
pe_init,
potential_fn,
idx,
i,
val,
support_size,
support_enumerate,
):
rng_key, z, pe, log_weight_sum = val
rng_key, rng_transition = random.split(rng_key)
proposal = jnp.where(i >= z_init_flat[idx], i + 1, i)
z_new_flat = z_init_flat.at[idx].set(proposal)
proposal_index = jnp.where(
support_enumerate[i] == z_init_flat[idx], support_size - 1, i
)
z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
log_weight_new = pe_init - pe_new
Expand All @@ -216,7 +227,9 @@ def _discrete_gibbs_proposal_body_fn(
return rng_key, z, pe, log_weight_sum


def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
def _discrete_gibbs_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
):
# idx: current index of `z_discrete_flat` to update
# support_size: support size of z_discrete at the index idx

Expand All @@ -234,6 +247,8 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support
pe,
potential_fn,
idx,
support_size=support_size,
support_enumerate=support_enumerate,
)
init_val = (rng_key, z_discrete, pe, jnp.array(0.0))
rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn, init_val)
Expand All @@ -242,7 +257,14 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support


def _discrete_modified_gibbs_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
rng_key,
z_discrete,
pe,
potential_fn,
idx,
support_size,
support_enumerate,
stay_prob=0.0,
):
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
Expand All @@ -253,6 +275,8 @@ def _discrete_modified_gibbs_proposal(
pe,
potential_fn,
idx,
support_size=support_size,
support_enumerate=support_enumerate,
)
# like gibbs_step but here, weight of the current value is 0
init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf))
Expand All @@ -276,28 +300,41 @@ def _discrete_modified_gibbs_proposal(
return rng_key, z_new, pe_new, log_accept_ratio


def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
def _discrete_rw_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
):
rng_key, rng_proposal = random.split(rng_key, 2)
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
log_accept_ratio = pe - pe_new
return rng_key, z_new, pe_new, log_accept_ratio


def _discrete_modified_rw_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
rng_key,
z_discrete,
pe,
potential_fn,
idx,
support_size,
support_enumerate,
stay_prob=0.0,
):
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
rng_key, rng_proposal, rng_stay = random.split(rng_key, 3)
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
proposal_index = jnp.where(
support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i
)
proposal = jnp.where(
random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index]
)
z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
Expand Down Expand Up @@ -434,6 +471,31 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
and site["fn"].has_enumerate_support
and not site["is_observed"]
}

# All support_enumerates should have the same length to be used in the loop
# Each support is padded with zeros to have the same length
# ravel is used to maintain a consistant behaviour with `support_sizes`

max_length_support_enumerates = np.max(
[size for size in self._support_sizes.values()]
)

support_enumerates = {}
for name, support_size in self._support_sizes.items():
site = self._prototype_trace[name]
enumerate_support = site["fn"].enumerate_support(True).T
# Only the last dimension that corresponds to support size is padded
pad_width = [(0, 0) for _ in range(len(enumerate_support.shape) - 1)] + [
(0, max_length_support_enumerates - enumerate_support.shape[-1])
]
padded_enumerate_support = np.pad(enumerate_support, pad_width)

support_enumerates[name] = padded_enumerate_support

self._support_enumerates = jax.vmap(
lambda x: ravel_pytree(x)[0], in_axes=len(support_size.shape), out_axes=1
)(support_enumerates)

self._gibbs_sites = [
name
for name, site in self._prototype_trace.items()
Expand Down
1 change: 1 addition & 0 deletions numpyro/infer/mixed_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def update_discrete(
partial(potential_fn, z_hmc=hmc_state.z),
idx,
self._support_sizes_flat[idx],
self._support_enumerates[idx],
)
# Algo 1, line 20: depending on reject or refract, we will update
# the discrete variable and its corresponding kinetic energy. In case of
Expand Down
89 changes: 89 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3423,3 +3423,92 @@ def test_gaussian_random_walk_linear_recursive_equivalence():
x2 = dist2.sample(random.PRNGKey(7))
assert jnp.allclose(x1, x2.squeeze())
assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2))


def test_discrete_uniform_with_mixedhmc():
import numpyro
import numpyro.distributions as dist
from numpyro.infer import HMC, MCMC, MixedHMC

def sample_mixedhmc(model_fn, num_samples, **kwargs):
kernel = HMC(model_fn, trajectory_length=1.2)
kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs)
mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False)
key = jax.random.PRNGKey(0)
mcmc.run(key)
samples = mcmc.get_samples()
return samples

num_samples = 1000
mixed_hmc_kwargs = [
{"random_walk": False, "modified": False},
{"random_walk": True, "modified": False},
{"random_walk": True, "modified": True},
{"random_walk": False, "modified": True},
]

# Case 1: one discrete uniform with one categorical
def model_1():
numpyro.sample("x0", dist.DiscreteUniform(10, 12))
numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25])))

for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_1, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 10) & (samples["x0"] <= 12)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 0) & (samples["x1"] <= 3)
), f"Failed with {kwargs=}"

def model_2():
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((4,))))
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((10,))))

# Case 2: 2 categorical with different support lengths
for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_2, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 0) & (samples["x0"] <= 3)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 0) & (samples["x1"] <= 9)
), f"Failed with {kwargs=}"

def model_3():
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((3, 4))))
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((3, 10))))

# Case 3: 2 categorical with different support lengths and batched by 3
for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_3, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 0) & (samples["x0"] <= 3)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 0) & (samples["x1"] <= 9)
), f"Failed with {kwargs=}"

def model_4():
dist0 = dist.Categorical(0.25 * jnp.ones((3, 4)))
numpyro.sample("x0", dist0)
dist1 = dist.DiscreteUniform(10 * jnp.ones((3,)), 19 * jnp.ones((3,)))
numpyro.sample("x1", dist1)

# Case 4: 1 categorical with different support lengths and batched by 3
for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_4, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 0) & (samples["x0"] <= 3)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 10) & (samples["x1"] <= 20)
), f"Failed with {kwargs=}"


if __name__ == "__main__":
test_discrete_uniform_with_mixedhmc()
Loading