Skip to content

Commit 33a6a49

Browse files
committed
Check number of context sets and check aux variable
1 parent 35b51de commit 33a6a49

File tree

9 files changed

+198
-16
lines changed

9 files changed

+198
-16
lines changed

neuralprocesses/architectures/convgnp.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import lab as B
21
import neuralprocesses as nps # This fixes inspection below.
3-
import wbml.out as out
42
from plum import convert
53

64
from ..util import register_model
@@ -30,7 +28,7 @@ def _convgnp_resolve_architecture(
3028
elif "conv" in conv_arch:
3129
conv_out_channels = conv_channels
3230
if conv_receptive_field is None:
33-
raise ValueError(f"Must specify `conv_receptive_field`.")
31+
raise ValueError("Must specify `conv_receptive_field`.")
3432
else:
3533
raise ValueError(f'Architecture "{conv_arch}" invalid.')
3634
return conv_out_channels
@@ -62,6 +60,16 @@ def _convgnp_construct_encoder_setconvs(
6260
)
6361

6462

63+
def _convgnp_assert_form_contexts(nps, dim_yc):
64+
if len(dim_yc) == 1:
65+
return nps.Chain(
66+
nps.SqueezeParallel(),
67+
nps.AssertNoParallel(),
68+
)
69+
else:
70+
return nps.AssertParallel(len(dim_yc))
71+
72+
6573
def _convgnp_construct_decoder_setconv(
6674
nps,
6775
decoder_scale,
@@ -159,8 +167,8 @@ def construct_convgnp(
159167
conv_receptive_field (float, optional): Receptive field of the standard
160168
architecture. Must be specified if `conv_arch` is set to `"conv"`.
161169
conv_layers (int, optional): Layers of the standard architecture. Defaults to 8.
162-
conv_channels (int, optional): Channels of the standard architecture. Defaults to
163-
64.
170+
conv_channels (int, optional): Channels of the standard architecture. Defaults
171+
to 64.
164172
num_basis_functions (int, optional): Number of basis functions for the
165173
low-rank likelihood. Defaults to `512`.
166174
dim_lv (int, optional): Dimensionality of the latent variable. Defaults to 0.
@@ -255,6 +263,8 @@ def construct_convgnp(
255263
# Not necessary. Just let the CNN produce the right number of channels.
256264
conv_out_channels = likelihood_in_channels
257265
linear_after_set_conv = lambda x: x
266+
# Also assert that there is no augmentation given.
267+
likelihood = nps.Chain(nps.AssertNoAugmentation(), likelihood)
258268

259269
# Construct the core CNN architectures for the encoder, which is only necessary
260270
# if we're using a latent variable, and for the decoder. First, we determine
@@ -348,6 +358,7 @@ def construct_convgnp(
348358
nps.FunctionalCoder(
349359
disc,
350360
nps.Chain(
361+
_convgnp_assert_form_contexts(nps, dim_yc),
351362
nps.PrependDensityChannel(),
352363
_convgnp_construct_encoder_setconvs(
353364
nps,

neuralprocesses/architectures/fullconvgnp.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
import neuralprocesses as nps # This fixes inspection below.
12
import wbml.out as out
23

3-
import neuralprocesses as nps # This fixes inspection below.
4+
from ..util import register_model
45
from .convgnp import (
5-
_convgnp_init_dims,
6-
_convgnp_resolve_architecture,
6+
_convgnp_assert_form_contexts,
7+
_convgnp_construct_decoder_setconv,
78
_convgnp_construct_encoder_setconvs,
9+
_convgnp_init_dims,
810
_convgnp_optional_division_by_density,
9-
_convgnp_construct_decoder_setconv,
11+
_convgnp_resolve_architecture,
1012
)
1113
from .util import parse_transform
12-
from ..util import register_model
1314

1415
__all__ = ["construct_fullconvgnp"]
1516

@@ -211,6 +212,7 @@ def construct_fullconvgnp(
211212
nps.FunctionalCoder(
212213
disc_mean,
213214
nps.Chain(
215+
_convgnp_assert_form_contexts(nps, dim_yc),
214216
nps.PrependDensityChannel(),
215217
_convgnp_construct_encoder_setconvs(
216218
nps,
@@ -232,6 +234,7 @@ def construct_fullconvgnp(
232234
disc_kernel,
233235
nps.MapDiagonal( # Map to diagonal of squared space.
234236
nps.Chain(
237+
_convgnp_assert_form_contexts(nps, dim_yc),
235238
nps.PrependDensityChannel(),
236239
_convgnp_construct_encoder_setconvs(
237240
nps,

neuralprocesses/architectures/gnp.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from plum import convert
33

44
from ..util import register_model
5+
from .convgnp import _convgnp_assert_form_contexts
56
from .util import construct_likelihood, parse_transform
67

78
__all__ = ["construct_gnp"]
@@ -185,6 +186,7 @@ def construct_mlp(dim_yci):
185186
encoder = nps.Chain(
186187
# We need to explicitly copy, because there will be multiple context sets in
187188
# parallel, which will otherwise dispatch to the wrong method.
189+
_convgnp_assert_form_contexts(nps, dim_yc),
188190
nps.Copy(2 + (dim_lv > 0)),
189191
nps.Parallel(
190192
nps.Chain(

neuralprocesses/coders/augment.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from ..augment import AugmentedInput
55
from ..datadims import data_dims
66
from ..materialise import _repeat_concat
7-
from ..util import register_module, register_composite_coder
7+
from ..util import register_composite_coder, register_module
88

9-
__all__ = ["Augment"]
9+
__all__ = ["Augment", "AssertNoAugmentation"]
1010

1111

1212
@register_composite_coder
@@ -47,3 +47,18 @@ def _augment(xz: AugmentedInput, z: B.Numeric):
4747
@_dispatch
4848
def _augment(xz: AugmentedInput):
4949
return xz.x
50+
51+
52+
@register_module
53+
class AssertNoAugmentation:
54+
"""Assert no augmentation of the target inputs."""
55+
56+
57+
@_dispatch
58+
def code(coder: AssertNoAugmentation, xz, z, x, **kw_args):
59+
return xz, z
60+
61+
62+
@_dispatch
63+
def code(coder: AssertNoAugmentation, xz, z, x: AugmentedInput, **kw_args):
64+
raise AssertionError("Did not expect augmentation of the target inputs.")

neuralprocesses/coders/shaping.py

+87-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
from ..parallel import Parallel
66
from ..util import register_module, split
77

8-
__all__ = ["Identity", "Splitter", "RestructureParallel"]
8+
__all__ = [
9+
"Identity",
10+
"Splitter",
11+
"RestructureParallel",
12+
"AssertParallel",
13+
"SqueezeParallel",
14+
"AssertNoParallel",
15+
]
916

1017

1118
@register_module
@@ -93,3 +100,82 @@ def _restructure_create(element_map, i):
93100
@_dispatch
94101
def _restructure_create(element_map, x: tuple):
95102
return Parallel(*(_restructure_create(element_map, xi) for xi in x))
103+
104+
105+
@register_module
106+
class AssertParallel:
107+
"""Assert a parallel of `n` elements.
108+
109+
Args:
110+
n (int): Number of elements asserted in parallel.
111+
112+
Attributes:
113+
n (int): Number of elements asserted in parallel.
114+
"""
115+
116+
def __init__(self, n):
117+
self.n = n
118+
119+
120+
@_dispatch
121+
def code(p: AssertParallel, xz, z, x, **kw_args):
122+
raise AssertionError(f"Expected a parallel of elements, but got `{xz}` and `{z}`.")
123+
124+
125+
@_dispatch
126+
def code(p: AssertParallel, xz: Parallel, z: Parallel, x, **kw_args):
127+
if not len(xz) == len(z) == p.n:
128+
raise AssertionError(
129+
f"Expected a parallel of {p.n} elements, "
130+
f"but got {len(xz)} inputs and {len(z)} outputs."
131+
)
132+
return xz, z
133+
134+
135+
@register_module
136+
class SqueezeParallel:
137+
"""If there is a parallel of exactly one element, remove the parallel."""
138+
139+
140+
@_dispatch
141+
def code(p: SqueezeParallel, xz, z, x, **kw_args):
142+
return xz, z
143+
144+
145+
@_dispatch
146+
def code(p: SqueezeParallel, xz: Parallel, z: Parallel, x, **kw_args):
147+
if len(xz) == len(z) == 1:
148+
return xz[0], z[0]
149+
else:
150+
return xz, z
151+
152+
153+
@register_module
154+
class AssertNoParallel:
155+
"""Assert exactly one element in parallel or not a parallel of elements."""
156+
157+
158+
@_dispatch
159+
def code(p: AssertNoParallel, xz, z, x, **kw_args):
160+
return xz, z
161+
162+
163+
@_dispatch
164+
def code(p: AssertNoParallel, xz: Parallel, z, x, **kw_args):
165+
raise AssertionError(
166+
"Expected not a parallel of elements, but got inputs in parallel."
167+
)
168+
169+
170+
@_dispatch
171+
def code(p: AssertNoParallel, xz, z: Parallel, x, **kw_args):
172+
raise AssertionError(
173+
"Expected not a parallel of elements, but got outputs in parallel."
174+
)
175+
176+
177+
@_dispatch
178+
def code(p: AssertNoParallel, xz: Parallel, z: Parallel, x, **kw_args):
179+
raise AssertionError(
180+
"Expected not a parallel of elements, but got inputs and outputs in parallel."
181+
)

neuralprocesses/mask.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ class Masked:
1212
"""A masked output.
1313
1414
Args:
15-
y (tensor): Output to mask.
15+
y (tensor): Output to mask. The masked values can have any non-NaN value, but
16+
they cannot be NaN!
1617
mask (tensor): A mask consisting of zeros and ones and just one channel.
1718
1819
Attributes:

neuralprocesses/parallel.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
from matrix.util import indent
33

44
from . import _dispatch
5-
from .util import register_module, is_framework_module
5+
from .util import is_framework_module, register_module
66

7-
__all__ = ["Parallel", "broadcast_coder_over_parallel"]
7+
__all__ = [
8+
"Parallel",
9+
"broadcast_coder_over_parallel",
10+
]
811

912

1013
@register_module

tests/test_architectures.py

+40
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytest
66
from neuralprocesses.aggregate import Aggregate, AggregateInput
7+
from neuralprocesses.parallel import Parallel
78
from plum import isinstance
89

910
from .util import approx, generate_data
@@ -420,3 +421,42 @@ def test_data_eq(nps, dim_x, dim_y, constructor, config, dim_lv):
420421
batch = gen.generate_batch()
421422
pred = model(batch["contexts"], batch["xt"])
422423
check_prediction(nps, pred, batch["yt"])
424+
425+
426+
@pytest.mark.parametrize(
427+
"constructor_kw_args",
428+
[
429+
("construct_gnp", {}),
430+
("construct_agnp", {}),
431+
("construct_convgnp", {"points_per_unit": 4}),
432+
("construct_fullconvgnp", {"points_per_unit": 4}),
433+
],
434+
)
435+
def test_context_verification(nps, constructor_kw_args):
436+
constructor, kw_args = constructor_kw_args
437+
438+
c = (B.randn(nps.dtype, 4, 1, 5), B.randn(nps.dtype, 4, 1, 5))
439+
xt = B.randn(nps.dtype, 4, 1, 15)
440+
441+
# Test one context set.
442+
model = getattr(nps, constructor)(dim_yc=1, dtype=nps.dtype, **kw_args)
443+
pred1 = model(*c, xt)
444+
pred2 = model([c], xt)
445+
approx(pred1.mean, pred2.mean)
446+
with pytest.raises(AssertionError, match="(?i)got inputs and outputs in parallel"):
447+
model([c, c], xt)
448+
with pytest.raises(AssertionError, match="(?i) got inputs in parallel"):
449+
model(Parallel(c[0], c[0]), c[1], xt)
450+
with pytest.raises(AssertionError, match="(?i) got outputs in parallel"):
451+
model(c[0], Parallel(c[1], c[1]), xt)
452+
453+
# Test two context sets.
454+
model = getattr(nps, constructor)(dim_yc=(1, 1), dtype=nps.dtype, **kw_args)
455+
model([c, c], xt)
456+
with pytest.raises(AssertionError, match="(?i)expected a parallel of elements"):
457+
model([c], xt)
458+
with pytest.raises(
459+
AssertionError,
460+
match="(?i)expected a parallel of 2 elements, but got 4 inputs and 4 outputs",
461+
):
462+
model([c, c, c, c], xt)

tests/test_augment.py

+21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import lab as B
22
import pytest
3+
from plum import NotFoundLookupError
34

45
from .test_architectures import check_prediction
56
from .util import nps # noqa
@@ -37,3 +38,23 @@ def test_convgnp_auxiliary_variable(nps):
3738
)
3839

3940
check_prediction(nps, pred, B.randn(nps.dtype, 16, 3, 15))
41+
42+
# Check that the model cannot be run forward without the auxiliary variable.
43+
with pytest.raises(NotFoundLookupError):
44+
model(
45+
[observed_data, aux_var1, aux_var2],
46+
B.randn(nps.dtype, 16, 2, 15),
47+
)
48+
49+
50+
def test_convgnp_auxiliary_variable_given_but_not_specified(nps):
51+
"""Test that giving the auxiliary variable without specifying `dim_aux_t` raises
52+
an error."""
53+
model = nps.construct_convgnp(points_per_unit=4)
54+
with pytest.raises(AssertionError, match="(?i)did not expect augmentation"):
55+
model(
56+
B.randn(nps.dtype, 4, 1, 15),
57+
B.randn(nps.dtype, 4, 1, 15),
58+
B.randn(nps.dtype, 4, 1, 10),
59+
aux_t=B.randn(nps.dtype, 4, 2, 10),
60+
)

0 commit comments

Comments
 (0)