|
1 |
| -# from __future__ import division |
2 |
| -# import numpy as np |
3 |
| -# import numpy.random as npr |
4 |
| -# import matplotlib.pyplot as plt |
5 |
| -# |
6 |
| -# from pybasicbayes.distributions import Regression |
7 |
| -# from pybasicbayes.util.text import progprint_xrange |
8 |
| -# |
9 |
| -# from pylds.models import LDS |
10 |
| -# |
11 |
| -# npr.seed(0) |
12 |
| -# |
13 |
| -# |
14 |
| -# ######################### |
15 |
| -# # set some parameters # |
16 |
| -# ######################### |
17 |
| -# D_obs = 1 |
18 |
| -# D_latent = 2 |
19 |
| -# D_input = 0 |
20 |
| -# T = 2000 |
21 |
| -# |
22 |
| -# mu_init = np.array([0.,1.]) |
23 |
| -# sigma_init = 0.01*np.eye(2) |
24 |
| -# |
25 |
| -# A = 0.99*np.array([[np.cos(np.pi/24), -np.sin(np.pi/24)], |
26 |
| -# [np.sin(np.pi/24), np.cos(np.pi/24)]]) |
27 |
| -# B = np.ones((D_latent, D_input)) |
28 |
| -# sigma_states = 0.01*np.eye(2) |
29 |
| -# |
30 |
| -# C = np.array([[10.,0.]]) |
31 |
| -# D = np.zeros((D_obs, D_input)) |
32 |
| -# sigma_obs = 0.01*np.eye(1) |
33 |
| -# |
34 |
| -# ################### |
35 |
| -# # generate data # |
36 |
| -# ################### |
37 |
| -# |
38 |
| -# truemodel = LDS( |
39 |
| -# dynamics_distn=Regression(A=np.hstack((A,B)), sigma=sigma_states), |
40 |
| -# emission_distn=Regression(A=np.hstack((C,D)), sigma=sigma_obs)) |
41 |
| -# |
42 |
| -# inputs = np.random.randn(T, D_input) |
43 |
| -# # inputs = np.zeros((T, D_input)) |
44 |
| -# data, stateseq = truemodel.generate(T, inputs=inputs) |
45 |
| -# |
46 |
| -# |
47 |
| -# ############### |
48 |
| -# # make model # |
49 |
| -# ############### |
50 |
| -# model = LDS( |
51 |
| -# dynamics_distn=Regression(nu_0=D_latent + 2, |
52 |
| -# S_0=D_latent * np.eye(D_latent), |
53 |
| -# M_0=np.zeros((D_latent, D_latent + D_input)), |
54 |
| -# K_0=(D_latent + D_input) * np.eye(D_latent + D_input), |
55 |
| -# # A=np.hstack((A,B)), sigma=sigma_states |
56 |
| -# ), |
57 |
| -# emission_distn=Regression(nu_0=D_obs + 2, |
58 |
| -# S_0=D_obs * np.eye(D_obs), |
59 |
| -# M_0=np.zeros((D_obs, D_latent + D_input)), |
60 |
| -# K_0=(D_latent + D_input) * np.eye(D_latent + D_input), |
61 |
| -# # A=np.hstack((C,D)), sigma=100*sigma_obs |
62 |
| -# ) |
63 |
| -# ) |
64 |
| -# model.add_data(data, inputs=inputs) |
65 |
| -# # model.emission_distn._initialize_mean_field() |
66 |
| -# # model.dynamics_distn._initialize_mean_field() |
67 |
| -# |
68 |
| -# ############### |
69 |
| -# # fit model # |
70 |
| -# ############### |
71 |
| -# def update(model): |
72 |
| -# return model.meanfield_coordinate_descent_step() |
73 |
| -# |
74 |
| -# for _ in progprint_xrange(100): |
75 |
| -# model.resample_model() |
76 |
| -# |
77 |
| -# N_steps = 100 |
78 |
| -# vlbs = [update(model) for _ in progprint_xrange(N_steps)] |
79 |
| -# model.resample_from_mf() |
80 |
| -# |
81 |
| -# plt.figure(figsize=(3,4)) |
82 |
| -# plt.plot([0, N_steps], truemodel.log_likelihood()*np.ones(2), '--k') |
83 |
| -# plt.plot(vlbs) |
84 |
| -# plt.xlabel('iteration') |
85 |
| -# plt.ylabel('variational lower bound') |
86 |
| -# plt.show() |
87 |
| -# |
88 |
| -# ################ |
89 |
| -# # smoothing # |
90 |
| -# ################ |
91 |
| -# smoothed_obs = model.states_list[0].meanfield_smooth() |
92 |
| -# |
93 |
| -# ################ |
94 |
| -# # predicting # |
95 |
| -# ################ |
96 |
| -# Nseed = 1700 |
97 |
| -# Npredict = 100 |
98 |
| -# prediction_seed = data[:Nseed] |
99 |
| -# |
100 |
| -# model.emission_distn.resample_from_mf() |
101 |
| -# predictions = model.sample_predictions(prediction_seed, Npredict) |
102 |
| -# |
103 |
| -# plt.figure() |
104 |
| -# plt.plot(data, 'k') |
105 |
| -# plt.plot(smoothed_obs[:Nseed], ':k') |
106 |
| -# plt.plot(Nseed + np.arange(Npredict), predictions, 'b') |
107 |
| -# plt.xlabel('time index') |
108 |
| -# plt.ylabel('prediction') |
109 |
| -# |
110 |
| -# plt.show() |
111 |
| - |
112 | 1 | from __future__ import division
|
113 | 2 | import numpy as np
|
114 | 3 | import numpy.random as npr
|
@@ -192,4 +81,3 @@ def update(model):
|
192 | 81 | plt.legend()
|
193 | 82 |
|
194 | 83 | plt.show()
|
195 |
| - |
0 commit comments