Skip to content

Commit a6fd326

Browse files
committed
adding bernoulli model, support for inputs, and some test code with scipy solve banded functions
1 parent 447ab47 commit a6fd326

8 files changed

+638
-159
lines changed

examples/bernoulli_lds.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from pybasicbayes.distributions import Regression
77
from pybasicbayes.util.text import progprint_xrange
88
from pypolyagamma.distributions import BernoulliRegression
9-
from pylds.models import CountLDS
9+
from pylds.models import CountLDS, DefaultBernoulliLDS
1010

11-
npr.seed(0)
11+
npr.seed(1)
1212

1313
# Parameters
1414
D_obs = 10
@@ -22,7 +22,7 @@
2222

2323
A = 0.99*np.array([[np.cos(np.pi/24), -np.sin(np.pi/24)],
2424
[np.sin(np.pi/24), np.cos(np.pi/24)]])
25-
B = np.ones((D_latent, D_input))
25+
B = np.zeros((D_latent, D_input))
2626
sigma_states = 0.01*np.eye(2)
2727

2828
C = np.random.randn(D_obs, D_latent)
@@ -45,26 +45,43 @@
4545
M_0=np.zeros((D_latent, D_latent + D_input)),
4646
K_0=(D_latent + D_input) * np.eye(D_latent + D_input)),
4747
emission_distn=BernoulliRegression(D_out=D_obs, D_in=D_latent + D_input))
48-
model.add_data(data, inputs=inputs)
48+
model.add_data(data, inputs=inputs, stateseq=np.zeros((T, D_latent)))
4949

50-
# Run a Gibbs sampler
51-
N_samples = 500
50+
# Run a Gibbs sampler with Polya-gamma augmentation
51+
N_samples = 50
5252
def gibbs_update(model):
5353
model.resample_model()
5454
smoothed_obs = model.states_list[0].smooth()
55-
return model.log_likelihood(), \
56-
model.states_list[0].gaussian_states, \
57-
smoothed_obs
55+
ll = model.log_likelihood()
56+
return ll, model.states_list[0].gaussian_states, smoothed_obs
5857

59-
lls, z_smpls, smoothed_obss = \
58+
lls_gibbs, x_smpls_gibbs, y_smooth_gibbs = \
6059
zip(*[gibbs_update(model) for _ in progprint_xrange(N_samples)])
6160

61+
# Fit with a Bernoulli LDS using Laplace approximation for comparison
62+
model = DefaultBernoulliLDS(D_obs, D_latent, D_input=D_input,
63+
C=0.01 * np.random.randn(D_obs, D_latent),
64+
D=0.01 * np.random.randn(D_obs, D_input))
65+
model.add_data(data, inputs=inputs, stateseq=np.zeros((T, D_latent)))
66+
67+
N_iters = 50
68+
def em_update(model):
69+
model.EM_step(verbose=True)
70+
smoothed_obs = model.states_list[0].smooth()
71+
ll = model.log_likelihood()
72+
return ll, model.states_list[0].gaussian_states, smoothed_obs
73+
74+
lls_em, x_smpls_em, y_smooth_em = \
75+
zip(*[em_update(model) for _ in progprint_xrange(N_iters)])
76+
6277
# Plot the log likelihood over iterations
6378
plt.figure(figsize=(10,6))
64-
plt.plot(lls,'-b')
65-
plt.plot([0,N_samples], truemodel.log_likelihood() * np.ones(2), '-k')
79+
plt.plot(lls_gibbs, label="gibbs")
80+
plt.plot(lls_em, label="em")
81+
plt.plot([0,N_samples], truemodel.log_likelihood() * np.ones(2), '-k', label="true")
6682
plt.xlabel('iteration')
6783
plt.ylabel('log likelihood')
84+
plt.legend(loc="lower right")
6885

6986
# Plot the smoothed observations
7087
fig = plt.figure(figsize=(10,10))
@@ -80,9 +97,9 @@ def gibbs_update(model):
8097
given_ts = np.where(data[:,j]==1)[0]
8198
ax.plot(given_ts, np.ones_like(given_ts), 'ko', markersize=5)
8299

83-
# Plot the inferred rate
84-
ax.plot([0], [0], 'b', lw=2, label="smoothed obs.")
85-
ax.plot(smoothed_obss[-1][:,j], 'r', lw=2, label="smoothed pr.")
100+
ax.plot([0], [0], 'ko', lw=2, label="data")
101+
ax.plot(y_smooth_gibbs[-1][:, j], lw=2, label="gibbs probs")
102+
ax.plot(y_smooth_em[-1][:, j], lw=2, label="em probs")
86103

87104
if i == 0:
88105
plt.legend(loc="upper center", ncol=4, bbox_to_anchor=(0.5, 2.))
@@ -93,4 +110,3 @@ def gibbs_update(model):
93110
ax.set_ylabel("$x_%d(t)$" % (j+1))
94111

