Skip to content

Commit adc842c

Browse files
committed
Fix data type bug
1 parent 4a21213 commit adc842c

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

neuralprocesses/model/elbo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def elbo(
9898

9999
if normalise:
100100
# Normalise by the number of targets.
101-
elbos = elbos / num_data(xt, yt)
101+
elbos = elbos / B.cast(float64, num_data(xt, yt))
102102

103103
return state, elbos
104104

neuralprocesses/model/loglik.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def loglik(
8989

9090
if normalise:
9191
# Normalise by the number of targets.
92-
logpdfs = logpdfs / num_data(xt, yt)
92+
logpdfs = logpdfs / B.cast(float64, num_data(xt, yt))
9393

9494
return state, logpdfs
9595

tests/test_architectures.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -249,23 +249,25 @@ def test_forward(nps, model_sample):
249249
check_prediction(nps, pred, yt)
250250

251251

252+
@pytest.mark.parametrize("normalise", [False, True])
252253
@pytest.mark.flaky(reruns=3)
253-
def test_elbo(nps, model_sample):
254+
def test_elbo(nps, model_sample, normalise):
254255
model, sample = model_sample
255256
model = model()
256257
xc, yc, xt, yt = sample()
257-
elbos = nps.elbo(model, xc, yc, xt, yt, num_samples=2)
258+
elbos = nps.elbo(model, xc, yc, xt, yt, num_samples=2, normalise=normalise)
258259
assert B.rank(elbos) == 1
259260
assert np.isfinite(B.to_numpy(B.sum(elbos)))
260261
assert B.dtype(elbos) == nps.dtype64
261262

262263

264+
@pytest.mark.parametrize("normalise", [False, True])
263265
@pytest.mark.flaky(reruns=3)
264-
def test_loglik(nps, model_sample):
266+
def test_loglik(nps, model_sample, normalise):
265267
model, sample = model_sample
266268
model = model()
267269
xc, yc, xt, yt = sample()
268-
logpdfs = nps.loglik(model, xc, yc, xt, yt, num_samples=2)
270+
logpdfs = nps.loglik(model, xc, yc, xt, yt, num_samples=2, normalise=normalise)
269271
assert B.rank(logpdfs) == 1
270272
assert np.isfinite(B.to_numpy(B.sum(logpdfs)))
271273
assert B.dtype(logpdfs) == nps.dtype64

0 commit comments

Comments
 (0)