Skip to content

Commit ca8fb39

Browse files
authored
Fix DiscreteUniform.enumerate_support with non-trivial batch shape (#1859)
* fix DiscreteUniform enumerate_support * make sure that low and high are concrete values
1 parent a7a2f31 commit ca8fb39

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

numpyro/distributions/discrete.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,9 @@ def enumerate_support(self, expand=True):
469469
raise NotImplementedError(
470470
"Inhomogeneous `high` not supported by `enumerate_support`."
471471
)
472-
values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1)).reshape(
473-
(-1,) + (1,) * len(self.batch_shape)
474-
)
472+
low = np.reshape(self.low, -1)[0]
473+
high = np.reshape(self.high, -1)[0]
474+
values = jnp.arange(low, high + 1).reshape((-1,) + (1,) * len(self.batch_shape))
475475
if expand:
476476
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
477477
return values

test/test_distributions.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2742,13 +2742,14 @@ def test_generated_sample_distribution(
27422742
@pytest.mark.parametrize(
27432743
"jax_dist, params, support",
27442744
[
2745-
(dist.BernoulliLogits, (5.0,), jnp.arange(2)),
2746-
(dist.BernoulliProbs, (0.5,), jnp.arange(2)),
2747-
(dist.BinomialLogits, (4.5, 10), jnp.arange(11)),
2748-
(dist.BinomialProbs, (0.5, 11), jnp.arange(12)),
2749-
(dist.BetaBinomial, (2.0, 0.5, 12), jnp.arange(13)),
2750-
(dist.CategoricalLogits, (np.array([3.0, 4.0, 5.0]),), jnp.arange(3)),
2751-
(dist.CategoricalProbs, (np.array([0.1, 0.5, 0.4]),), jnp.arange(3)),
2745+
(dist.BernoulliLogits, (5.0,), np.arange(2)),
2746+
(dist.BernoulliProbs, (0.5,), np.arange(2)),
2747+
(dist.BinomialLogits, (4.5, 10), np.arange(11)),
2748+
(dist.BinomialProbs, (0.5, 11), np.arange(12)),
2749+
(dist.BetaBinomial, (2.0, 0.5, 12), np.arange(13)),
2750+
(dist.CategoricalLogits, (np.array([3.0, 4.0, 5.0]),), np.arange(3)),
2751+
(dist.CategoricalProbs, (np.array([0.1, 0.5, 0.4]),), np.arange(3)),
2752+
(dist.DiscreteUniform, (2, 4), np.arange(2, 5)),
27522753
],
27532754
)
27542755
@pytest.mark.parametrize("batch_shape", [(5,), ()])
@@ -3333,8 +3334,8 @@ def test_normal_log_cdf():
33333334
"value",
33343335
[
33353336
-15.0,
3336-
jnp.array([[-15.0], [-10.0], [-5.0]]),
3337-
jnp.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]),
3337+
np.array([[-15.0], [-10.0], [-5.0]]),
3338+
np.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]),
33383339
],
33393340
)
33403341
def test_truncated_normal_log_prob_in_tail(value):

0 commit comments

Comments
 (0)