Skip to content

Commit 46bd53e

Browse files
committed
Fix off by one bug (thanks @tom-andersson!)
1 parent 33a6a49 commit 46bd53e

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

neuralprocesses/mask.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ def _pad_zeros(x: B.Numeric, up_to: int, axis: int):
4040
return B.concat(x, B.zeros(B.dtype(x), *shape), axis=axis)
4141

4242

43+
def _ceil_to_closest_multiple(n, m):
44+
d, r = divmod(n, m)
45+
# If `n` is zero, then we must also round up.
46+
if n == 0 or r > 0:
47+
return (d + 1) * m
48+
else:
49+
return d * m
50+
51+
4352
@_dispatch
4453
def _determine_ns(xc: tuple, multiple: Union[int, tuple]):
4554
ns = [B.shape(xci, 2) for xci in xc]
@@ -48,7 +57,7 @@ def _determine_ns(xc: tuple, multiple: Union[int, tuple]):
4857
multiple = (multiple,) * len(ns)
4958

5059
# Ceil to the closest multiple of `multiple`.
51-
ns = [((n - 1) // m + 1) * m for n, m in zip(ns, multiple)]
60+
ns = [_ceil_to_closest_multiple(n, m) for n, m in zip(ns, multiple)]
5261

5362
return ns
5463

tests/test_mask.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33

44
from .test_architectures import generate_data
5-
from .util import nps, approx # noqa
5+
from .util import approx, nps # noqa
66

77

88
@pytest.mark.flaky(reruns=3)
@@ -31,3 +31,23 @@ def test_convgnp_mask(nps):
3131
# Check that the two ways of doing it coincide.
3232
approx(pred.mean, pred_masked.mean)
3333
approx(pred.var, pred_masked.var)
34+
35+
36+
@pytest.mark.parametrize("ns", [(10,), (0,), (10, 5), (10, 0), (0, 10), (15, 5, 10)])
37+
@pytest.mark.parametrize("multiple", [1, 2, 3, 5])
38+
def test_mask_contexts(nps, ns, multiple):
39+
x, y = nps.merge_contexts(
40+
*((B.randn(nps.dtype, 2, 3, n), B.randn(nps.dtype, 2, 4, n)) for n in ns),
41+
multiple=multiple
42+
)
43+
44+
# Test that the output is of the right shape.
45+
if max(ns) == 0:
46+
assert B.shape(y.y, 2) == multiple
47+
else:
48+
assert B.shape(y.y, 2) == ((max(ns) - 1) // multiple + 1) * multiple
49+
50+
# Test that the mask is right.
51+
mask = y.mask == 1 # Convert mask to booleans.
52+
assert B.all(B.take(B.flatten(y.y), B.flatten(mask)) != 0)
53+
assert B.all(B.take(B.flatten(y.y), B.flatten(~mask)) == 0)

0 commit comments

Comments
 (0)