diff --git a/Project.toml b/Project.toml index 2fb0b1c7a..186039852 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.4.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -66,6 +67,7 @@ AbstractPPL = "0.8.4, 0.9, 0.10, 0.11" Bijectors = "0.13, 0.14, 0.15" BridgeStan = "2" DataFrames = "1" +DataStructures = "0.18.22" DifferentiationInterface = "0.6.48" Distributions = "0.25" DocStringExtensions = "0.9" diff --git a/examples/stan/601-hmm.json b/examples/stan/601-hmm.json new file mode 100644 index 000000000..794e5842b --- /dev/null +++ b/examples/stan/601-hmm.json @@ -0,0 +1,153 @@ +{ + "N": 600, + "observations": [2.20041627e+00, 1.95769524e+00, -8.52087759e-01, -3.49801576e+00, + -2.75466819e+00, -7.84143831e+00, -7.00086086e+00, -1.00700985e+01, + -1.12625951e+01, -1.14712207e+01, -7.63410793e+00, -6.82938344e+00, + -8.94909208e+00, -1.05940554e+01, -1.01677096e+01, -1.07991865e+01, + -1.21380587e+01, -1.26044233e+01, -9.45156068e+00, -1.01137595e+01, + -1.20918822e+01, -8.58104339e+00, -9.46199003e+00, -6.16620129e+00, + -4.99686144e+00, -7.00417901e+00, -6.63052420e+00, -4.64233530e+00, + -3.70324733e+00, -3.32442030e+00, -4.88425406e-02, -5.51073794e-01, + -2.90981188e-02, 3.46099579e+00, -2.40716819e+00, -1.61163678e+00, + -5.16084106e-02, -3.91487830e+00, -2.05140073e+00, -1.28237909e+00, + -1.19528909e+00, -5.26269931e-01, -1.20447035e+00, 1.93218523e-01, + 3.08078980e+00, 1.80337827e+00, 1.57832252e+00, 9.34944592e-01, + 7.28679787e-01, 2.02044301e+00, 4.83385362e+00, 4.28691024e+00, + 6.39388537e+00, 3.85840269e+00, 2.80378245e+00, 5.52823636e+00, + 3.65488463e+00, 1.16215512e+00, -2.67688803e-01, -7.58287776e-01, + -7.02438545e-01, -6.81065563e-01, 4.78604164e+00, 2.50092926e+00, + 4.90315224e+00, 2.88597671e+00, 5.79131041e+00, 4.51960872e+00, + 2.96660337e+00, 2.05281631e+00, 2.94518154e+00, 3.80861207e+00, + 3.57769218e+00, 4.72358097e+00, 5.27504634e+00, 2.72113257e+00, + 5.62923032e+00, 4.89088010e+00, 4.85389967e+00, 2.52088724e+00, + 1.68862456e+00, 1.45954408e+00, 1.67063240e+00, 3.21727514e+00, + -2.66006003e+00, 1.96019559e+00, -9.23003292e-01, 2.02073660e-03, + -1.10989865e+00, -3.57514428e+00, -1.29680815e+00, 1.62500296e-01, + -1.53874849e-01, -5.00880738e-02, 7.12391454e-01, 2.04897962e+00, + 2.58430799e+00, -8.10793344e-02, 8.01232671e-01, 1.26356208e+00, + 2.42295161e+00, 4.78682256e+00, 1.64580644e+00, 7.49921116e-01, + 4.67797903e+00, 3.96665687e+00, 6.75268362e-01, 3.93938545e+00, + 4.02636161e+00, 4.44271098e+00, 5.19042137e+00, 2.58723543e+00, + 7.18735268e+00, 6.71400675e+00, 7.33104925e+00, 8.07291814e+00, + 1.01641232e+01, 7.12567322e+00, 5.56664294e+00, 8.06508369e+00, + 8.06292109e+00, 4.96482253e+00, 7.52618800e+00, 8.00781373e+00, + 3.75931312e+00, 9.41197214e-01, 7.20632190e+00, 7.21676797e+00, + 6.53389120e+00, 5.23761666e+00, 3.48087341e+00, 2.21970007e+00, + 4.29720023e+00, 4.70966628e+00, 4.09379447e+00, 6.49047977e+00, + 6.27283245e+00, 4.31824266e+00, 7.48014444e+00, 9.19625336e+00, + 9.37229733e+00, 1.08146351e+01, 1.11088290e+01, 1.02598495e+01, + 1.13194604e+01, 1.19072907e+01, 1.02393721e+01, 1.02645342e+01, + 1.01497331e+01, 1.41101105e+01, 8.85600537e+00, 7.99558228e+00, + 1.08893333e+01, 1.03842444e+01, 9.92441316e+00, 1.06790997e+01, + 1.12610941e+01, 1.02065620e+01, 8.74433367e+00, 7.94298791e+00, + 7.08367963e+00, 6.09791682e+00, 6.30358218e+00, 8.68326355e+00, + 1.12112083e+01, 1.15750269e+01, 1.38522936e+01, 9.37299819e+00, + 5.78993972e+00, 4.49590123e+00, 6.51469197e+00, 7.36551889e+00, + 7.32101484e+00, 7.67118026e+00, 6.33513478e+00, 6.60466730e+00, + 7.38936309e+00, 7.38584474e+00, 6.47593338e+00, 5.28929490e+00, + 5.40888051e+00, 7.68078795e+00, 9.59892683e+00, 5.20831396e+00, + 5.16531652e+00, 5.24548731e+00, 5.56486992e+00, 6.77264209e+00, + 1.61973862e+00, 5.37412033e+00, 3.03930250e+00, 4.47814021e+00, + 3.44894920e+00, 2.34977267e+00, 1.00274187e+00, 4.78782641e+00, + 5.00688524e+00, 5.42917704e+00, 8.74264437e+00, 1.08546136e+01, + 1.26508115e+01, 1.38369640e+01, 1.22559499e+01, 1.35193366e+01, + 1.45558898e+01, 1.98940529e+01, 1.56230795e+01, 1.44094233e+01, + 1.72861607e+01, 1.33902066e+01, 1.58302603e+01, 1.70575462e+01, + 1.46093763e+01, 1.56669888e+01, 1.61703988e+01, 1.75911871e+01, + 1.77323057e+01, 1.69723792e+01, 1.89714998e+01, 2.14827767e+01, + 1.71719057e+01, 2.01649086e+01, 2.02643282e+01, 2.23352866e+01, + 2.45251186e+01, 2.75240294e+01, 2.67906918e+01, 2.34239784e+01, + 2.47779506e+01, 2.41075077e+01, 2.38932096e+01, 2.37542881e+01, + 2.41000283e+01, 2.73505371e+01, 2.64419694e+01, 2.52751711e+01, + 2.80470395e+01, 2.57457760e+01, 2.59011799e+01, 2.48515667e+01, + 2.17817917e+01, 2.50912782e+01, 2.57779044e+01, 2.83524943e+01, + 2.57348726e+01, 2.64355250e+01, 2.59956499e+01, 2.67754598e+01, + 2.15529923e+01, 2.09129092e+01, 2.19243167e+01, 2.19974779e+01, + 2.19960943e+01, 2.21024865e+01, 2.52887950e+01, 2.40804220e+01, + 2.27381049e+01, 2.00207687e+01, 2.19041866e+01, 2.49407407e+01, + 2.27611733e+01, 2.64211439e+01, 2.79037728e+01, 2.62650187e+01, + 2.95082709e+01, 2.63434330e+01, 2.69195724e+01, 2.72645202e+01, + 2.42261173e+01, 2.61695505e+01, 2.38557180e+01, 2.48194219e+01, + 2.40467891e+01, 2.76022611e+01, 2.42635759e+01, 2.30312713e+01, + 1.95100536e+01, 2.10866278e+01, 2.27955861e+01, 2.51106600e+01, + 2.54197451e+01, 2.32745902e+01, 3.19060151e+01, 3.00306215e+01, + 3.11005590e+01, 3.29650208e+01, 3.57058316e+01, 3.67591464e+01, + 3.59652717e+01, 3.86935166e+01, 3.57911075e+01, 3.79579306e+01, + 3.84987982e+01, 3.56516264e+01, 3.08808791e+01, 3.14919873e+01, + 3.37580507e+01, 3.67153048e+01, 3.71685145e+01, 3.87817383e+01, + 3.96963411e+01, 3.83564527e+01, 3.93413077e+01, 3.75321922e+01, + 4.15746594e+01, 4.36385157e+01, 3.98425106e+01, 4.15453595e+01, + 3.94474951e+01, 3.80429200e+01, 3.58932846e+01, 3.55484494e+01, + 3.57650208e+01, 3.55469244e+01, 3.13031299e+01, 3.35075406e+01, + 3.30004314e+01, 3.33909262e+01, 3.38864614e+01, 3.51344117e+01, + 3.47694398e+01, 3.61716779e+01, 3.64870273e+01, 3.60862257e+01, + 3.77519484e+01, 3.93528421e+01, 3.94286091e+01, 4.03535430e+01, + 3.91922900e+01, 3.71025976e+01, 3.22791905e+01, 3.18571947e+01, + 3.00682943e+01, 3.15075832e+01, 3.11689635e+01, 3.36183250e+01, + 3.36876822e+01, 3.54169048e+01, 3.52179400e+01, 3.62088136e+01, + 3.30185438e+01, 4.01510887e+01, 3.64507373e+01, 3.77575895e+01, + 3.63044233e+01, 3.64477644e+01, 3.68859079e+01, 3.79617659e+01, + 3.77345991e+01, 3.73876702e+01, 4.00806336e+01, 4.19850377e+01, + 4.11555981e+01, 4.42295454e+01, 4.16191062e+01, 3.84851625e+01, + 3.89853669e+01, 4.07976633e+01, 4.27051670e+01, 4.00062468e+01, + 4.30978454e+01, 4.23601404e+01, 4.41922518e+01, 4.39965820e+01, + 4.67794112e+01, 5.02978572e+01, 4.56351465e+01, 4.48239990e+01, + 4.05876623e+01, 4.14133895e+01, 4.14028065e+01, 4.18386992e+01, + 4.13414260e+01, 4.11379412e+01, 4.11907486e+01, 4.09471608e+01, + 3.95399716e+01, 3.69100236e+01, 3.66860475e+01, 3.79205085e+01, + 3.34496798e+01, 3.78740709e+01, 3.47468014e+01, 3.24855140e+01, + 3.40201234e+01, 3.61644552e+01, 3.63968850e+01, 3.42714727e+01, + 3.41741941e+01, 3.11705640e+01, 2.83597819e+01, 3.07470347e+01, + 2.56121553e+01, 2.59008892e+01, 2.57916997e+01, 2.88839503e+01, + 2.96389389e+01, 2.76744072e+01, 2.61646143e+01, 2.57337329e+01, + 2.47063686e+01, 2.58303525e+01, 2.47858266e+01, 2.87182967e+01, + 3.04326208e+01, 2.94320397e+01, 3.07871136e+01, 2.99429604e+01, + 2.81048675e+01, 2.87370975e+01, 2.96210477e+01, 2.75839731e+01, + 2.91019014e+01, 2.92451746e+01, 2.84477817e+01, 2.91807574e+01, + 2.92471887e+01, 3.04269667e+01, 3.15569941e+01, 3.38214070e+01, + 3.40074353e+01, 3.16579505e+01, 3.38981861e+01, 3.28793324e+01, + 3.52557179e+01, 3.39301843e+01, 3.63010869e+01, 4.05274543e+01, + 4.07032346e+01, 3.85513408e+01, 4.02953088e+01, 3.70458613e+01, + 3.72419718e+01, 3.70039193e+01, 3.77764164e+01, 3.56303451e+01, + 3.62959585e+01, 3.78946072e+01, 3.62709061e+01, 3.76634104e+01, + 3.91337749e+01, 3.84004126e+01, 3.81241726e+01, 4.15581530e+01, + 4.01142431e+01, 3.63541133e+01, 3.86243449e+01, 4.01430134e+01, + 3.92610382e+01, 3.59249212e+01, 3.73114540e+01, 3.78569320e+01, + 3.84827639e+01, 3.73911966e+01, 3.60235311e+01, 3.55317825e+01, + 3.79022022e+01, 3.80714924e+01, 3.83597586e+01, 3.78601350e+01, + 3.73587552e+01, 3.60520322e+01, 3.63999056e+01, 3.57917551e+01, + 3.61548360e+01, 3.50228969e+01, 3.56458695e+01, 3.42961012e+01, + 3.58128162e+01, 3.52237851e+01, 3.47537293e+01, 3.68640322e+01, + 3.41322362e+01, 3.66964622e+01, 3.34686593e+01, 3.44682303e+01, + 3.45441304e+01, 3.58217338e+01, 3.58910303e+01, 3.73505089e+01, + 3.77782400e+01, 3.46415997e+01, 3.58564269e+01, 3.53247241e+01, + 3.36910112e+01, 3.27083918e+01, 3.13210632e+01, 2.77079547e+01, + 3.00823366e+01, 2.99031872e+01, 2.79379527e+01, 2.74939928e+01, + 2.75225223e+01, 2.82877226e+01, 2.83256334e+01, 2.83580797e+01, + 3.04150303e+01, 2.66598925e+01, 2.91794583e+01, 2.83726774e+01, + 2.51347763e+01, 2.25963168e+01, 2.18457418e+01, 2.42479872e+01, + 2.67623007e+01, 2.25385780e+01, 2.73269381e+01, 2.59379632e+01, + 2.10230150e+01, 2.07278808e+01, 2.00192286e+01, 2.01415859e+01, + 2.17867287e+01, 2.43700543e+01, 2.45504939e+01, 2.47710967e+01, + 2.38890926e+01, 2.54430988e+01, 2.19310121e+01, 2.01861465e+01, + 2.38134047e+01, 2.19633916e+01, 2.16153891e+01, 2.45069550e+01, + 2.04463694e+01, 1.90655406e+01, 2.09490496e+01, 2.29965558e+01, + 2.57579529e+01, 2.54823802e+01, 2.56986140e+01, 2.50657061e+01, + 2.61572543e+01, 2.35304540e+01, 2.03513795e+01, 2.47338959e+01, + 2.85361200e+01, 2.67531185e+01, 2.53851547e+01, 2.93807397e+01, + 2.93031483e+01, 3.07399296e+01, 2.82640182e+01, 2.87514656e+01, + 2.70374233e+01, 2.65811547e+01, 2.61418710e+01, 2.58550020e+01, + 2.71578776e+01, 2.64238308e+01, 2.84622023e+01, 2.79479797e+01, + 2.54783638e+01, 2.60418242e+01, 2.79443392e+01, 3.07094261e+01, + 2.99400838e+01, 2.82979800e+01, 3.07511641e+01, 3.33209394e+01, + 3.33287246e+01, 3.74255324e+01, 3.66325921e+01, 3.67588263e+01, + 3.43328954e+01, 3.80575146e+01, 3.61886716e+01, 3.53697709e+01, + 3.52197474e+01, 3.59416153e+01, 3.71610542e+01, 3.50041339e+01, + 3.61407066e+01, 3.55777199e+01, 3.57989925e+01, 3.61486268e+01, + 3.79354584e+01, 3.87753346e+01, 4.04170961e+01, 3.85544136e+01, + 3.75831990e+01, 3.90366313e+01, 3.49632852e+01, 3.27692015e+01, + 3.66075790e+01, 3.82664250e+01, 4.10635836e+01, 4.01232858e+01, + 3.82659471e+01, 3.86222563e+01, 3.95396663e+01, 4.07177845e+01, + 4.16468646e+01, 4.14681492e+01, 3.87589157e+01, 4.25642139e+01, + 4.28169508e+01, 4.05391734e+01, 3.96425358e+01, 3.65827556e+01] +} \ No newline at end of file diff --git a/examples/stan/601-hmm.stan b/examples/stan/601-hmm.stan new file mode 100644 index 000000000..db0c70178 --- /dev/null +++ b/examples/stan/601-hmm.stan @@ -0,0 +1,25 @@ +data { + int N; + vector[N] observations; +} + +parameters { + real log_sigma_transition; + vector[N] latents; +} + +transformed parameters { + real sigma_transition = exp(log_sigma_transition); +} + +model { + log_sigma_transition ~ normal(1,1); + + latents[1] ~ normal(2,0.5); + observations[1] ~ normal(latents[1],1); + + for (t in 2:N) { + latents[t] ~ normal(latents[t-1], sigma_transition); + observations[t] ~ normal(latents[t],1); + } +} \ No newline at end of file diff --git a/src/Pigeons.jl b/src/Pigeons.jl index 5df5da48a..2a9e36d57 100644 --- a/src/Pigeons.jl +++ b/src/Pigeons.jl @@ -38,6 +38,7 @@ using StaticArraysCore using Statistics using StatsBase using ZipFile +using DataStructures import Base: Forward, @kwdef, show, print, merge, keys import Base.Threads.@threads @@ -68,7 +69,7 @@ export pigeons, Inputs, PT, # getting information out of an execution: stepping_stone, n_tempered_restarts, n_round_trips, process_sample, get_sample, # variational references: - GaussianReference, + GaussianReference, TreeReference, DenseGaussianReference, # samplers SliceSampler, AutoMALA, Compose, AAPS, MALA, Mix diff --git a/src/explorers/GradientBasedSampler.jl b/src/explorers/GradientBasedSampler.jl index c558b955e..995b1555a 100644 --- a/src/explorers/GradientBasedSampler.jl +++ b/src/explorers/GradientBasedSampler.jl @@ -20,6 +20,6 @@ function gradient_based_sampler_recorders!(recorders, explorer::GradientBasedSam push!(recorders, buffers) push!(recorders, Pigeons.ad_buffers) if hasproperty(explorer, :preconditioner) && explorer.preconditioner isa AdaptedDiagonalPreconditioner - push!(recorders, _transformed_online) # for mass matrix adaptation + push!(recorders, _transformed_online_full) # for mass matrix adaptation end end diff --git a/src/includes.jl b/src/includes.jl index 5600eb669..e755b0036 100644 --- a/src/includes.jl +++ b/src/includes.jl @@ -79,6 +79,8 @@ include("targets/DistributionLogPotential.jl") include("pt/checks.jl") include("explorers/BufferedAD.jl") include("variational/GaussianReference.jl") +include("variational/TreeReference.jl") +include("variational/DenseGaussianReference.jl") include("variational/VariationalReference.jl") include("paths/ScaledPrecisionNormalPath.jl") include("explorers/invariance_test.jl") diff --git a/src/pt/pigeons.jl b/src/pt/pigeons.jl index 483d7fe6a..fd47c281c 100644 --- a/src/pt/pigeons.jl +++ b/src/pt/pigeons.jl @@ -111,7 +111,7 @@ function explore!(pt, replica, explorer) # for the online stats, we ignore pt.inputs.extended_traces # because the recorders do not support grouping by chains @record_if_requested!(replica.recorders, :online, extract_sample(replica.state, log_potential, pt.inputs.extractor)) - @record_if_requested!(replica.recorders, :_transformed_online, replica.state) + @record_if_requested!(replica.recorders, :_transformed_online_full, replica.state) end if pt.inputs.extended_traces || is_target(pt.shared.tempering.swap_graphs, replica.chain) @record_if_requested!( diff --git a/src/recorders/OnlineStateRecorder.jl b/src/recorders/OnlineStateRecorder.jl index e7a1a7a57..6fd9ff59d 100644 --- a/src/recorders/OnlineStateRecorder.jl +++ b/src/recorders/OnlineStateRecorder.jl @@ -2,10 +2,11 @@ See [`online()`](@ref). """ @kwdef struct OnlineStateRecorder + full::Bool = false stats::Dict{Pair{Symbol, Type}, Any} = Dict{Pair{Symbol, Type}, Any}() end -OnlineStateRecorder(from_another::OnlineStateRecorder) = OnlineStateRecorder(copy(from_another.stats)) +OnlineStateRecorder(from_another::OnlineStateRecorder) = OnlineStateRecorder(from_another.full, copy(from_another.stats)) """ $SIGNATURES @@ -30,9 +31,12 @@ get_transformed_statistic(reduced_recorders, variable_name::Symbol, t::Type{T}) get_statistic(reduced_recorders, variable_name, t, false) function get_statistic(reduced_recorders, variable_name::Symbol, ::Type{T}, original_param = true) where {T} - recorder = original_param ? reduced_recorders.online : reduced_recorders._transformed_online + recorder = original_param ? reduced_recorders.online : reduced_recorders._transformed_online_full key = Pair(variable_name, T) - v = value(recorder.stats[key]) + v = value(recorder.stats[key]) + if T==CovMatrix + return v + end return value.(v) end @@ -68,7 +72,7 @@ recorded_continuous_variables(state) = continuous_variables(state) `OnlineStat` types to be computed when the [`online()`] recorder is enabled. """ -const registered_online_types = [Mean, Variance] +const registered_online_types = [Mean, Variance, CovMatrix] """ $SIGNATURES @@ -101,6 +105,13 @@ initialize_online_state_recorder!(stats, state) = initialize_online_state_recorder!(stats, state, stat_type) end +initialize_online_state_recorder!(stats, state, ::Type{CovMatrix}) = + for name in recorded_continuous_variables(state) + var = variable(state, name) + key = Pair(name, CovMatrix) + stats[key] = CovMatrix() + end + initialize_online_state_recorder!(stats, state, ::Type{T}) where {T} = for name in recorded_continuous_variables(state) var = variable(state, name) diff --git a/src/recorders/recorder.jl b/src/recorders/recorder.jl index e090d13f7..996087382 100644 --- a/src/recorders/recorder.jl +++ b/src/recorders/recorder.jl @@ -99,7 +99,8 @@ transformed to be defined on an unconstrained space. This is used internally by [`explorer`](@ref)'s for adaptation purposes (in particular, pre-conditioning and variational references). """ -@provides recorder _transformed_online() = OnlineStateRecorder() +@provides recorder _transformed_online() = OnlineStateRecorder() +@provides recorder _transformed_online_full() = OnlineStateRecorder(full=true) """ Restart and round-trip counts. diff --git a/src/variational/DenseGaussianReference.jl b/src/variational/DenseGaussianReference.jl new file mode 100644 index 000000000..68becc164 --- /dev/null +++ b/src/variational/DenseGaussianReference.jl @@ -0,0 +1,95 @@ +""" +A Gaussian dense variational reference (i.e., with a dense covariance matrix). +""" +@kwdef mutable struct DenseGaussianReference + mean::Vector{Any} = Vector{Any}() + covariance::Matrix{Float64} = zeros(Float64, 0, 0) + precision::Matrix{Float64} = zeros(Float64, 0, 0) + cholesky::Any = zeros(Float64, 0, 0) + which_variable::Vector{Any} = Vector{Any}() + which_index::Vector{Int} = Vector{Int}() + identity_gaussian::Any = zeros(Float64, 0, 0) + first_tuning_round::Int = 10 # TODO: this should be moved elsewhere? + + function DenseGaussianReference(mean, covariance, precision, cholesky, which_variable, which_index, identity_gaussian, first_tuning_round) + @assert first_tuning_round ≥ 1 + new(mean, covariance, precision, cholesky, which_variable, which_index, identity_gaussian, first_tuning_round) + end +end + +dim(variational::DenseGaussianReference) = length(variational.mean) +function activate_variational(variational::DenseGaussianReference, iterators::Iterators) + iterators.round ≥ variational.first_tuning_round ? true : false +end +variational_recorder_builders(::DenseGaussianReference) = [_transformed_online_full] + +function update_reference!(reduced_recorders, variational::DenseGaussianReference, state) + isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") + + empty!(variational.which_variable) + empty!(variational.which_index) + empty!(variational.mean) + empty!(variational.which_variable) + empty!(variational.which_index) + + eps = 1e-6 + temp_covariance = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) + variational.covariance = temp_covariance + eps * I + + variational.mean = get_transformed_statistic(reduced_recorders, :singleton_variable, Mean) + variational.identity_gaussian = MvNormal(zeros(length(variational.mean)), I) + variational.precision = inv(variational.covariance) + variational.cholesky = cholesky(variational.covariance).L + + for var_name in continuous_variables(state) + for i = 1:length(variable(state, var_name)) + push!(variational.which_variable, var_name) + push!(variational.which_index, i) + end + end + +end + +function sample_iid!(variational::DenseGaussianReference, replica, shared) + z = rand(variational.identity_gaussian) + sample = variational.mean + variational.cholesky * z + + for i = 1:length(variational.mean) + update_state!(replica.state, variational.which_variable[i], variational.which_index[i], sample[i]) + end +end + +function (variational::DenseGaussianReference)(state) + flattened_state = Vector{Float64}() + + for i = 1:length(variational.mean) + name = variational.which_variable[i] + index = variational.which_index[i] + push!(flattened_state, Pigeons.variable(state, name)[index]) + end + + return -0.5 * (transpose(flattened_state - variational.mean) * variational.precision * (flattened_state - variational.mean)) +end + + + +# LogDensityProblemsAD implementation (currently only for special case of a singleton variable) + +LogDensityProblems.logdensity(log_potential::DenseGaussianReference, x) = + log_potential(x) + +function LogDensityProblems.dimension(log_potential::DenseGaussianReference) + return length(log_potential.mean) +end + +LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType, log_potential::DenseGaussianReference, replica::Replica) = + BufferedAD(log_potential, replica.recorders.buffers) + +function LogDensityProblems.logdensity_and_gradient(log_potential::BufferedAD{DenseGaussianReference}, x) + variational = log_potential.enclosed + buffer = log_potential.buffer + mean = variational.mean + precision = variational.precision + buffer .= -precision * (x - mean) + return LogDensityProblems.logdensity(variational, x), buffer +end \ No newline at end of file diff --git a/src/variational/GaussianReference.jl b/src/variational/GaussianReference.jl index a2e225c26..33a523118 100644 --- a/src/variational/GaussianReference.jl +++ b/src/variational/GaussianReference.jl @@ -4,7 +4,7 @@ A Gaussian mean-field variational reference (i.e., with a diagonal covariance ma @kwdef mutable struct GaussianReference mean::Dict{Symbol, Any} = Dict{Symbol, Any}() standard_deviation::Dict{Symbol, Any} = Dict{Symbol, Any}() - first_tuning_round::Int = 6 # TODO: this should be moved elsewhere? + first_tuning_round::Int = 11 # TODO: this should be moved elsewhere? function GaussianReference(mean, standard_deviation, first_tuning_round) @assert length(mean) == length(standard_deviation) @@ -17,7 +17,7 @@ dim(variational::GaussianReference) = length(variational.mean) function activate_variational(variational::GaussianReference, iterators::Iterators) iterators.round ≥ variational.first_tuning_round ? true : false end -variational_recorder_builders(::GaussianReference) = [_transformed_online] +variational_recorder_builders(::GaussianReference) = [_transformed_online_full] function update_reference!(reduced_recorders, variational::GaussianReference, state) isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl new file mode 100644 index 000000000..79cb4a606 --- /dev/null +++ b/src/variational/TreeReference.jl @@ -0,0 +1,216 @@ +""" +A Gaussian tree variational reference +""" + +@kwdef mutable struct TreeReference + edge_set::Vector{Tuple{Float64, Float64, Int, Int}} = Vector{Tuple{Float64, Float64, Int, Int}}() + mean::Vector{Float64} = Vector{Float64}() + standard_deviation::Vector{Float64} = Vector{Float64}() + which_variable::Vector{Symbol} = Vector{Symbol}() + which_index::Vector{Int} = Vector{Int}() + iid_sample_set::Vector{Float64} = Vector{Float64}() + covariance_matrix::Matrix{Float64} = zeros(Float64, 0, 0) + first_tuning_round::Int = 10 + + function TreeReference(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, covariance_matrix, first_tuning_round) + @assert first_tuning_round ≥ 1 + new(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, covariance_matrix, first_tuning_round) + end +end + + +dim(variational::TreeReference) = length(variational.mean) +function activate_variational(variational::TreeReference, iterators) + iterators.round ≥ variational.first_tuning_round ? true : false +end + +variational_recorder_builders(::TreeReference) = [_transformed_online_full] + + +function update_reference!(reduced_recorders, variational::TreeReference, state) + isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") + + empty!(variational.mean) + empty!(variational.standard_deviation) + empty!(variational.which_variable) + empty!(variational.which_index) + variational.edge_set = [] + + for var_name in continuous_variables(state) + temp_mean = get_transformed_statistic(reduced_recorders, var_name, Mean) + temp_std = sqrt.(get_transformed_statistic(reduced_recorders, var_name, Variance)) + + dimension = length(temp_mean) + for i = 1:dimension + push!(variational.mean, temp_mean[i]) + push!(variational.standard_deviation, temp_std[i]) + + push!(variational.which_variable, var_name) + push!(variational.which_index, i) + end + end + @assert length(variational.mean) == length(variational.standard_deviation) + + variational.iid_sample_set = zeros(length(variational.mean)) + variational.covariance_matrix = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) + variational.edge_set = build_tree(variational) +end + + +function build_tree(variational::TreeReference) + dim = length(variational.mean) + + adjacency_list::Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}} = Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}}() + for i in 1:dim + adjacency_list[i] = Vector{Tuple{Float64, Float64, Int, Int}}() + end + + for i = 1:dim + for j = (i+1):dim + normalization = (variational.standard_deviation[i] * variational.standard_deviation[j]) + rho = variational.covariance_matrix[i,j] / normalization + rho = clamp(rho, -0.99, 0.99) + I = -0.5*log(1-rho^2) + + push!(adjacency_list[i], (I, rho, i, j)) + push!(adjacency_list[j], (I, rho, j, i)) + end + end + root = 1 + return directed_max_tree(adjacency_list, root) +end + + +function directed_max_tree(adjacency_list, root) + total_number_of_nodes = length(keys(adjacency_list)) + mst = Vector{Tuple{Float64, Float64, Int, Int}}() + pq = BinaryMaxHeap{Tuple{Float64, Float64, Int, Int}}() + visited_nodes = Set{Int}() + + push!(visited_nodes, root) + for edge in adjacency_list[root] + push!(pq, edge) + end + + while !isempty(pq) && length(mst)