Skip to content

Commit

Permalink
Added SGHMC
Browse files Browse the repository at this point in the history
Added SGHMC LV/Spiral examples
  • Loading branch information
karen-sy committed Jan 19, 2021
1 parent e7a0bad commit 0ad23a3
Show file tree
Hide file tree
Showing 2 changed files with 307 additions and 0 deletions.
143 changes: 143 additions & 0 deletions BayesiaNODE_SGHMC_LV.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
using Distributed
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Turing, Serialization, Plots
using JLD


u0 = Float32[1., 1.];
p = [1.5, 1., 3., 1.];
tspan = (0.0f0, 4.5f0);
tsteps = tspan[1]:0.1:tspan[2];

function lv(u, p, t)
x, y = u
α, β, γ, δ = p
dx = α*x - β*x*y
dy = δ*x*y - γ*y
du = [dx, dy]
end


trueodeprob = ODEProblem(lv, u0, tspan, p);
ode_data = Array(solve(trueodeprob, Tsit5(), saveat = tsteps));
y_train = ode_data[:, 1:35];

dudt2 = FastChain(FastDense(2, 10, tanh), FastDense(10, 2));

prob_node = NeuralODE(dudt2, (0., 4.5), Tsit5(), saveat = tsteps); #neural ode
train_prob = NeuralODE(dudt2, (0., 3.5), Tsit5(), saveat = tsteps[1:35]);


function predict_node(p) # predict with given params
Array(train_prob(u0, p))
end

function loss(p) # loss function to minimize
Float64(sum(abs2, y_train .- predict_node(p)))
end


####### Perform inference


### Fit neural ode to the data
@everywhere @model function fit_node(data)
σ ~ InverseGamma(2, 3)
p ~ MvNormal(pmin_lv, 0.06)

# Calculate predictions for the inputs given the params.
predicted = predict_node(p)

# observe each prediction.
for i = 1:size(predicted,2)
data[:,i] ~ MvNormal(predicted[:,i], σ)
end
end

@everywhere model = fit_node(y_train);
function perform_inference(lr, alpha, samplesize, num_chains)
alg = SGHMC(learning_rate=lr, momentum_decay=alpha)
chain = sample(model, alg, MCMCThreads(), samplesize, num_chains, init_theta=zero(pmin_lv), progress=true);
return chain
end

function map_loss(chain)
chain_array = Array(chain)
k = size(chain_array,1)
losses = loss.([chain_array[i,:] for i in 1:k])
return losses
end

# init at map point
using JLD
pinit = initial_params(dudt2);
opt = DiffEqFlux.sciml_train(loss, pinit, ADAM(0.05), maxiters = 1500)
opt2 = DiffEqFlux.sciml_train(loss, opt.minimizer, LBFGS(), allow_f_increases = true)
pmin = opt2.minimizer;
save("pmin_lv.jld", "pmin_lv", pmin)

using JLD
pmin = load("pmin_lv.jld")
pmin = pmin["pmin_lv"]
pmin_lv = pmin;


function plot_chain(chain, losses)
pl = plot()
chain_array = Array(chain)
len = size(chain_array,1)

training_end = 3.5

scatter!(tsteps, ode_data[1,:], color = :red, label = "Data: Var1", title = "Lotka Volterra Neural ODE")
scatter!(tsteps, ode_data[2,:], color = :blue, label = "Data: Var2")

for k in 1:300
resol = prob_node(u0, chain_array[rand(25:len), :])
plot!(tsteps[1:36], resol[1,1:36], alpha=0.04, color=:red, label = "")
plot!(tsteps[1:36], resol[2,1:36], alpha=0.04, color=:blue, label = "")
plot!(tsteps[36:45], resol[1,36:45], alpha=0.04, color=:purple, label = "")
plot!(tsteps[36:45], resol[2,36:45], alpha=0.04, color=:purple, label = "")
end

idx = findmin(losses)[2]
prediction = prob_node(u0, chain_array[idx, :])
plot!(tsteps, prediction[1,:], color=:black, w=2, label = "")
plot!(tsteps, prediction[2,:], color=:black, w=2, label = "Training: Best fit prediction")
plot!(tsteps[36:end], prediction[1,:][36:end], color = :purple, w = 2, label = "")
plot!(tsteps[36:end], prediction[2,:][36:end], color = :purple, w = 2, label = "Forecasting: Best fit prediction", ylims = (-1.5, 10))

