@@ -2742,13 +2742,14 @@ def test_generated_sample_distribution(
2742
2742
@pytest .mark .parametrize (
2743
2743
"jax_dist, params, support" ,
2744
2744
[
2745
- (dist .BernoulliLogits , (5.0 ,), jnp .arange (2 )),
2746
- (dist .BernoulliProbs , (0.5 ,), jnp .arange (2 )),
2747
- (dist .BinomialLogits , (4.5 , 10 ), jnp .arange (11 )),
2748
- (dist .BinomialProbs , (0.5 , 11 ), jnp .arange (12 )),
2749
- (dist .BetaBinomial , (2.0 , 0.5 , 12 ), jnp .arange (13 )),
2750
- (dist .CategoricalLogits , (np .array ([3.0 , 4.0 , 5.0 ]),), jnp .arange (3 )),
2751
- (dist .CategoricalProbs , (np .array ([0.1 , 0.5 , 0.4 ]),), jnp .arange (3 )),
2745
+ (dist .BernoulliLogits , (5.0 ,), np .arange (2 )),
2746
+ (dist .BernoulliProbs , (0.5 ,), np .arange (2 )),
2747
+ (dist .BinomialLogits , (4.5 , 10 ), np .arange (11 )),
2748
+ (dist .BinomialProbs , (0.5 , 11 ), np .arange (12 )),
2749
+ (dist .BetaBinomial , (2.0 , 0.5 , 12 ), np .arange (13 )),
2750
+ (dist .CategoricalLogits , (np .array ([3.0 , 4.0 , 5.0 ]),), np .arange (3 )),
2751
+ (dist .CategoricalProbs , (np .array ([0.1 , 0.5 , 0.4 ]),), np .arange (3 )),
2752
+ (dist .DiscreteUniform , (2 , 4 ), np .arange (2 , 5 )),
2752
2753
],
2753
2754
)
2754
2755
@pytest .mark .parametrize ("batch_shape" , [(5 ,), ()])
@@ -3333,8 +3334,8 @@ def test_normal_log_cdf():
3333
3334
"value" ,
3334
3335
[
3335
3336
- 15.0 ,
3336
- jnp .array ([[- 15.0 ], [- 10.0 ], [- 5.0 ]]),
3337
- jnp .array ([[[- 15.0 ], [- 10.0 ], [- 5.0 ]], [[- 14.0 ], [- 9.0 ], [- 4.0 ]]]),
3337
+ np .array ([[- 15.0 ], [- 10.0 ], [- 5.0 ]]),
3338
+ np .array ([[[- 15.0 ], [- 10.0 ], [- 5.0 ]], [[- 14.0 ], [- 9.0 ], [- 4.0 ]]]),
3338
3339
],
3339
3340
)
3340
3341
def test_truncated_normal_log_prob_in_tail (value ):
0 commit comments