95112
plt.show()
96-

pylds/distributions.py

+165-27
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
import autograd.numpy as np
2-
from autograd import value_and_grad, hessian_vector_product
2+
from autograd import value_and_grad
33
from autograd.scipy.special import gammaln
44

55
from scipy.optimize import minimize
66

77
from pybasicbayes.distributions import Regression
8+
from pybasicbayes.util.text import progprint_xrange
9+
810

911
class PoissonRegression(Regression):
1012
"""
1113
Poisson regression with Gaussian distributed inputs and exp link:
1214
13-
y ~ Poisson(exp(Ax + b))
15+
y ~ Poisson(exp(Ax))
1416
1517
where x ~ N(mu, sigma)
1618
1719
Currently, we only support maximum likelihood estimation of the
18-
parameters, A and b, given the distribution over inputs, x, and
20+
parameters A given the distribution over inputs, x, and
1921
the observed outputs, y.
2022
2123
We compute the expected log likelihood in closed form (since
@@ -35,10 +37,18 @@ def __init__(self, D_out, D_in, A=None, verbose=False):
3537

3638
self.sigma = None
3739

40+
@property
41+
def D_in(self):
42+
return self._D_in
43+
44+
@property
45+
def D_out(self):
46+
return self._D_out
47+
3848
def log_likelihood(self,xy):
3949
assert isinstance(xy, tuple)
4050
x, y = xy
41-
loglmbda = x.dot(self.A.T) + self.b.T
51+
loglmbda = x.dot(self.A.T)
4252
lmbda = np.exp(loglmbda)
4353
return -gammaln(y+1) - lmbda + y * loglmbda
4454

@@ -68,14 +78,6 @@ def expected_log_likelihood(self, mus, sigmas, y):
6878

6979
return ll
7080

71-
@property
72-
def D_in(self):
73-
return self._D_in
74-
75-
@property
76-
def D_out(self):
77-
return self._D_out
78-
7981
def predict(self, x):
8082
return np.exp(x.dot(self.A.T))
8183

@@ -97,27 +99,30 @@ def max_likelihood(self, data, weights=None,stats=None):
9799
def max_expected_likelihood(self, stats, verbose=False):
98100
# These aren't really "sufficient" statistics, since we
99101
# need the mean and covariance for each time bin.
100-
EyxT = np.sum([s[0] for s in stats], axis=0)
102+
EyxuT = np.sum([s[0] for s in stats], axis=0)
101103
mus = np.vstack([s[1] for s in stats])
102-
sigs = np.vstack([s[2] for s in stats])
103-
masks = np.vstack(s[3] for s in stats)
104+
sigmas = np.vstack([s[2] for s in stats])
105+
inputs = np.vstack([s[3] for s in stats])
106+
masks = np.vstack(s[4] for s in stats)
104107
T = mus.shape[0]
105-
D = self.D_in
108+
109+
D_latent = mus.shape[1]
110+
sigmas_vec = sigmas.reshape((T, D_latent**2))
106111

107112
# Optimize each row of A independently
108-
for n in range(self.D_out):
113+
ns = progprint_xrange(self.D_out) if verbose else range(self.D_out)
114+
for n in ns:
109115

110116
# Flatten the covariance to enable vectorized calculations
111-
sigs_vec = sigs.reshape((T,D**2))
112117
def ll_vec(an):
113118

114119
ll = 0
115-
ll += np.dot(an, EyxT[n])
120+
ll += np.dot(an, EyxuT[n])
116121

117122
# Vectorized log likelihood calculation
118123
loglmbda = np.dot(mus, an)
119-
aa_vec = np.outer(an, an).reshape((D ** 2,))
120-
trms = np.exp(loglmbda + 0.5 * np.dot(sigs_vec, aa_vec))
124+
aa_vec = np.outer(an[:D_latent], an[:D_latent]).reshape((D_latent ** 2,))
125+
trms = np.exp(loglmbda + 0.5 * np.dot(sigmas_vec, aa_vec))
121126
ll -= np.sum(trms[masks[:, n]])
122127

123128
if not np.isfinite(ll):
@@ -134,11 +139,144 @@ def cbk(x):
134139
res = minimize(value_and_grad(obj), self.A[n],
135140
jac=True,
136141
callback=cbk if verbose else None)
137-
# res = minimize(value_and_grad(obj), self.A[n],
138-
# tol=1e-3,
139-
# method="Newton-CG",
140-
# jac=True,
141-
# hessp=hessian_vector_product(obj),
142-
# callback=cbk if verbose else None)
143142
assert res.success
144143
self.A[n] = res.x
144+
145+
146+
class BernoulliRegression(Regression):
147+
"""
148+
Bernoulli regression with Gaussian distributed inputs and logistic link:
149+
150+
y ~ Bernoulli(logistic(Ax))
151+
152+
where x ~ N(mu, sigma)
153+
154+
Currently, we only support maximum likelihood estimation of the
155+
parameter A given the distribution over inputs, x, and
156+
the observed outputs, y.
157+
158+
We approximate the expected log likelihood with Monte Carlo.
159+
"""
160+
161+
def __init__(self, D_out, D_in, A=None, verbose=False):
162+
self._D_out, self._D_in = D_out, D_in
163+
self.verbose = verbose
164+
165+
if A is not None:
166+
assert A.shape == (D_out, D_in)
167+
self.A = A.copy()
168+
else:
169+
self.A = 0.01 * np.random.randn(D_out, D_in)
170+
171+
self.sigma = None
172+
173+
@property
174+
def D_in(self):
175+
return self._D_in
176+
177+
@property
178+
def D_out(self):
179+
return self._D_out
180+
181+
def log_likelihood(self,xy):
182+
assert isinstance(xy, tuple)
183+
x, y = xy
184+
psi = x.dot(self.A.T)
185+
186+
# First term is linear
187+
ll = y * psi
188+
189+
# Compute second term with log-sum-exp trick (see above)
190+
logm = np.maximum(0, psi)
191+
ll -= np.sum(logm)
192+
ll -= np.sum(np.log(np.exp(-logm) + np.exp(psi - logm)))
193+
194+
return ll
195+
196+
def predict(self, x):
197+
return 1 / (1 + np.exp(-x.dot(self.A.T)))
198+
199+
def rvs(self, x=None, size=1, return_xy=True):
200+
x = np.random.normal(size=(size, self.D_in)) if x is None else x
201+
y = np.random.rand(x.shape[0], self.D_out) < self.predict(x)
202+
return np.hstack((x, y)) if return_xy else y
203+
204+
def max_likelihood(self, data, weights=None, stats=None):
205+
"""
206+
Maximize the likelihood for given data
207+
:param data:
208+
:param weights:
209+
:param stats:
210+
:return:
211+
"""
212+
if isinstance(data, list):
213+
x = np.vstack([d[0] for d in data])
214+
y = np.vstack([d[1] for d in data])
215+
elif isinstance(data, tuple):
216+
assert len(data) == 2
217+
elif isinstance(data, np.ndarray):
218+
x, y = data[:,:self.D_in], data[:, self.D_in:]
219+
else:
220+
raise Exception("Invalid data type")
221+
222+
from sklearn.linear_model import LogisticRegression
223+
for n in progprint_xrange(self.D_out):
224+
lr = LogisticRegression(fit_intercept=False)
225+
lr.fit(x, y[:,n])
226+
self.A[n] = lr.coef_
227+
228+
229+
def max_expected_likelihood(self, stats, verbose=False, n_smpls=1):
230+
231+
# These aren't really "sufficient" statistics, since we
232+
# need the mean and covariance for each time bin.
233+
EyxuT = np.sum([s[0] for s in stats], axis=0)
234+
mus = np.vstack([s[1] for s in stats])
235+
sigmas = np.vstack([s[2] for s in stats])
236+
inputs = np.vstack([s[3] for s in stats])
237+
T = mus.shape[0]
238+
239+
D_latent = mus.shape[1]
240+
241+
# Draw Monte Carlo samples of x
242+
sigmas_chol = np.linalg.cholesky(sigmas)
243+
x_smpls = mus[:, :, None] + np.matmul(sigmas_chol, np.random.randn(T, D_latent, n_smpls))
244+
245+
# Optimize each row of A independently
246+
ns = progprint_xrange(self.D_out) if verbose else range(self.D_out)
247+
for n in ns:
248+
249+
def ll_vec(an):
250+
ll = 0
251+
252+
# todo include mask
253+
# First term is linear in psi
254+
ll += np.dot(an, EyxuT[n])
255+
256+
# Second term depends only on x and cannot be computed in closed form
257+
# Instead, Monte Carlo sample x
258+
psi_smpls = np.einsum('tdm, d -> tm', x_smpls, an[:D_latent])
259+
psi_smpls = psi_smpls + np.dot(inputs, an[D_latent:])[:, None]
260+
logm = np.maximum(0, psi_smpls)
261+
trm2_smpls = logm + np.log(np.exp(-logm) + np.exp(psi_smpls - logm))
262+
ll -= np.sum(trm2_smpls) / n_smpls
263+
264+
if not np.isfinite(ll):
265+
return -np.inf
266+
267+
return ll / T
268+
269+
obj = lambda x: -ll_vec(x)
270+
271+
itr = [0]
272+
def cbk(x):
273+
itr[0] += 1
274+
print("M_step iteration ", itr[0])
275+
276+
res = minimize(value_and_grad(obj), self.A[n],
277+
jac=True,
278+
# callback=cbk if verbose else None)
279+
callback=None)
280+
assert res.success
281+
self.A[n] = res.x
282+

0 commit comments

Comments
 (0)