Skip to content

Commit 35b51de

Browse files
authored
Merge pull request #14 from wesselb/struclik
Structured likelihoods
2 parents d8b16cc + 6a136b7 commit 35b51de

24 files changed

+846
-182
lines changed

Diff for: neuralprocesses/aggregate.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,19 @@ def f(*args: Aggregate, **kw_args):
7979
raise ValueError(f"Invalid number of arguments {num_args}.")
8080

8181

82+
_map_f("expand_dims", 1)
83+
_map_f("exp", 1)
84+
_map_f("one", 1)
85+
_map_f("zero", 1)
8286
_map_f("mean", 1)
87+
_map_f("sum", 1)
88+
_map_f("logsumexp", 1)
8389

8490
_map_f("add", 2)
8591
_map_f("subtract", 2)
8692
_map_f("multiply", 2)
8793
_map_f("divide", 2)
8894

89-
9095
_map_f("stack", "*")
9196
_map_f("concat", "*")
9297
_map_f("squeeze", "*")

Diff for: neuralprocesses/architectures/agnp.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import neuralprocesses as nps # This fixes inspection below.
2+
23
from ..util import register_model
34

45
__all__ = ["construct_agnp"]
@@ -34,7 +35,7 @@ def construct_agnp(*args, nps=nps, num_heads=8, **kw_args):
3435
low-rank likelihood. Defaults to 512.
3536
dim_lv (int, optional): Dimensionality of the latent variable. Defaults to 0.
3637
lv_likelihood (str, optional): Likelihood of the latent variable. Must be one of
37-
`"het"` or `"dense"`. Defaults to `"het"`.
38+
`"het"`, `"dense"`, or `"spikes-beta"`. Defaults to `"het"`.
3839
transform (str or tuple[float, float]): Bijection applied to the
3940
output of the model. This can help deal with positive of bounded data.
4041
Must be either `"positive"`, `"exp"`, `"softplus"`, or

Diff for: neuralprocesses/architectures/convgnp.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import lab as B
2+
import neuralprocesses as nps # This fixes inspection below.
23
import wbml.out as out
34
from plum import convert
45

5-
import neuralprocesses as nps # This fixes inspection below.
6-
from .util import construct_likelihood, parse_transform
76
from ..util import register_model
7+
from .util import construct_likelihood, parse_transform
88

99
__all__ = ["construct_convgnp"]
1010

