|
2 | 2 | import pytest
|
3 | 3 |
|
4 | 4 | from .test_architectures import generate_data
|
5 |
| -from .util import nps, approx # noqa |
| 5 | +from .util import approx, nps # noqa |
6 | 6 |
|
7 | 7 |
|
8 | 8 | @pytest.mark.flaky(reruns=3)
|
@@ -31,3 +31,23 @@ def test_convgnp_mask(nps):
|
31 | 31 | # Check that the two ways of doing it coincide.
|
32 | 32 | approx(pred.mean, pred_masked.mean)
|
33 | 33 | approx(pred.var, pred_masked.var)
|
| 34 | + |
| 35 | + |
| 36 | +@pytest.mark.parametrize("ns", [(10,), (0,), (10, 5), (10, 0), (0, 10), (15, 5, 10)]) |
| 37 | +@pytest.mark.parametrize("multiple", [1, 2, 3, 5]) |
| 38 | +def test_mask_contexts(nps, ns, multiple): |
| 39 | + x, y = nps.merge_contexts( |
| 40 | + *((B.randn(nps.dtype, 2, 3, n), B.randn(nps.dtype, 2, 4, n)) for n in ns), |
| 41 | + multiple=multiple |
| 42 | + ) |
| 43 | + |
| 44 | + # Test that the output is of the right shape. |
| 45 | + if max(ns) == 0: |
| 46 | + assert B.shape(y.y, 2) == multiple |
| 47 | + else: |
| 48 | + assert B.shape(y.y, 2) == ((max(ns) - 1) // multiple + 1) * multiple |
| 49 | + |
| 50 | + # Test that the mask is right. |
| 51 | + mask = y.mask == 1 # Convert mask to booleans. |
| 52 | + assert B.all(B.take(B.flatten(y.y), B.flatten(mask)) != 0) |
| 53 | + assert B.all(B.take(B.flatten(y.y), B.flatten(~mask)) == 0) |
0 commit comments