display(plot!([training_end-0.0001,training_end+0.0001],[-1,5],lw=3,color=:green,label="Training Data End", linestyle = :dash))

############## CONTOUR PLOTS #######################
pl2 = scatter(ode_data[1,:], ode_data[2,:], color = :red, label = "Data", xlabel = "Var1", ylabel = "Var2", title = "Lotka Volterra Neural ODE")

for k in 1:300
resol = prob_node(u0, chain_array[rand(100:len), :])
plot!(resol[1,:][1:36],resol[2,:][1:36], alpha=0.04, color = :red, label = "")
plot!(resol[1,:][36:end],resol[2,:][36:end], alpha=0.1, color = :purple, label = "")
end

plot!(prediction[1,:], prediction[2,:], color = :black, w = 2, label = "Training: Best fit prediction")
display(plot!(prediction[1,:][36:end], prediction[2,:][36:end], color = :purple, w = 2, label = "Forecasting: Best fit prediction", ylims = (-2, 7), xlims = (-0.5, 7.5)))


return pl, pl2;
end

## ---------------------------------------------------


samples = 1000

lr = 2e-7; md = 0.1; num_chains = 7;

chain = perform_inference(lr, md, samples, num_chains);
for i in 1:num_chains
losses = map_loss(chain[:,:,i])
pl = plot(1:samples, losses); display(pl)
savefig(pl, string(lr, "_", md, "_", samples, "_", "chain_", i, "_losses", ".png"))
pl_ch, pl2 = plot_chain(chain[:,:,i], losses)
savefig(pl_ch, string(lr, "_", md, "_", samples, "_", "chain_", i, "_predictions", ".png"))
savefig(pl2, string(lr, "_", md, "_", samples, "_", "chain_", i, "_contour", ".png"))

end
164 changes: 164 additions & 0 deletions BayesiaNODE_SGHMC_spiral.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
using Distributed
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Turing, Serialization, Plots
using JLD


u0 = [2.0; 0.0]
datasize = 60
tspan = (0.0, 1.2)
tsteps = range(tspan[1], tspan[2], length = datasize)

function spiral(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end


trueodeprob = ODEProblem(spiral, u0, tspan);
ode_data = Array(solve(trueodeprob, Tsit5(), saveat = tsteps));
y_train = ode_data[:, 1:50];

dudt2 = FastChain(FastDense(2, 50, tanh),
FastDense(50, 2))

prob_node = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps); #neural ode
train_prob = NeuralODE(dudt2, (0., 1.0), Tsit5(), saveat = tsteps[1:50]);


function predict_node(p) # predict with given params
Array(train_prob(u0, p))
end

function loss(p) # loss function to minimize
pred = predict_node(p)
return Float64(sum(abs2, y_train .- pred))
end


## -----------------------------
####### Perform inference


### Fit neural ode to the data
@everywhere @model function fit_node(data)
σ ~ InverseGamma(2, 3)
p ~ MvNormal(pmin_spiral, 0.1)
# Calculate predictions for the inputs given the params.
predicted = train_prob(u0, p)
# observe each prediction.
for i = 1:size(predicted,2)
data[:,i] ~ MvNormal(predicted[:,i], σ)
end
end

@everywhere model = fit_node(y_train); # fit model to average simulated data

function perform_inference(lr, alpha, samplesize, pmin, num_chains)
alg = SGHMC(learning_rate=lr, momentum_decay=alpha)
chain = sample(model, alg, MCMCThreads(), samplesize, num_chains, progress=true);
return chain
end

function map_loss(chain)
chain_array = Array(chain)
k = size(chain_array,1)
losses = loss.([chain_array[i,:] for i in 1:k])
return losses
end

callback = function (p, l, param; doplot = true)
# plot current prediction against data
display(l)
plt = scatter(ode_data[1,:], ode_data[2,:], label = "data")
sol = prob_node(u0, param);
scatter!(plt, sol[1,:], sol[2,:], label = "prediction")
if doplot
display(plot(plt))
end
return false
end

