|
4 | 4 | import numpy as np
|
5 | 5 | import pytest
|
6 | 6 | from neuralprocesses.aggregate import Aggregate, AggregateInput
|
| 7 | +from neuralprocesses.parallel import Parallel |
7 | 8 | from plum import isinstance
|
8 | 9 |
|
9 | 10 | from .util import approx, generate_data
|
@@ -420,3 +421,42 @@ def test_data_eq(nps, dim_x, dim_y, constructor, config, dim_lv):
|
420 | 421 | batch = gen.generate_batch()
|
421 | 422 | pred = model(batch["contexts"], batch["xt"])
|
422 | 423 | check_prediction(nps, pred, batch["yt"])
|
| 424 | + |
| 425 | + |
| 426 | +@pytest.mark.parametrize( |
| 427 | + "constructor_kw_args", |
| 428 | + [ |
| 429 | + ("construct_gnp", {}), |
| 430 | + ("construct_agnp", {}), |
| 431 | + ("construct_convgnp", {"points_per_unit": 4}), |
| 432 | + ("construct_fullconvgnp", {"points_per_unit": 4}), |
| 433 | + ], |
| 434 | +) |
| 435 | +def test_context_verification(nps, constructor_kw_args): |
| 436 | + constructor, kw_args = constructor_kw_args |
| 437 | + |
| 438 | + c = (B.randn(nps.dtype, 4, 1, 5), B.randn(nps.dtype, 4, 1, 5)) |
| 439 | + xt = B.randn(nps.dtype, 4, 1, 15) |
| 440 | + |
| 441 | + # Test one context set. |
| 442 | + model = getattr(nps, constructor)(dim_yc=1, dtype=nps.dtype, **kw_args) |
| 443 | + pred1 = model(*c, xt) |
| 444 | + pred2 = model([c], xt) |
| 445 | + approx(pred1.mean, pred2.mean) |
| 446 | + with pytest.raises(AssertionError, match="(?i)got inputs and outputs in parallel"): |
| 447 | + model([c, c], xt) |
| 448 | + with pytest.raises(AssertionError, match="(?i) got inputs in parallel"): |
| 449 | + model(Parallel(c[0], c[0]), c[1], xt) |
| 450 | + with pytest.raises(AssertionError, match="(?i) got outputs in parallel"): |
| 451 | + model(c[0], Parallel(c[1], c[1]), xt) |
| 452 | + |
| 453 | + # Test two context sets. |
| 454 | + model = getattr(nps, constructor)(dim_yc=(1, 1), dtype=nps.dtype, **kw_args) |
| 455 | + model([c, c], xt) |
| 456 | + with pytest.raises(AssertionError, match="(?i)expected a parallel of elements"): |
| 457 | + model([c], xt) |
| 458 | + with pytest.raises( |
| 459 | + AssertionError, |
| 460 | + match="(?i)expected a parallel of 2 elements, but got 4 inputs and 4 outputs", |
| 461 | + ): |
| 462 | + model([c, c, c, c], xt) |
0 commit comments