@@ -249,23 +249,25 @@ def test_forward(nps, model_sample):
249
249
check_prediction (nps , pred , yt )
250
250
251
251
252
+ @pytest .mark .parametrize ("normalise" , [False , True ])
252
253
@pytest .mark .flaky (reruns = 3 )
253
- def test_elbo (nps , model_sample ):
254
+ def test_elbo (nps , model_sample , normalise ):
254
255
model , sample = model_sample
255
256
model = model ()
256
257
xc , yc , xt , yt = sample ()
257
- elbos = nps .elbo (model , xc , yc , xt , yt , num_samples = 2 )
258
+ elbos = nps .elbo (model , xc , yc , xt , yt , num_samples = 2 , normalise = normalise )
258
259
assert B .rank (elbos ) == 1
259
260
assert np .isfinite (B .to_numpy (B .sum (elbos )))
260
261
assert B .dtype (elbos ) == nps .dtype64
261
262
262
263
264
+ @pytest .mark .parametrize ("normalise" , [False , True ])
263
265
@pytest .mark .flaky (reruns = 3 )
264
- def test_loglik (nps , model_sample ):
266
+ def test_loglik (nps , model_sample , normalise ):
265
267
model , sample = model_sample
266
268
model = model ()
267
269
xc , yc , xt , yt = sample ()
268
- logpdfs = nps .loglik (model , xc , yc , xt , yt , num_samples = 2 )
270
+ logpdfs = nps .loglik (model , xc , yc , xt , yt , num_samples = 2 , normalise = normalise )
269
271
assert B .rank (logpdfs ) == 1
270
272
assert np .isfinite (B .to_numpy (B .sum (logpdfs )))
271
273
assert B .dtype (logpdfs ) == nps .dtype64
0 commit comments