# init at map point
using JLD
pinit = initial_params(dudt2);
opt = DiffEqFlux.sciml_train(loss, train_prob.p, ADAM(0.05), maxiters = 1500)


pmin = opt.minimizer;
save("pmin_spiral.jld", "pmin_spiral", pmin)

using JLD
pmin = load("pmin_spiral.jld")
pmin_spiral = pmin["pmin_spiral"]


sol = prob_node(u0, pmin_spiral);
plot()
display(scatter!(sol[1,:], sol[2,:]))
display(scatter!(ode_data[1,:], ode_data[2,:]))


function plot_chain(chain, losses)
pl = plot()
chain_array = Array(chain)
len = size(chain_array,1)

training_end = 1.0
tei = 50 #training_end_idx

scatter!(tsteps, ode_data[1,:], color = :red, label = "Data: Var1", title = "Spiral Neural ODE")
scatter!(tsteps, ode_data[2,:], color = :blue, label = "Data: Var2")
plot!([training_end-0.0001,training_end+0.0001],[-2.2,1.3],lw=3,color=:green,label="Training Data End", linestyle = :dash)


for k in 1:300
resol = prob_node(u0, chain_array[rand(100:len), :])
plot!(tsteps[1:tei], resol[1,:][1:tei], alpha=0.04, color = :red, label = "")
plot!(tsteps[1:tei], resol[2,:][1:tei], alpha=0.04, color = :blue, label = "")
plot!(tsteps[tei:end], resol[1,:][tei:end], alpha=0.04, color = :purple, label = "")
plot!(tsteps[tei:end], resol[2,:][tei:end], alpha=0.04, color = :purple, label = "")
end

idx = findmin(losses)[2]
prediction = prob_node(u0, chain_array[idx, :])
plot!(tsteps, prediction[1,:], color=:black, w=2, label = "")
plot!(tsteps, prediction[2,:], color=:black, w=2, label = "Training: Best fit prediction", ylims = (-2.5, 3.5))
plot!(tsteps[tei:end], prediction[1,:][tei:end], color = :purple, w = 2, label = "")
plot!(tsteps[tei:end], prediction[2,:][tei:end], color = :purple, w = 2, label = "Forecasting: Best fit prediction", ylims = (-2.5, 3.5))

display(plot!([training_end-0.0001,training_end+0.0001],[-1,5],lw=3,color=:green,label="Training Data End", linestyle = :dash))


################## COUNTOUR PLOTS ###################################

pl2 = scatter(ode_data[1,:], ode_data[2,:], color = :red, label = "Data", xlabel = "Var1", ylabel = "Var2", title = "Spiral Neural ODE")

for k in 1:300
resol = prob_node(u0, chain_array[rand(50:len), :])
plot!(resol[1,:][1:tei],resol[2,:][1:tei], alpha=0.04, color = :red, label = "")
plot!(resol[1,:][tei:end],resol[2,:][tei:end], alpha=0.1, color = :purple, label = "")

end

plot!(prediction[1,:], prediction[2,:], color = :black, w = 2, label = "Training: Best fit prediction", ylims = (-2.5, 3.5))
display(plot!(prediction[1,:][tei:end], prediction[2,:][tei:end], color = :purple, w = 2, label = "Forecasting: Best fit prediction", ylims = (-2.5, 3.5)))

return pl, pl2;
end


## ---------------------------------------------------
#
samples = 750

lr = 1.0e-6; md = 0.15;
num_chains = 6;
chain = perform_inference(lr, md, samples, pmin, num_chains);
for i in 1:num_chains
losses = map_loss(chain[:,:,i])
pl = plot(1:samples, losses); display(pl)
savefig(pl, string("spiral_", lr, "_", md, "_", samples, "_", "chain_", i+4, "_losses", ".png"))
pl_ch, pl2 = plot_chain(chain[:,:,i], losses)
savefig(pl_ch, string("spiral_", lr, "_", md, "_", samples, "_", "chain_", i+4, "_predictions", ".png"))
savefig(pl2, string("spiral_", lr, "_", md, "_", samples, "_", "chain_", i+4, "_contour", ".png"))
end

0 comments on commit 0ad23a3

Please sign in to comment.