@@ -140,8 +140,8 @@ def construct_convgnp(
140140
Defaults to 64.
141141
margin (float, optional): Margin of the internal discretisation. Defaults to
142142
0.1.
143-
likelihood (str, optional): Likelihood. Must be one of `"het"` or `"lowrank".
144-
Defaults to `"lowrank"`.
143+
likelihood (str, optional): Likelihood. Must be one of `"het"`, `"lowrank"`,
144+
or `"spikes-beta"`. Defaults to `"lowrank"`.
145145
conv_arch (str, optional): Convolutional architecture to use. Must be one of
146146
`"unet[-res][-sep]"` or `"conv[-res][-sep]"`. Defaults to `"unet"`.
147147
unet_channels (tuple[int], optional): Channels of every layer of the UNet.

Diff for: neuralprocesses/architectures/gnp.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import neuralprocesses as nps # This fixes inspection below.
12
from plum import convert
23

3-
import neuralprocesses as nps # This fixes inspection below.
4-
from .util import construct_likelihood, parse_transform
54
from ..util import register_model
5+
from .util import construct_likelihood, parse_transform
66

77
__all__ = ["construct_gnp"]
88

@@ -58,7 +58,7 @@ def construct_gnp(
5858
low-rank likelihood. Defaults to 512.
5959
dim_lv (int, optional): Dimensionality of the latent variable. Defaults to 0.
6060
lv_likelihood (str, optional): Likelihood of the latent variable. Must be one of
61-
`"het"` or `"dense"`. Defaults to `"het"`.
61+
`"het"`, `"dense"`, or `"spikes-beta"`. Defaults to `"het"`.
6262
transform (str or tuple[float, float]): Bijection applied to the
6363
output of the model. This can help deal with positive of bounded data.
6464
Must be either `"positive"`, `"exp"`, `"softplus"`, or

Diff for: neuralprocesses/architectures/util.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ def construct_likelihood(nps=nps, *, spec, dim_y, num_basis_functions, dtype):
1111
1212
Args:
1313
nps (module): Appropriate backend-specific module.
14-
spec (str, optional): Specification. Must be one of `"het"`, `"lowrank"`, or
15-
`"dense"`. Defaults to `"lowrank"`. Must be given as a keyword argument.
14+
spec (str, optional): Specification. Must be one of `"het"`, `"lowrank"`,
15+
`"dense"`, or `"spikes-beta"`. Defaults to `"lowrank"`. Must be given as
16+
a keyword argument.
1617
dim_y (int): Dimensionality of the outputs. Must be given as a keyword argument.
1718
num_basis_functions (int): Number of basis functions for the low-rank
1819
likelihood. Must be given as a keyword argument.
@@ -47,6 +48,10 @@ def construct_likelihood(nps=nps, *, spec, dim_y, num_basis_functions, dtype):
4748
),
4849
nps.DenseGaussianLikelihood(),
4950
)
51+
elif spec == "spikes-beta":
52+
num_channels = (2 + 3) * dim_y # Alpha, beta, and three log-probabilities
53+
selector = nps.SelectFromChannels(dim_y, dim_y, dim_y, dim_y, dim_y)
54+
lik = nps.SpikesBetaLikelihood()
5055

5156
else:
5257
raise ValueError(f'Incorrect likelihood specification "{spec}".')

Diff for: neuralprocesses/coding.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from plum import isinstance, issubclass
33

44
from . import _dispatch
5-
from .dist import AbstractMultiOutputDistribution, Dirac
5+
from .dist import AbstractDistribution, Dirac
66
from .parallel import Parallel
77
from .util import is_composite_coder
88

@@ -70,7 +70,7 @@ def code_track(coder, xz, z, x, h, **kw_args):
7070
if is_composite_coder(coder):
7171
raise RuntimeError(
7272
f"Dispatched to fallback implementation of `code_track` for "
73-
f"`{ptype(type(coder))}`, but the coder is composite."
73+
f"`{type(coder)}`, but the coder is composite."
7474
)
7575
xz, z = code(coder, xz, z, x, **kw_args)
7676
return xz, z, h + [x]
@@ -95,7 +95,7 @@ def recode(coder, xz, z, h, **kw_args):
9595
if is_composite_coder(coder):
9696
raise RuntimeError(
9797
f"Dispatched to fallback implementation of `recode` for "
98-
f"`{ptype(type(coder))}`, but the coder is composite."
98+
f"`{type(coder)}`, but the coder is composite."
9999
)
100100
xz, z = code(coder, xz, z, h[0], **kw_args)
101101
return xz, z, h[1:]
@@ -151,6 +151,6 @@ def _choose(new: Dirac, old: Dirac):
151151

152152

153153
@_dispatch
154-
def _choose(new: AbstractMultiOutputDistribution, old: AbstractMultiOutputDistribution):
154+
def _choose(new: AbstractDistribution, old: AbstractDistribution):
155155
# Do recode other distributions.
156156
return new

Diff for: neuralprocesses/data/bimodal.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import lab as B
2+
import numpy as np
3+
from lab.shape import Dimension
4+
5+
from neuralprocesses.data import SyntheticGenerator, new_batch
6+
7+
__all__ = ["BiModalGenerator"]
8+
9+
10+
class BiModalGenerator(SyntheticGenerator):
11+
"""Bi-modal distribution generator.
12+
13+
Further takes in arguments and keyword arguments from the constructor of
14+
:class:`.data.SyntheticGenerator`. Moreover, also has the attributes of
15+
:class:`.data.SyntheticGenerator`.
16+
"""
17+
18+
def __init__(self, *args, **kw_args):
19+
super().__init__(*args, **kw_args)
20+
21+
def generate_batch(self):
22+
with B.on_device(self.device):
23+
set_batch, xcs, xc, nc, xts, xt, nt = new_batch(self, self.dim_y)
24+
x = B.concat(xc, xt, axis=1)
25+
26+
# Draw a different random phase, amplitude, and period for every task in
27+
# the batch.
28+
self.state, rand = B.rand(
29+
self.state,
30+
self.float64,
31+
3,
32+
self.batch_size,
33+
1, # Broadcast over `n`.
34+
1, # There is only one input dimension.
35+
)
36+
phase = 2 * B.pi * rand[0]
37+
amplitude = 1 + rand[1]
38+
period = 1 + rand[2]
39+
40+
# Construct the noiseless function.
41+
f = amplitude * B.sin(phase + (2 * B.pi / period) * x)
42+
43+
# Add noise with variance.
44+
probs = B.cast(self.float64, np.array([0.5, 0.5]))
45+
means = B.cast(self.float64, np.array([-0.1, 0.1]))
46+
variance = 1
47+
# Randomly choose from `means` with probabilities `probs`.
48+
self.state, mean = B.choice(self.state, means, self.batch_size, p=probs)
49+
self.state, randn = B.randn(
50+
self.state,
51+
self.float64,
52+
self.batch_size,
53+
# `nc` and `nt` are tensors rather than plain integers. Tell dispatch
54+
# that they can be interpreted as dimensions of a shape.
55+
Dimension(nc + nt),
56+
1,
57+
)
58+
noise = B.sqrt(variance) * randn + mean[:, None, None]
59+
60+
# Construct the noisy function.
61+
y = f + noise
62+
63+
batch = {}
64+
set_batch(batch, y[:, :nc], y[:, nc:])
65+
return batch

Diff for: neuralprocesses/dist/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from .beta import *
12
from .dirac import *
23
from .dist import *
34
from .geom import *
45
from .normal import *
6+
from .spikeslab import *
57
from .transformed import *
68
from .uniform import *

Diff for: neuralprocesses/dist/beta.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import lab as B
2+
from matrix.shape import broadcast
3+
from plum import parametric
4+
5+
from .. import _dispatch
6+
from ..aggregate import Aggregate
7+
from ..mask import Masked
8+
from .dist import AbstractDistribution, shape_batch
9+
10+
__all__ = ["Beta"]
11+
12+
13+
@parametric
14+
class Beta(AbstractDistribution):
15+
"""Beta distribution.
16+
17+
Args:
18+
alpha (tensor): Shape parameter `alpha`.
19+
beta (tensor): Shape parameter `beta`.
20+
d (int): Dimensionality of the data.
21+
22+
Attributes:
23+
alpha (tensor): Shape parameter `alpha`.
24+
beta (tensor): Shape parameter `beta`.
25+
d (int): Dimensionality of the data.
26+
"""
27+
28+
def __init__(self, alpha, beta, d):
29+
self.alpha = alpha
30+
self.beta = beta
31+
self.d = d
32+
33+
@property
34+
def mean(self):
35+
return B.divide(self.alpha, B.add(self.alpha, self.beta))
36+
37+
@property
38+
def var(self):
39+
sum = B.add(self.alpha, self.beta)
40+
with B.on_device(sum):
41+
one = B.one(sum)
42+
return B.divide(
43+
B.multiply(self.alpha, self.beta),
44+
B.multiply(B.multiply(sum, sum), B.add(sum, one)),
45+
)
46+
47+
@_dispatch
48+
def sample(
49+
self: "Beta[Aggregate, Aggregate, Aggregate]",
50+
state: B.RandomState,
51+
dtype: B.DType,
52+
*shape,
53+
):
54+
samples = []
55+
for ai, bi, di in zip(self.alpha, self.beta, self.d):
56+
state, sample = Beta(ai, bi, di).sample(state, dtype, *shape)
57+
samples.append(sample)
58+
return state, Aggregate(*samples)
59+
60+
@_dispatch
61+
def sample(
62+
self: "Beta[B.Numeric, B.Numeric, B.Int]",
63+
state: B.RandomState,
64+
dtype: B.DType,
65+
*shape,
66+
):
67+
return B.randbeta(state, dtype, *shape, alpha=self.alpha, beta=self.beta)
68+
69+
@_dispatch
70+
def logpdf(self: "Beta[Aggregate, Aggregate, Aggregate]", x: Aggregate):
71+
return sum(
72+
[
73+
Beta(ai, bi, di).logpdf(xi)
74+
for ai, bi, di, xi in zip(self.alpha, self.beta, self.d, x)
75+
],
76+
0,
77+
)
78+
79+
@_dispatch
80+
def logpdf(self: "Beta[B.Numeric, B.Numeric, B.Int]", x: Masked):
81+
x, mask = x.y, x.mask
82+
with B.on_device(self.alpha):
83+
safe = B.to_active_device(B.cast(B.dtype(self.alpha), 0.5))
84+
# Make inputs safe.
85+
x = mask * x + (1 - mask) * safe
86+
# Run with safe inputs, and filter out the right logpdfs.
87+
return self.logpdf(x, mask=mask)
88+
89+
@_dispatch
90+
def logpdf(self: "Beta[B.Numeric, B.Numeric, B.Int]", x: B.Numeric, *, mask=1):
91+
logz = B.logbeta(self.alpha, self.beta)
92+
logpdf = (self.alpha - 1) * B.log(x) + (self.beta - 1) * B.log(1 - x) - logz
93+
return B.sum(mask * logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :])
94+
95+
def __str__(self):
96+
return f"Beta({self.alpha}, {self.beta})"
97+
98+
def __repr__(self):
99+
return f"Beta({self.alpha!r}, {self.beta!r})"
100+
101+
102+
@B.dtype.dispatch
103+
def dtype(dist: Beta):
104+
return B.dtype(dist.alpha, dist.beta)
105+
106+
107+
@shape_batch.dispatch
108+
def shape_batch(dist: "Beta[B.Numeric, B.Numeric, B.Int]"):
109+
return B.shape_broadcast(dist.alpha, dist.beta)[: -dist.d]
110+
111+
112+
@shape_batch.dispatch
113+
def shape_batch(dist: "Beta[Aggregate, Aggregate, Aggregate]"):
114+
return broadcast(
115+
*(
116+
shape_batch(Beta(ai, bi, di))
117+
for ai, bi, di in zip(dist.alpha, dist.beta, dist.d)
118+
)
119+
)

0 commit comments

Comments
 (0)