From a1f4b97017bb5d3f5162f25aed2cc734c817e624 Mon Sep 17 00:00:00 2001 From: Jack Date: Thu, 26 Jun 2025 16:19:07 -0700 Subject: [PATCH 01/27] init TreeReference.jl class, added todos --- src/variational/TreeReference.jl | 69 ++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 src/variational/TreeReference.jl diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl new file mode 100644 index 000000000..7f14ce91d --- /dev/null +++ b/src/variational/TreeReference.jl @@ -0,0 +1,69 @@ +""" +A Gaussian tree variational reference +""" + +#TODO +@kwdef mutable struct TreeReference + edges + num_nodes + first_tuning_round +end + +#TODO +@kwdef mutable struct node + data + mean + variance +end + + +dim(variational::TreeReference) = variational.num_nodes +#TODO +function activate_variational(variational::TreeReference, iterators::Iterators) +end +#TODO +variational_recorder_builders(::TreeReference) = [_transformed_online] + + + +#TODO +function update_reference!(reduced_recorders, variational::TreeReference, state) +end +#TODO +function compute_mutual_info() +end +#TODO +function tree_decomposition() +end +#TODO +function prims_spanning_tree() +end + + + +#TODO +function sample_iid!(variational::TreeReference, replica, shared) +end +#TODO +function (variational::TreeReference)(state) +end +#TODO +function tree_logdensity() +end + + + +# LogDensityProblemsAD implementation (currently only for special case of a singleton variable) +#TODO +LogDensityProblems.logdensity(log_potential::TreeReference, x) = + tree_logdensity() + +#TODO +function LogDensityProblems.dimension(log_potential::TreeReference) +end +#TODO +LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType. log_potential::TreeReference, replica::Replica) = + BufferedAD(log_potential, replica.recorders.buffers) +#TODO +function LogDensityProblems.logdensity_and_gradient(log_potential::BufferedAD{TreeReference}, x) +end \ No newline at end of file From 45d35107291755073031e5e6449cc5be3f9890c7 Mon Sep 17 00:00:00 2001 From: Jack Date: Sat, 28 Jun 2025 15:48:51 -0700 Subject: [PATCH 02/27] added constructor --- src/variational/TreeReference.jl | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index 7f14ce91d..4f624ebb3 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -2,38 +2,36 @@ A Gaussian tree variational reference """ -#TODO + @kwdef mutable struct TreeReference - edges - num_nodes - first_tuning_round -end + edges::Vector{Tuple{Symbol, Symbol, Float32}} = Vector{Tuple{Symbol, Symbol, Float32}}() + num_nodes::Int = 0 + first_tuning_round::Int = 6 -#TODO -@kwdef mutable struct node - data - mean - variance + function TreeReference(edges, num_nodes, first_tuning_round) + @assert length(edges)==num_nodes-1 || length(edges)==(num_nodes*(num_nodes-1))/2 + @assert first_tuning_round ≥ 1 + new(edges, num_nodes, first_tuning_round) + end end dim(variational::TreeReference) = variational.num_nodes -#TODO function activate_variational(variational::TreeReference, iterators::Iterators) + iterators.round ≥ variational.first_tuning_round ? true : false end -#TODO -variational_recorder_builders(::TreeReference) = [_transformed_online] +variational_recorder_builders(::TreeReference) = [_transformed_online] #TODO function update_reference!(reduced_recorders, variational::TreeReference, state) end #TODO -function compute_mutual_info() +function tree_decomposition() end #TODO -function tree_decomposition() +function compute_mutual_info() end #TODO function prims_spanning_tree() From 92c0ab6d010b8277fd61b869dfbf9e8f69ad45af Mon Sep 17 00:00:00 2001 From: Jack Date: Sat, 28 Jun 2025 23:05:36 -0700 Subject: [PATCH 03/27] added logdensity evaluation for 1D variables --- src/variational/TreeReference.jl | 55 ++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index 4f624ebb3..df5cb0c8e 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -4,9 +4,11 @@ A Gaussian tree variational reference @kwdef mutable struct TreeReference - edges::Vector{Tuple{Symbol, Symbol, Float32}} = Vector{Tuple{Symbol, Symbol, Float32}}() + edges::Vector{Tuple{Symbol, Symbol, Any}} = Vector{Tuple{Symbol, Symbol, Any}}() num_nodes::Int = 0 - first_tuning_round::Int = 6 + node_means::Dict{Symbol, Any} = Dict{Symbol, Any}() + node_variances::Dict{Symbol, Any} = Dict{Symbol, Any}() + first_tuning_round::Int = 6 function TreeReference(edges, num_nodes, first_tuning_round) @assert length(edges)==num_nodes-1 || length(edges)==(num_nodes*(num_nodes-1))/2 @@ -33,6 +35,7 @@ end #TODO function compute_mutual_info() end + #TODO function prims_spanning_tree() end @@ -42,26 +45,60 @@ end #TODO function sample_iid!(variational::TreeReference, replica, shared) end -#TODO + + function (variational::TreeReference)(state) + log_pdf = 0.0 + + marginal_var_name = continuous_variables(state)[1] + marginal_state = variable(state, marginal_var_name) + marginal_mean = variational.node_means[marginal_var_name] + marginal_variance = variational.node_variances[marginal_var_name] + log_pdf += logpdf(Normal(marginal_mean, sqrt(marginal_variance)), marginal_state) + + for edges in variational.edges + parent_var_name = edges[1] + child_var_name = edges[2] + + state_at_parent = variable(state, parent_var_name) + state_at_child = variable(state, child_var_name) + + log_pdf += tree_logdensity(variational, child_var_name, parent_var_name, state_at_child, state_at_parent) + end + return log_pdf end + + +function tree_logdensity(variational::TreeReference, child_var_name, parent_var_name, state_at_child, state_at_parent) + child_mean = variational.node_means[child_var_name] + parent_mean = variational.node_means[parent_var_name] + child_variance = variational.node_variances[child_var_name] + parent_variance = variational.node_variances[child_var_name] + + rho = get_rho(parent_var_name, child_var_name) + + new_mu = child_mean + rho * (sqrt(child_variance) / sqrt(parent_variance)) * (state_at_parent - parent_mean) + new_sigma = sqrt((1-rho^2) * child_variance) + + logdensity = logpdf(Normal(new_mu, new_sigma), state_at_child) + + return logdensity +end + #TODO -function tree_logdensity() +function get_rho(var_name1, var_name2) end # LogDensityProblemsAD implementation (currently only for special case of a singleton variable) #TODO -LogDensityProblems.logdensity(log_potential::TreeReference, x) = - tree_logdensity() - +LogDensityProblems.logdensity(log_potential::TreeReference, x) #TODO function LogDensityProblems.dimension(log_potential::TreeReference) end #TODO -LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType. log_potential::TreeReference, replica::Replica) = - BufferedAD(log_potential, replica.recorders.buffers) +LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType. log_potential::TreeReference, replica::Replica) #TODO function LogDensityProblems.logdensity_and_gradient(log_potential::BufferedAD{TreeReference}, x) end \ No newline at end of file From 7975cfdcaa8eced6285e36cac8a044a503a7a6f9 Mon Sep 17 00:00:00 2001 From: Jack Date: Sun, 29 Jun 2025 15:09:18 -0700 Subject: [PATCH 04/27] refactored to use std instead of variance --- src/variational/TreeReference.jl | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index df5cb0c8e..e4e6704a8 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -5,15 +5,15 @@ A Gaussian tree variational reference @kwdef mutable struct TreeReference edges::Vector{Tuple{Symbol, Symbol, Any}} = Vector{Tuple{Symbol, Symbol, Any}}() - num_nodes::Int = 0 - node_means::Dict{Symbol, Any} = Dict{Symbol, Any}() - node_variances::Dict{Symbol, Any} = Dict{Symbol, Any}() + means::Dict{Symbol, Any} = Dict{Symbol, Any}() + standard_deviations::Dict{Symbol, Any} = Dict{Symbol, Any}() first_tuning_round::Int = 6 + num_nodes::Int = 0 - function TreeReference(edges, num_nodes, first_tuning_round) + function TreeReference(edges, means, standard_deviations, first_tuning_round, num_nodes) @assert length(edges)==num_nodes-1 || length(edges)==(num_nodes*(num_nodes-1))/2 @assert first_tuning_round ≥ 1 - new(edges, num_nodes, first_tuning_round) + new(edges, means, standard_deviations, first_tuning_round, num_nodes) end end @@ -44,6 +44,9 @@ end #TODO function sample_iid!(variational::TreeReference, replica, shared) + for var_name in continuous_variables(replica.state) + for i in eachindex(variational.mean[var_name]) + val = randn(replica.rng) * variational_standard end @@ -52,9 +55,9 @@ function (variational::TreeReference)(state) marginal_var_name = continuous_variables(state)[1] marginal_state = variable(state, marginal_var_name) - marginal_mean = variational.node_means[marginal_var_name] - marginal_variance = variational.node_variances[marginal_var_name] - log_pdf += logpdf(Normal(marginal_mean, sqrt(marginal_variance)), marginal_state) + marginal_mean = variational.means[marginal_var_name] + marginal_standard_deviation = variational.standard_deviations[marginal_var_name] + log_pdf += logpdf(Normal(marginal_mean, marginal_standard_deviation), marginal_state) for edges in variational.edges parent_var_name = edges[1] @@ -70,15 +73,15 @@ end function tree_logdensity(variational::TreeReference, child_var_name, parent_var_name, state_at_child, state_at_parent) - child_mean = variational.node_means[child_var_name] - parent_mean = variational.node_means[parent_var_name] - child_variance = variational.node_variances[child_var_name] - parent_variance = variational.node_variances[child_var_name] + child_mean = variational.means[child_var_name] + parent_mean = variational.means[parent_var_name] + child_standard_deviation = variational.standard_deviations[child_var_name] + parent_standard_deviation = variational.standard_deviations[child_var_name] rho = get_rho(parent_var_name, child_var_name) - new_mu = child_mean + rho * (sqrt(child_variance) / sqrt(parent_variance)) * (state_at_parent - parent_mean) - new_sigma = sqrt((1-rho^2) * child_variance) + new_mu = child_mean + rho * (child_standard_deviation / parent_standard_deviation) * (state_at_parent - parent_mean) + new_sigma = sqrt((1-rho^2) * (child_standard_deviation)^2) logdensity = logpdf(Normal(new_mu, new_sigma), state_at_child) From 7cf863b54a07e65f512f08231e8792e9ab5cb555 Mon Sep 17 00:00:00 2001 From: Jack Date: Mon, 30 Jun 2025 12:51:15 -0700 Subject: [PATCH 05/27] added greedy search algorithm --- src/variational/TreeReference.jl | 74 +++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 24 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index e4e6704a8..923901433 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -2,23 +2,26 @@ A Gaussian tree variational reference """ +import Pkg +Pkg.add("DataStructures") +using DataStructures @kwdef mutable struct TreeReference - edges::Vector{Tuple{Symbol, Symbol, Any}} = Vector{Tuple{Symbol, Symbol, Any}}() - means::Dict{Symbol, Any} = Dict{Symbol, Any}() - standard_deviations::Dict{Symbol, Any} = Dict{Symbol, Any}() - first_tuning_round::Int = 6 - num_nodes::Int = 0 - - function TreeReference(edges, means, standard_deviations, first_tuning_round, num_nodes) - @assert length(edges)==num_nodes-1 || length(edges)==(num_nodes*(num_nodes-1))/2 + edge_set::Vector{Tuple{Float64, Symbol, Symbol}} = Vector{Tuple{Float64, Symbol, Symbol}}() + mean::Dict{Symbol, Any} = Dict{Symbol, Any}() + standard_deviation::Dict{Symbol, Any} = Dict{Symbol, Any}() + first_tuning_round::Int = 6 + + function TreeReference(edge_set, mean, standard_deviation, first_tuning_round) + dim = length(mean) + @assert length(edge_set)==dim-1 || length(edge_set)==(dim*(dim-1))/2 @assert first_tuning_round ≥ 1 - new(edges, means, standard_deviations, first_tuning_round, num_nodes) + new(edge_set, mean, standard_deviation, first_tuning_round) end end -dim(variational::TreeReference) = variational.num_nodes +dim(variational::TreeReference) = length(variational.mean) function activate_variational(variational::TreeReference, iterators::Iterators) iterators.round ≥ variational.first_tuning_round ? true : false end @@ -36,17 +39,40 @@ end function compute_mutual_info() end -#TODO -function prims_spanning_tree() + +function directed_max_tree(adjacency_list, root) + total_number_of_nodes = length(keys(adjacency_list)) + mst = Vector{Tuple{Symbol, Symbol}}() + pq = BinaryMaxHeap{Tuple{Float64, Symbol, Symbol}}() + visited_nodes = Set{Symbol}() + push!(visited_nodes, root) + + for edge in adjacency_list[root] + push!(pq, edge) + end + + while !isempty(pq) && length(mst) Date: Thu, 3 Jul 2025 17:11:06 -0700 Subject: [PATCH 06/27] added update_ref! mechanism, light refactoring --- src/variational/TreeReference.jl | 85 +++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index 923901433..91f1d91cb 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -7,16 +7,15 @@ Pkg.add("DataStructures") using DataStructures @kwdef mutable struct TreeReference - edge_set::Vector{Tuple{Float64, Symbol, Symbol}} = Vector{Tuple{Float64, Symbol, Symbol}}() - mean::Dict{Symbol, Any} = Dict{Symbol, Any}() - standard_deviation::Dict{Symbol, Any} = Dict{Symbol, Any}() + edge_set::Vector{Any} = Vector{Any}() + mean::Dict{Tuple{Symbol, Vector{Any}}} = Dict{Symbol, Vector{Any}}() + standard_deviation::Dict{Symbol, Vector{Any}} = Dict{Symbol, Vector{Any}}() + which_variable::Vector{Symbol} first_tuning_round::Int = 6 - function TreeReference(edge_set, mean, standard_deviation, first_tuning_round) - dim = length(mean) - @assert length(edge_set)==dim-1 || length(edge_set)==(dim*(dim-1))/2 + function TreeReference(edge_set, mean, standard_deviation, which_variable, first_tuning_round) @assert first_tuning_round ≥ 1 - new(edge_set, mean, standard_deviation, first_tuning_round) + new(edge_set, mean, standard_deviation, which_variable, first_tuning_round) end end @@ -31,34 +30,62 @@ variational_recorder_builders(::TreeReference) = [_transformed_online] #TODO function update_reference!(reduced_recorders, variational::TreeReference, state) + isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") + + for var_name in continuous_variables(state) + temp_mean = get_transformed_statistic(reduced_recorders, var_name, Mean) + temp_std = sqrt.(get_transformed_statistic(reduced_recorders, var_name, Variance)) + + dimension = length(temp_mean) + for i = 1:dimension + push!(variational.mean, temp_mean[i]) + push!(variational.standard_deviation, temp_std[i]) + end + end + @assert length(variational.mean) == length(variational.standard_deviation) + + adjacency_list::Dict{Symbol, Vector{Any}} = Dict{Symbol, Vector{Any}}() + + total_number_of_nodes = length(variational.mean) + for i = 1:total_number_of_nodes + for j = 1:total_number_of_nodes + if i != j + I = compute_mutual_info(i, j) + + push!(adjacency_list[i], (I, i, j)) + push!(adjacency_list[j], (I, i, j)) + end + end + end + root = 1 + variational.edge_set = directed_max_tree(adjacency_list, root) end -#TODO -function tree_decomposition() -end -#TODO -function compute_mutual_info() + + +function compute_mutual_info(i, j) + return -0.5*log(1-get_rho(i, j)^2) end function directed_max_tree(adjacency_list, root) total_number_of_nodes = length(keys(adjacency_list)) - mst = Vector{Tuple{Symbol, Symbol}}() - pq = BinaryMaxHeap{Tuple{Float64, Symbol, Symbol}}() - visited_nodes = Set{Symbol}() + mst = Vector{Tuple{Int, Int}}() + pq = BinaryMaxHeap{Tuple{Float64, Int, Int}}() + visited_nodes = Set{Int}() + push!(visited_nodes, root) - for edge in adjacency_list[root] push!(pq, edge) end while !isempty(pq) && length(mst) Date: Thu, 3 Jul 2025 17:22:34 -0700 Subject: [PATCH 07/27] added iid reference sampling mechanism --- src/variational/TreeReference.jl | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index 91f1d91cb..8dcdac0e4 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -11,11 +11,12 @@ using DataStructures mean::Dict{Tuple{Symbol, Vector{Any}}} = Dict{Symbol, Vector{Any}}() standard_deviation::Dict{Symbol, Vector{Any}} = Dict{Symbol, Vector{Any}}() which_variable::Vector{Symbol} + which_index::Vector{Int} first_tuning_round::Int = 6 - function TreeReference(edge_set, mean, standard_deviation, which_variable, first_tuning_round) + function TreeReference(edge_set, mean, standard_deviation, which_variable, which_index, first_tuning_round) @assert first_tuning_round ≥ 1 - new(edge_set, mean, standard_deviation, which_variable, first_tuning_round) + new(edge_set, mean, standard_deviation, which_variable, which_index, first_tuning_round) end end @@ -28,7 +29,6 @@ end variational_recorder_builders(::TreeReference) = [_transformed_online] -#TODO function update_reference!(reduced_recorders, variational::TreeReference, state) isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") @@ -40,6 +40,9 @@ function update_reference!(reduced_recorders, variational::TreeReference, state) for i = 1:dimension push!(variational.mean, temp_mean[i]) push!(variational.standard_deviation, temp_std[i]) + + push!(variational.which_variable, var_name) + push!(variational.which_index, i) end end @assert length(variational.mean) == length(variational.standard_deviation) @@ -53,7 +56,7 @@ function update_reference!(reduced_recorders, variational::TreeReference, state) I = compute_mutual_info(i, j) push!(adjacency_list[i], (I, i, j)) - push!(adjacency_list[j], (I, i, j)) + push!(adjacency_list[j], (I, j, i)) end end end @@ -98,7 +101,18 @@ end function sample_iid!(variational::TreeReference, replica, shared) + new_state::Vector{Int} = Vector{Int}() + + marginal_val = randn(replica.rng) * variational.standard_deviation[1] + variational.mean[1] + push!(new_state, marginal_val) + update_state!(replica.state, which_variable[1], 1, marginal_val) + for edge in variational.edge_set + params = tree_logdensity(variational, which_variable[edge[3]], which_variable[edge[2]], new_state[edge[2]]) + val = rand(replica.rng) * params[2] + params[1] + + update_state!(replica.state, which_variable[edge[3]], which_index[edge[3]], val) + end end From e612b3a3e08c7fb1cfb64e4a7fd696b82c0d017d Mon Sep 17 00:00:00 2001 From: Jack Date: Fri, 4 Jul 2025 15:57:17 -0700 Subject: [PATCH 08/27] added partial singleton ad implementation --- src/variational/TreeReference.jl | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index 8dcdac0e4..f21d78883 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -125,7 +125,7 @@ function (variational::TreeReference)(state) marginal_standard_deviation = variational.standard_deviation[marginal_var_name] log_pdf += logpdf(Normal(marginal_mean, marginal_standard_deviation), marginal_state) - for edge in variational.edge_set + for edge in variational.edge_set parent_var_name = which_variable[edge[1]] child_var_name = which_variable[edge[2]] @@ -155,18 +155,27 @@ end #TODO function get_rho(var_name1, var_name2) + return 0 end # LogDensityProblemsAD implementation (currently only for special case of a singleton variable) #TODO -LogDensityProblems.logdensity(log_potential::TreeReference, x) -#TODO +LogDensityProblems.logdensity(log_potential::TreeReference, x) = 0 + function LogDensityProblems.dimension(log_potential::TreeReference) + @assert length(log_potential.mean) == 1 && haskey(log_potential.mean, :singleton_variable) "Differentiation of TreeReference assuming a single flat vector called :singleton_variable at the moment. Found: $(keys(log_potential.mean))" end -#TODO -LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType. log_potential::TreeReference, replica::Replica) -#TODO + +LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType. log_potential::TreeReference, replica::Replica) = + BufferedAD(log_potential, replica.recorders.buffers) + function LogDensityProblems.logdensity_and_gradient(log_potential::BufferedAD{TreeReference}, x) + variational = log_potential.enclosed + buffer = log_potential.buffer + mean = variational.mean[:singleton_variable] + standard_deviation = variational.standard_deviation[:singleton_variable] + @. buffer = - 1.0/(standard_deviation^2) * (x - mean) + return LogDensityProblems.logdensity(variational, x), buffer end \ No newline at end of file From 8794ac378accb457aafab3bfd37ae0989df4695e Mon Sep 17 00:00:00 2001 From: Jack Date: Fri, 4 Jul 2025 22:59:48 -0700 Subject: [PATCH 09/27] integrated treeref into codebase dependencies --- Project.toml | 2 ++ src/Pigeons.jl | 3 ++- src/includes.jl | 1 + src/variational/TreeReference.jl | 4 ---- src/variational/VariationalReference.jl | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 2fb0b1c7a..186039852 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.4.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -66,6 +67,7 @@ AbstractPPL = "0.8.4, 0.9, 0.10, 0.11" Bijectors = "0.13, 0.14, 0.15" BridgeStan = "2" DataFrames = "1" +DataStructures = "0.18.22" DifferentiationInterface = "0.6.48" Distributions = "0.25" DocStringExtensions = "0.9" diff --git a/src/Pigeons.jl b/src/Pigeons.jl index 5df5da48a..0194c5b08 100644 --- a/src/Pigeons.jl +++ b/src/Pigeons.jl @@ -38,6 +38,7 @@ using StaticArraysCore using Statistics using StatsBase using ZipFile +using DataStructures import Base: Forward, @kwdef, show, print, merge, keys import Base.Threads.@threads @@ -68,7 +69,7 @@ export pigeons, Inputs, PT, # getting information out of an execution: stepping_stone, n_tempered_restarts, n_round_trips, process_sample, get_sample, # variational references: - GaussianReference, + GaussianReference, TreeReference, # samplers SliceSampler, AutoMALA, Compose, AAPS, MALA, Mix diff --git a/src/includes.jl b/src/includes.jl index 5600eb669..f5dc5db0f 100644 --- a/src/includes.jl +++ b/src/includes.jl @@ -79,6 +79,7 @@ include("targets/DistributionLogPotential.jl") include("pt/checks.jl") include("explorers/BufferedAD.jl") include("variational/GaussianReference.jl") +include("variational/TreeReference.jl") include("variational/VariationalReference.jl") include("paths/ScaledPrecisionNormalPath.jl") include("explorers/invariance_test.jl") diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index f21d78883..fd382547e 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -2,10 +2,6 @@ A Gaussian tree variational reference """ -import Pkg -Pkg.add("DataStructures") -using DataStructures - @kwdef mutable struct TreeReference edge_set::Vector{Any} = Vector{Any}() mean::Dict{Tuple{Symbol, Vector{Any}}} = Dict{Symbol, Vector{Any}}() diff --git a/src/variational/VariationalReference.jl b/src/variational/VariationalReference.jl index b95f130d4..20d8d0b0d 100644 --- a/src/variational/VariationalReference.jl +++ b/src/variational/VariationalReference.jl @@ -3,7 +3,7 @@ Methods common to all variational references =# # Currently implemented variational references -const VariationalReference = Union{GaussianReference} +const VariationalReference = Union{GaussianReference, TreeReference} # Elide the AD buffering system # Reasoning: From 5bd273881aae9d0bc4127ffdf30aa6745301d1f4 Mon Sep 17 00:00:00 2001 From: Jack Date: Sun, 6 Jul 2025 12:33:57 -0700 Subject: [PATCH 10/27] fixed compilation bugs --- src/variational/TreeReference.jl | 84 +++++++++++++++++--------------- 1 file changed, 45 insertions(+), 39 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index fd382547e..6cae0bc6d 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -4,21 +4,22 @@ A Gaussian tree variational reference @kwdef mutable struct TreeReference edge_set::Vector{Any} = Vector{Any}() - mean::Dict{Tuple{Symbol, Vector{Any}}} = Dict{Symbol, Vector{Any}}() - standard_deviation::Dict{Symbol, Vector{Any}} = Dict{Symbol, Vector{Any}}() - which_variable::Vector{Symbol} - which_index::Vector{Int} + mean::Vector{Any} = Vector{Any}() + standard_deviation::Vector{Any} = Vector{Any}() + which_variable::Vector{Any} = Vector{Any}() + which_index::Vector{Int} = Vector{Int}() + iid_sample_set::Vector{Any} = Vector{Int}() first_tuning_round::Int = 6 - function TreeReference(edge_set, mean, standard_deviation, which_variable, which_index, first_tuning_round) + function TreeReference(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, first_tuning_round) @assert first_tuning_round ≥ 1 - new(edge_set, mean, standard_deviation, which_variable, which_index, first_tuning_round) + new(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, first_tuning_round) end end dim(variational::TreeReference) = length(variational.mean) -function activate_variational(variational::TreeReference, iterators::Iterators) +function activate_variational(variational::TreeReference, iterators) iterators.round ≥ variational.first_tuning_round ? true : false end @@ -27,7 +28,7 @@ variational_recorder_builders(::TreeReference) = [_transformed_online] function update_reference!(reduced_recorders, variational::TreeReference, state) isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") - + for var_name in continuous_variables(state) temp_mean = get_transformed_statistic(reduced_recorders, var_name, Mean) temp_std = sqrt.(get_transformed_statistic(reduced_recorders, var_name, Variance)) @@ -42,10 +43,15 @@ function update_reference!(reduced_recorders, variational::TreeReference, state) end end @assert length(variational.mean) == length(variational.standard_deviation) + total_number_of_nodes = length(variational.mean) - adjacency_list::Dict{Symbol, Vector{Any}} = Dict{Symbol, Vector{Any}}() + variational.iid_sample_set = zeros(total_number_of_nodes) + + adjacency_list::Dict{Any, Any} = Dict{Any, Vector{Any}}() + for i in 1:total_number_of_nodes + adjacency_list[i] = Vector{Any}() + end - total_number_of_nodes = length(variational.mean) for i = 1:total_number_of_nodes for j = 1:total_number_of_nodes if i != j @@ -68,9 +74,9 @@ end function directed_max_tree(adjacency_list, root) total_number_of_nodes = length(keys(adjacency_list)) - mst = Vector{Tuple{Int, Int}}() - pq = BinaryMaxHeap{Tuple{Float64, Int, Int}}() - visited_nodes = Set{Int}() + mst = Vector{Any}() + pq = BinaryMaxHeap{Any}() + visited_nodes = Set{Any}() push!(visited_nodes, root) for edge in adjacency_list[root] @@ -82,7 +88,7 @@ function directed_max_tree(adjacency_list, root) if !(popped[3] in visited_nodes) push!(visited_nodes, popped[3]) - push!(mst, (popped[2], popped[3])) + push!(mst, popped) for new_edge in adjacency_list[popped[3]] if !(new_edge[3] in visited_nodes) @@ -97,17 +103,16 @@ end function sample_iid!(variational::TreeReference, replica, shared) - new_state::Vector{Int} = Vector{Int}() - marginal_val = randn(replica.rng) * variational.standard_deviation[1] + variational.mean[1] - push!(new_state, marginal_val) - update_state!(replica.state, which_variable[1], 1, marginal_val) + variational.iid_sample_set[1] = marginal_val + update_state!(replica.state, variational.which_variable[1], 1, marginal_val) for edge in variational.edge_set - params = tree_logdensity(variational, which_variable[edge[3]], which_variable[edge[2]], new_state[edge[2]]) + params = tree_logdensity(variational, edge[3], edge[2], variational.iid_sample_set[edge[2]]) #TODO val = rand(replica.rng) * params[2] + params[1] + variational.iid_sample_set[edge[3]] = val - update_state!(replica.state, which_variable[edge[3]], which_index[edge[3]], val) + update_state!(replica.state, variational.which_variable[edge[3]], variational.which_index[edge[3]], val) end end @@ -115,42 +120,41 @@ end function (variational::TreeReference)(state) log_pdf = 0.0 - marginal_var_name = continuous_variables(state)[1] - marginal_state = variable(state, marginal_var_name) - marginal_mean = variational.mean[marginal_var_name] - marginal_standard_deviation = variational.standard_deviation[marginal_var_name] - log_pdf += logpdf(Normal(marginal_mean, marginal_standard_deviation), marginal_state) + marginal_state = variable(state, variational.which_variable[1])[1] + marginal_mean = variational.mean[1] + marginal_standard_deviation = variational.standard_deviation[1] + log_pdf += logpdf(Normal(marginal_mean, marginal_standard_deviation), marginal_state)[1] for edge in variational.edge_set - parent_var_name = which_variable[edge[1]] - child_var_name = which_variable[edge[2]] + parent_var_name = variational.which_variable[edge[2]] + child_var_name = variational.which_variable[edge[3]] - state_at_parent = variable(state, parent_var_name) - state_at_child = variable(state, child_var_name) + state_at_parent = variable(state, parent_var_name)[variational.which_index[edge[2]]] + state_at_child = variable(state, child_var_name)[variational.which_index[edge[3]]] - cond_params = tree_logdensity(variational, child_var_name, parent_var_name, state_at_parent) + cond_params = tree_logdensity(variational, edge[3], edge[2], state_at_parent) log_pdf += logpdf(Normal(cond_params[1], cond_params[2]), state_at_child) end return log_pdf end -function tree_logdensity(variational::TreeReference, child_var_name, parent_var_name, state_at_parent) - child_mean = variational.mean[child_var_name] - parent_mean = variational.mean[parent_var_name] - child_standard_deviation = variational.standard_deviation[child_var_name] - parent_standard_deviation = variational.standard_deviation[child_var_name] +function tree_logdensity(variational::TreeReference, child_num, parent_num, state_at_parent) + child_mean = variational.mean[child_num] + parent_mean = variational.mean[parent_num] + child_standard_deviation = variational.standard_deviation[child_num] + parent_standard_deviation = variational.standard_deviation[parent_num] - rho = get_rho(parent_var_name, child_var_name) + rho = get_rho(parent_num, child_num) - new_mu = child_mean + rho * (child_standard_deviation / parent_standard_deviation) * (state_at_parent - parent_mean) + new_mu = child_mean + rho * (child_standard_deviation / parent_standard_deviation) * (state_at_parent .- parent_mean) new_sigma = sqrt((1-rho^2) * (child_standard_deviation)^2) return (new_mu, new_sigma) end #TODO -function get_rho(var_name1, var_name2) +function get_rho(parent_num, child_num) return 0 end @@ -160,13 +164,15 @@ end #TODO LogDensityProblems.logdensity(log_potential::TreeReference, x) = 0 +#TODO function LogDensityProblems.dimension(log_potential::TreeReference) @assert length(log_potential.mean) == 1 && haskey(log_potential.mean, :singleton_variable) "Differentiation of TreeReference assuming a single flat vector called :singleton_variable at the moment. Found: $(keys(log_potential.mean))" end -LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType. log_potential::TreeReference, replica::Replica) = +LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType, log_potential::TreeReference, replica::Replica) = BufferedAD(log_potential, replica.recorders.buffers) +#TODO function LogDensityProblems.logdensity_and_gradient(log_potential::BufferedAD{TreeReference}, x) variational = log_potential.enclosed buffer = log_potential.buffer From e2f764eb6e79ec4d9def6b7f046109642a949309 Mon Sep 17 00:00:00 2001 From: Jack Date: Fri, 11 Jul 2025 16:14:46 -0700 Subject: [PATCH 11/27] refactored for dense/diag dispatch --- src/recorders/OnlineStateRecorder.jl | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/recorders/OnlineStateRecorder.jl b/src/recorders/OnlineStateRecorder.jl index e7a1a7a57..f1f3dc406 100644 --- a/src/recorders/OnlineStateRecorder.jl +++ b/src/recorders/OnlineStateRecorder.jl @@ -3,6 +3,7 @@ See [`online()`](@ref). """ @kwdef struct OnlineStateRecorder stats::Dict{Pair{Symbol, Type}, Any} = Dict{Pair{Symbol, Type}, Any}() + is_full::Bool = false end OnlineStateRecorder(from_another::OnlineStateRecorder) = OnlineStateRecorder(copy(from_another.stats)) @@ -84,7 +85,23 @@ function register_online_type(type) end end -function record!(recorder::OnlineStateRecorder, state) +record!(recorder::OnlineStateRecorder, state) = + record!(recorder, state, Val(recorder.is_full)) + +#TODO +function record!(recorder::OnlineStateRecorder, state, is_full::Val{true}) + if isempty(recorder.stats) + initialize_online_state_recorder!(recorder.stats, state) + end + for name in recorded_continuous_variables(state) + for stat in registered_online_types # NB: the more natural "for key in keys(recorder.stats)" leads to allocations in the inner loop + key = Pair(name, stat) + fit!(recorder.stats[key], variable(state, name)) + end + end +end + +function record!(recorder::OnlineStateRecorder, state, is_full::Val{false}) if isempty(recorder.stats) initialize_online_state_recorder!(recorder.stats, state) end From a73b5dfc717646afa6dd487529bdad319bfe34b0 Mon Sep 17 00:00:00 2001 From: Jack Date: Sat, 26 Jul 2025 19:46:39 -0700 Subject: [PATCH 12/27] added covariance mechanism, improved code efficiency --- src/variational/TreeReference.jl | 129 ++++++++++++++----------------- 1 file changed, 56 insertions(+), 73 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index 6cae0bc6d..93133fcc8 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -3,17 +3,18 @@ A Gaussian tree variational reference """ @kwdef mutable struct TreeReference - edge_set::Vector{Any} = Vector{Any}() - mean::Vector{Any} = Vector{Any}() - standard_deviation::Vector{Any} = Vector{Any}() - which_variable::Vector{Any} = Vector{Any}() + edge_set::Vector{Tuple{Float64, Float64, Int, Int}} = Vector{Tuple{Float64, Float64, Int, Int}}() + mean::Vector{Float64} = Vector{Float64}() + standard_deviation::Vector{Float64} = Vector{Float64}() + which_variable::Vector{Symbol} = Vector{Symbol}() which_index::Vector{Int} = Vector{Int}() - iid_sample_set::Vector{Any} = Vector{Int}() - first_tuning_round::Int = 6 - - function TreeReference(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, first_tuning_round) + iid_sample_set::Vector{Float64} = Vector{Float64}() + first_tuning_round::Int = 11 + covariance_matrix::Matrix{Float64} = zeros(Float64, 0, 0) + + function TreeReference(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, first_tuning_round, covariance_matrix) @assert first_tuning_round ≥ 1 - new(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, first_tuning_round) + new(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, first_tuning_round, covariance_matrix) end end @@ -23,12 +24,18 @@ function activate_variational(variational::TreeReference, iterators) iterators.round ≥ variational.first_tuning_round ? true : false end -variational_recorder_builders(::TreeReference) = [_transformed_online] +variational_recorder_builders(::TreeReference) = [_transformed_online_full] function update_reference!(reduced_recorders, variational::TreeReference, state) isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") + empty!(variational.mean) + empty!(variational.standard_deviation) + empty!(variational.which_variable) + empty!(variational.which_index) + variational.edge_set = [] + for var_name in continuous_variables(state) temp_mean = get_transformed_statistic(reduced_recorders, var_name, Mean) temp_std = sqrt.(get_transformed_statistic(reduced_recorders, var_name, Variance)) @@ -47,36 +54,36 @@ function update_reference!(reduced_recorders, variational::TreeReference, state) variational.iid_sample_set = zeros(total_number_of_nodes) - adjacency_list::Dict{Any, Any} = Dict{Any, Vector{Any}}() + adjacency_list::Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}} = Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}}() for i in 1:total_number_of_nodes - adjacency_list[i] = Vector{Any}() + adjacency_list[i] = Vector{Int}() end - for i = 1:total_number_of_nodes - for j = 1:total_number_of_nodes - if i != j - I = compute_mutual_info(i, j) + variational.covariance_matrix = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) - push!(adjacency_list[i], (I, i, j)) - push!(adjacency_list[j], (I, j, i)) - end + for i = 1:total_number_of_nodes + for j = (i+1):total_number_of_nodes + normalization = (variational.standard_deviation[i] * variational.standard_deviation[j]) + rho = variational.covariance_matrix[i,j] / normalization + I = -0.5*log(1-rho^2) + + push!(adjacency_list[i], (rho, I, i, j)) + push!(adjacency_list[j], (rho, I, j, i)) end end root = 1 variational.edge_set = directed_max_tree(adjacency_list, root) -end - -function compute_mutual_info(i, j) - return -0.5*log(1-get_rho(i, j)^2) + empty!(adjacency_list) end + function directed_max_tree(adjacency_list, root) total_number_of_nodes = length(keys(adjacency_list)) - mst = Vector{Any}() - pq = BinaryMaxHeap{Any}() - visited_nodes = Set{Any}() + mst = Vector{Tuple{Float64, Float64, Int, Int}}() + pq = BinaryMaxHeap{Tuple{Float64, Float64, Int, Int}}() + visited_nodes = Set{Int}() push!(visited_nodes, root) for edge in adjacency_list[root] @@ -86,12 +93,12 @@ function directed_max_tree(adjacency_list, root) while !isempty(pq) && length(mst) Date: Sat, 26 Jul 2025 19:47:16 -0700 Subject: [PATCH 13/27] migrated dense reference from other branch --- src/variational/DenseGaussianReference.jl | 92 +++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 src/variational/DenseGaussianReference.jl diff --git a/src/variational/DenseGaussianReference.jl b/src/variational/DenseGaussianReference.jl new file mode 100644 index 000000000..562b4bacd --- /dev/null +++ b/src/variational/DenseGaussianReference.jl @@ -0,0 +1,92 @@ +""" +A Gaussian dense variational reference (i.e., with a dense covariance matrix). +""" +@kwdef mutable struct DenseGaussianReference + mean::Vector{Any} = Vector{Any}() + covariance::Matrix{Float64} = zeros(Float64, 0, 0) + precision::Matrix{Float64} = zeros(Float64, 0, 0) + cholesky::Any = zeros(Float64, 0, 0) + which_variable::Vector{Any} = Vector{Any}() + which_index::Vector{Int} = Vector{Int}() + first_tuning_round::Int = 11 # TODO: this should be moved elsewhere? + + function DenseGaussianReference(mean, covariance, precision, cholesky, which_variable, which_index, first_tuning_round) + @assert first_tuning_round ≥ 1 + new(mean, covariance, precision, cholesky, which_variable, which_index, first_tuning_round) + end +end + +dim(variational::DenseGaussianReference) = length(variational.mean) +function activate_variational(variational::DenseGaussianReference, iterators::Iterators) + iterators.round ≥ variational.first_tuning_round ? true : false +end +variational_recorder_builders(::DenseGaussianReference) = [_transformed_online_full] + +function update_reference!(reduced_recorders, variational::DenseGaussianReference, state) + isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") + + empty!(variational.which_variable) + empty!(variational.which_index) + empty!(variational.mean) + + variational.covariance = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) + variational.precision = inv(variational.covariance) + variational.cholesky = cholesky(variational.covariance).L + variational.mean = get_transformed_statistic(reduced_recorders, :singleton_variable, Mean) + + for var_name in continuous_variables(state) + for i = 1:length(variable(state, var_name)) + push!(variational.which_variable, var_name) + push!(variational.which_index, i) + + end + end + +end + +function sample_iid!(variational::DenseGaussianReference, replica, shared) + dim = length(variational.mean) + identity_gaussian = MvNormal(zeros(dim), I) + z = rand(identity_gaussian) + + sample = variational.mean + variational.cholesky * z + + for i in 1:dim + update_state!(replica.state, variational.which_variable[i], variational.which_index[i], sample[i]) + end +end + +function (variational::DenseGaussianReference)(state) + flattened_state = Vector{Float64}() + + for i in 1:length(variational.mean) + name = variational.which_variable[i] + index = variational.which_index[i] + push!(flattened_state, Pigeons.variable(state, name)[index]) + end + + return -0.5 * (transpose(flattened_state - variational.mean) * variational.precision * (flattened_state - variational.mean)) +end + + + +# LogDensityProblemsAD implementation (currently only for special case of a singleton variable) + +LogDensityProblems.logdensity(log_potential::DenseGaussianReference, x) = + log_potential(x) + +function LogDensityProblems.dimension(log_potential::DenseGaussianReference) + return length(log_potential.mean) +end + +LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType, log_potential::DenseGaussianReference, replica::Replica) = + BufferedAD(log_potential, replica.recorders.buffers) + +function LogDensityProblems.logdensity_and_gradient(log_potential::BufferedAD{DenseGaussianReference}, x) + variational = log_potential.enclosed + buffer = log_potential.buffer + mean = variational.mean + precision = variational.precision + @. buffer = -precision * (x - mean) + return LogDensityProblems.logdensity(variational, x), buffer +end \ No newline at end of file From f4f9270a4a3d3607b606d25302bff629d040428f Mon Sep 17 00:00:00 2001 From: Jack Date: Sat, 26 Jul 2025 19:49:49 -0700 Subject: [PATCH 14/27] incorporated references, covariance mechanism into dependencies --- src/Pigeons.jl | 2 +- src/explorers/GradientBasedSampler.jl | 2 +- src/includes.jl | 1 + src/pt/pigeons.jl | 2 +- src/recorders/OnlineStateRecorder.jl | 38 +++++++++++-------------- src/recorders/recorder.jl | 3 +- src/variational/VariationalReference.jl | 2 +- 7 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/Pigeons.jl b/src/Pigeons.jl index 0194c5b08..2a9e36d57 100644 --- a/src/Pigeons.jl +++ b/src/Pigeons.jl @@ -69,7 +69,7 @@ export pigeons, Inputs, PT, # getting information out of an execution: stepping_stone, n_tempered_restarts, n_round_trips, process_sample, get_sample, # variational references: - GaussianReference, TreeReference, + GaussianReference, TreeReference, DenseGaussianReference, # samplers SliceSampler, AutoMALA, Compose, AAPS, MALA, Mix diff --git a/src/explorers/GradientBasedSampler.jl b/src/explorers/GradientBasedSampler.jl index c558b955e..995b1555a 100644 --- a/src/explorers/GradientBasedSampler.jl +++ b/src/explorers/GradientBasedSampler.jl @@ -20,6 +20,6 @@ function gradient_based_sampler_recorders!(recorders, explorer::GradientBasedSam push!(recorders, buffers) push!(recorders, Pigeons.ad_buffers) if hasproperty(explorer, :preconditioner) && explorer.preconditioner isa AdaptedDiagonalPreconditioner - push!(recorders, _transformed_online) # for mass matrix adaptation + push!(recorders, _transformed_online_full) # for mass matrix adaptation end end diff --git a/src/includes.jl b/src/includes.jl index f5dc5db0f..e755b0036 100644 --- a/src/includes.jl +++ b/src/includes.jl @@ -80,6 +80,7 @@ include("pt/checks.jl") include("explorers/BufferedAD.jl") include("variational/GaussianReference.jl") include("variational/TreeReference.jl") +include("variational/DenseGaussianReference.jl") include("variational/VariationalReference.jl") include("paths/ScaledPrecisionNormalPath.jl") include("explorers/invariance_test.jl") diff --git a/src/pt/pigeons.jl b/src/pt/pigeons.jl index 483d7fe6a..fd47c281c 100644 --- a/src/pt/pigeons.jl +++ b/src/pt/pigeons.jl @@ -111,7 +111,7 @@ function explore!(pt, replica, explorer) # for the online stats, we ignore pt.inputs.extended_traces # because the recorders do not support grouping by chains @record_if_requested!(replica.recorders, :online, extract_sample(replica.state, log_potential, pt.inputs.extractor)) - @record_if_requested!(replica.recorders, :_transformed_online, replica.state) + @record_if_requested!(replica.recorders, :_transformed_online_full, replica.state) end if pt.inputs.extended_traces || is_target(pt.shared.tempering.swap_graphs, replica.chain) @record_if_requested!( diff --git a/src/recorders/OnlineStateRecorder.jl b/src/recorders/OnlineStateRecorder.jl index f1f3dc406..6fd9ff59d 100644 --- a/src/recorders/OnlineStateRecorder.jl +++ b/src/recorders/OnlineStateRecorder.jl @@ -2,11 +2,11 @@ See [`online()`](@ref). """ @kwdef struct OnlineStateRecorder + full::Bool = false stats::Dict{Pair{Symbol, Type}, Any} = Dict{Pair{Symbol, Type}, Any}() - is_full::Bool = false end -OnlineStateRecorder(from_another::OnlineStateRecorder) = OnlineStateRecorder(copy(from_another.stats)) +OnlineStateRecorder(from_another::OnlineStateRecorder) = OnlineStateRecorder(from_another.full, copy(from_another.stats)) """ $SIGNATURES @@ -31,9 +31,12 @@ get_transformed_statistic(reduced_recorders, variable_name::Symbol, t::Type{T}) get_statistic(reduced_recorders, variable_name, t, false) function get_statistic(reduced_recorders, variable_name::Symbol, ::Type{T}, original_param = true) where {T} - recorder = original_param ? reduced_recorders.online : reduced_recorders._transformed_online + recorder = original_param ? reduced_recorders.online : reduced_recorders._transformed_online_full key = Pair(variable_name, T) - v = value(recorder.stats[key]) + v = value(recorder.stats[key]) + if T==CovMatrix + return v + end return value.(v) end @@ -69,7 +72,7 @@ recorded_continuous_variables(state) = continuous_variables(state) `OnlineStat` types to be computed when the [`online()`] recorder is enabled. """ -const registered_online_types = [Mean, Variance] +const registered_online_types = [Mean, Variance, CovMatrix] """ $SIGNATURES @@ -85,23 +88,7 @@ function register_online_type(type) end end -record!(recorder::OnlineStateRecorder, state) = - record!(recorder, state, Val(recorder.is_full)) - -#TODO -function record!(recorder::OnlineStateRecorder, state, is_full::Val{true}) - if isempty(recorder.stats) - initialize_online_state_recorder!(recorder.stats, state) - end - for name in recorded_continuous_variables(state) - for stat in registered_online_types # NB: the more natural "for key in keys(recorder.stats)" leads to allocations in the inner loop - key = Pair(name, stat) - fit!(recorder.stats[key], variable(state, name)) - end - end -end - -function record!(recorder::OnlineStateRecorder, state, is_full::Val{false}) +function record!(recorder::OnlineStateRecorder, state) if isempty(recorder.stats) initialize_online_state_recorder!(recorder.stats, state) end @@ -118,6 +105,13 @@ initialize_online_state_recorder!(stats, state) = initialize_online_state_recorder!(stats, state, stat_type) end +initialize_online_state_recorder!(stats, state, ::Type{CovMatrix}) = + for name in recorded_continuous_variables(state) + var = variable(state, name) + key = Pair(name, CovMatrix) + stats[key] = CovMatrix() + end + initialize_online_state_recorder!(stats, state, ::Type{T}) where {T} = for name in recorded_continuous_variables(state) var = variable(state, name) diff --git a/src/recorders/recorder.jl b/src/recorders/recorder.jl index e090d13f7..996087382 100644 --- a/src/recorders/recorder.jl +++ b/src/recorders/recorder.jl @@ -99,7 +99,8 @@ transformed to be defined on an unconstrained space. This is used internally by [`explorer`](@ref)'s for adaptation purposes (in particular, pre-conditioning and variational references). """ -@provides recorder _transformed_online() = OnlineStateRecorder() +@provides recorder _transformed_online() = OnlineStateRecorder() +@provides recorder _transformed_online_full() = OnlineStateRecorder(full=true) """ Restart and round-trip counts. diff --git a/src/variational/VariationalReference.jl b/src/variational/VariationalReference.jl index 20d8d0b0d..5f2e8de46 100644 --- a/src/variational/VariationalReference.jl +++ b/src/variational/VariationalReference.jl @@ -3,7 +3,7 @@ Methods common to all variational references =# # Currently implemented variational references -const VariationalReference = Union{GaussianReference, TreeReference} +const VariationalReference = Union{GaussianReference, TreeReference, DenseGaussianReference} # Elide the AD buffering system # Reasoning: From d1177aee7876bed1ae9346eab82486a2a7c2551a Mon Sep 17 00:00:00 2001 From: Jack Date: Sat, 26 Jul 2025 19:51:07 -0700 Subject: [PATCH 15/27] added HMM problem for testing new references --- examples/stan/1000-hmm.json | 253 ++++++++++++++++++++++++++++++++++++ examples/stan/1000-hmm.stan | 25 ++++ 2 files changed, 278 insertions(+) create mode 100644 examples/stan/1000-hmm.json create mode 100644 examples/stan/1000-hmm.stan diff --git a/examples/stan/1000-hmm.json b/examples/stan/1000-hmm.json new file mode 100644 index 000000000..fa8206b92 --- /dev/null +++ b/examples/stan/1000-hmm.json @@ -0,0 +1,253 @@ +{ + "N": 1000, + "observations": [2.20041627e+00, 1.95769524e+00, -8.52087759e-01, -3.49801576e+00, + -2.75466819e+00, -7.84143831e+00, -7.00086086e+00, -1.00700985e+01, + -1.12625951e+01, -1.14712207e+01, -7.63410793e+00, -6.82938344e+00, + -8.94909208e+00, -1.05940554e+01, -1.01677096e+01, -1.07991865e+01, + -1.21380587e+01, -1.26044233e+01, -9.45156068e+00, -1.01137595e+01, + -1.20918822e+01, -8.58104339e+00, -9.46199003e+00, -6.16620129e+00, + -4.99686144e+00, -7.00417901e+00, -6.63052420e+00, -4.64233530e+00, + -3.70324733e+00, -3.32442030e+00, -4.88425406e-02, -5.51073794e-01, + -2.90981188e-02, 3.46099579e+00, -2.40716819e+00, -1.61163678e+00, + -5.16084106e-02, -3.91487830e+00, -2.05140073e+00, -1.28237909e+00, + -1.19528909e+00, -5.26269931e-01, -1.20447035e+00, 1.93218523e-01, + 3.08078980e+00, 1.80337827e+00, 1.57832252e+00, 9.34944592e-01, + 7.28679787e-01, 2.02044301e+00, 4.83385362e+00, 4.28691024e+00, + 6.39388537e+00, 3.85840269e+00, 2.80378245e+00, 5.52823636e+00, + 3.65488463e+00, 1.16215512e+00, -2.67688803e-01, -7.58287776e-01, + -7.02438545e-01, -6.81065563e-01, 4.78604164e+00, 2.50092926e+00, + 4.90315224e+00, 2.88597671e+00, 5.79131041e+00, 4.51960872e+00, + 2.96660337e+00, 2.05281631e+00, 2.94518154e+00, 3.80861207e+00, + 3.57769218e+00, 4.72358097e+00, 5.27504634e+00, 2.72113257e+00, + 5.62923032e+00, 4.89088010e+00, 4.85389967e+00, 2.52088724e+00, + 1.68862456e+00, 1.45954408e+00, 1.67063240e+00, 3.21727514e+00, + -2.66006003e+00, 1.96019559e+00, -9.23003292e-01, 2.02073660e-03, + -1.10989865e+00, -3.57514428e+00, -1.29680815e+00, 1.62500296e-01, + -1.53874849e-01, -5.00880738e-02, 7.12391454e-01, 2.04897962e+00, + 2.58430799e+00, -8.10793344e-02, 8.01232671e-01, 1.26356208e+00, + 2.42295161e+00, 4.78682256e+00, 1.64580644e+00, 7.49921116e-01, + 4.67797903e+00, 3.96665687e+00, 6.75268362e-01, 3.93938545e+00, + 4.02636161e+00, 4.44271098e+00, 5.19042137e+00, 2.58723543e+00, + 7.18735268e+00, 6.71400675e+00, 7.33104925e+00, 8.07291814e+00, + 1.01641232e+01, 7.12567322e+00, 5.56664294e+00, 8.06508369e+00, + 8.06292109e+00, 4.96482253e+00, 7.52618800e+00, 8.00781373e+00, + 3.75931312e+00, 9.41197214e-01, 7.20632190e+00, 7.21676797e+00, + 6.53389120e+00, 5.23761666e+00, 3.48087341e+00, 2.21970007e+00, + 4.29720023e+00, 4.70966628e+00, 4.09379447e+00, 6.49047977e+00, + 6.27283245e+00, 4.31824266e+00, 7.48014444e+00, 9.19625336e+00, + 9.37229733e+00, 1.08146351e+01, 1.11088290e+01, 1.02598495e+01, + 1.13194604e+01, 1.19072907e+01, 1.02393721e+01, 1.02645342e+01, + 1.01497331e+01, 1.41101105e+01, 8.85600537e+00, 7.99558228e+00, + 1.08893333e+01, 1.03842444e+01, 9.92441316e+00, 1.06790997e+01, + 1.12610941e+01, 1.02065620e+01, 8.74433367e+00, 7.94298791e+00, + 7.08367963e+00, 6.09791682e+00, 6.30358218e+00, 8.68326355e+00, + 1.12112083e+01, 1.15750269e+01, 1.38522936e+01, 9.37299819e+00, + 5.78993972e+00, 4.49590123e+00, 6.51469197e+00, 7.36551889e+00, + 7.32101484e+00, 7.67118026e+00, 6.33513478e+00, 6.60466730e+00, + 7.38936309e+00, 7.38584474e+00, 6.47593338e+00, 5.28929490e+00, + 5.40888051e+00, 7.68078795e+00, 9.59892683e+00, 5.20831396e+00, + 5.16531652e+00, 5.24548731e+00, 5.56486992e+00, 6.77264209e+00, + 1.61973862e+00, 5.37412033e+00, 3.03930250e+00, 4.47814021e+00, + 3.44894920e+00, 2.34977267e+00, 1.00274187e+00, 4.78782641e+00, + 5.00688524e+00, 5.42917704e+00, 8.74264437e+00, 1.08546136e+01, + 1.26508115e+01, 1.38369640e+01, 1.22559499e+01, 1.35193366e+01, + 1.45558898e+01, 1.98940529e+01, 1.56230795e+01, 1.44094233e+01, + 1.72861607e+01, 1.33902066e+01, 1.58302603e+01, 1.70575462e+01, + 1.46093763e+01, 1.56669888e+01, 1.61703988e+01, 1.75911871e+01, + 1.77323057e+01, 1.69723792e+01, 1.89714998e+01, 2.14827767e+01, + 1.71719057e+01, 2.01649086e+01, 2.02643282e+01, 2.23352866e+01, + 2.45251186e+01, 2.75240294e+01, 2.67906918e+01, 2.34239784e+01, + 2.47779506e+01, 2.41075077e+01, 2.38932096e+01, 2.37542881e+01, + 2.41000283e+01, 2.73505371e+01, 2.64419694e+01, 2.52751711e+01, + 2.80470395e+01, 2.57457760e+01, 2.59011799e+01, 2.48515667e+01, + 2.17817917e+01, 2.50912782e+01, 2.57779044e+01, 2.83524943e+01, + 2.57348726e+01, 2.64355250e+01, 2.59956499e+01, 2.67754598e+01, + 2.15529923e+01, 2.09129092e+01, 2.19243167e+01, 2.19974779e+01, + 2.19960943e+01, 2.21024865e+01, 2.52887950e+01, 2.40804220e+01, + 2.27381049e+01, 2.00207687e+01, 2.19041866e+01, 2.49407407e+01, + 2.27611733e+01, 2.64211439e+01, 2.79037728e+01, 2.62650187e+01, + 2.95082709e+01, 2.63434330e+01, 2.69195724e+01, 2.72645202e+01, + 2.42261173e+01, 2.61695505e+01, 2.38557180e+01, 2.48194219e+01, + 2.40467891e+01, 2.76022611e+01, 2.42635759e+01, 2.30312713e+01, + 1.95100536e+01, 2.10866278e+01, 2.27955861e+01, 2.51106600e+01, + 2.54197451e+01, 2.32745902e+01, 3.19060151e+01, 3.00306215e+01, + 3.11005590e+01, 3.29650208e+01, 3.57058316e+01, 3.67591464e+01, + 3.59652717e+01, 3.86935166e+01, 3.57911075e+01, 3.79579306e+01, + 3.84987982e+01, 3.56516264e+01, 3.08808791e+01, 3.14919873e+01, + 3.37580507e+01, 3.67153048e+01, 3.71685145e+01, 3.87817383e+01, + 3.96963411e+01, 3.83564527e+01, 3.93413077e+01, 3.75321922e+01, + 4.15746594e+01, 4.36385157e+01, 3.98425106e+01, 4.15453595e+01, + 3.94474951e+01, 3.80429200e+01, 3.58932846e+01, 3.55484494e+01, + 3.57650208e+01, 3.55469244e+01, 3.13031299e+01, 3.35075406e+01, + 3.30004314e+01, 3.33909262e+01, 3.38864614e+01, 3.51344117e+01, + 3.47694398e+01, 3.61716779e+01, 3.64870273e+01, 3.60862257e+01, + 3.77519484e+01, 3.93528421e+01, 3.94286091e+01, 4.03535430e+01, + 3.91922900e+01, 3.71025976e+01, 3.22791905e+01, 3.18571947e+01, + 3.00682943e+01, 3.15075832e+01, 3.11689635e+01, 3.36183250e+01, + 3.36876822e+01, 3.54169048e+01, 3.52179400e+01, 3.62088136e+01, + 3.30185438e+01, 4.01510887e+01, 3.64507373e+01, 3.77575895e+01, + 3.63044233e+01, 3.64477644e+01, 3.68859079e+01, 3.79617659e+01, + 3.77345991e+01, 3.73876702e+01, 4.00806336e+01, 4.19850377e+01, + 4.11555981e+01, 4.42295454e+01, 4.16191062e+01, 3.84851625e+01, + 3.89853669e+01, 4.07976633e+01, 4.27051670e+01, 4.00062468e+01, + 4.30978454e+01, 4.23601404e+01, 4.41922518e+01, 4.39965820e+01, + 4.67794112e+01, 5.02978572e+01, 4.56351465e+01, 4.48239990e+01, + 4.05876623e+01, 4.14133895e+01, 4.14028065e+01, 4.18386992e+01, + 4.13414260e+01, 4.11379412e+01, 4.11907486e+01, 4.09471608e+01, + 3.95399716e+01, 3.69100236e+01, 3.66860475e+01, 3.79205085e+01, + 3.34496798e+01, 3.78740709e+01, 3.47468014e+01, 3.24855140e+01, + 3.40201234e+01, 3.61644552e+01, 3.63968850e+01, 3.42714727e+01, + 3.41741941e+01, 3.11705640e+01, 2.83597819e+01, 3.07470347e+01, + 2.56121553e+01, 2.59008892e+01, 2.57916997e+01, 2.88839503e+01, + 2.96389389e+01, 2.76744072e+01, 2.61646143e+01, 2.57337329e+01, + 2.47063686e+01, 2.58303525e+01, 2.47858266e+01, 2.87182967e+01, + 3.04326208e+01, 2.94320397e+01, 3.07871136e+01, 2.99429604e+01, + 2.81048675e+01, 2.87370975e+01, 2.96210477e+01, 2.75839731e+01, + 2.91019014e+01, 2.92451746e+01, 2.84477817e+01, 2.91807574e+01, + 2.92471887e+01, 3.04269667e+01, 3.15569941e+01, 3.38214070e+01, + 3.40074353e+01, 3.16579505e+01, 3.38981861e+01, 3.28793324e+01, + 3.52557179e+01, 3.39301843e+01, 3.63010869e+01, 4.05274543e+01, + 4.07032346e+01, 3.85513408e+01, 4.02953088e+01, 3.70458613e+01, + 3.72419718e+01, 3.70039193e+01, 3.77764164e+01, 3.56303451e+01, + 3.62959585e+01, 3.78946072e+01, 3.62709061e+01, 3.76634104e+01, + 3.91337749e+01, 3.84004126e+01, 3.81241726e+01, 4.15581530e+01, + 4.01142431e+01, 3.63541133e+01, 3.86243449e+01, 4.01430134e+01, + 3.92610382e+01, 3.59249212e+01, 3.73114540e+01, 3.78569320e+01, + 3.84827639e+01, 3.73911966e+01, 3.60235311e+01, 3.55317825e+01, + 3.79022022e+01, 3.80714924e+01, 3.83597586e+01, 3.78601350e+01, + 3.73587552e+01, 3.60520322e+01, 3.63999056e+01, 3.57917551e+01, + 3.61548360e+01, 3.50228969e+01, 3.56458695e+01, 3.42961012e+01, + 3.58128162e+01, 3.52237851e+01, 3.47537293e+01, 3.68640322e+01, + 3.41322362e+01, 3.66964622e+01, 3.34686593e+01, 3.44682303e+01, + 3.45441304e+01, 3.58217338e+01, 3.58910303e+01, 3.73505089e+01, + 3.77782400e+01, 3.46415997e+01, 3.58564269e+01, 3.53247241e+01, + 3.36910112e+01, 3.27083918e+01, 3.13210632e+01, 2.77079547e+01, + 3.00823366e+01, 2.99031872e+01, 2.79379527e+01, 2.74939928e+01, + 2.75225223e+01, 2.82877226e+01, 2.83256334e+01, 2.83580797e+01, + 3.04150303e+01, 2.66598925e+01, 2.91794583e+01, 2.83726774e+01, + 2.51347763e+01, 2.25963168e+01, 2.18457418e+01, 2.42479872e+01, + 2.67623007e+01, 2.25385780e+01, 2.73269381e+01, 2.59379632e+01, + 2.10230150e+01, 2.07278808e+01, 2.00192286e+01, 2.01415859e+01, + 2.17867287e+01, 2.43700543e+01, 2.45504939e+01, 2.47710967e+01, + 2.38890926e+01, 2.54430988e+01, 2.19310121e+01, 2.01861465e+01, + 2.38134047e+01, 2.19633916e+01, 2.16153891e+01, 2.45069550e+01, + 2.04463694e+01, 1.90655406e+01, 2.09490496e+01, 2.29965558e+01, + 2.57579529e+01, 2.54823802e+01, 2.56986140e+01, 2.50657061e+01, + 2.61572543e+01, 2.35304540e+01, 2.03513795e+01, 2.47338959e+01, + 2.85361200e+01, 2.67531185e+01, 2.53851547e+01, 2.93807397e+01, + 2.93031483e+01, 3.07399296e+01, 2.82640182e+01, 2.87514656e+01, + 2.70374233e+01, 2.65811547e+01, 2.61418710e+01, 2.58550020e+01, + 2.71578776e+01, 2.64238308e+01, 2.84622023e+01, 2.79479797e+01, + 2.54783638e+01, 2.60418242e+01, 2.79443392e+01, 3.07094261e+01, + 2.99400838e+01, 2.82979800e+01, 3.07511641e+01, 3.33209394e+01, + 3.33287246e+01, 3.74255324e+01, 3.66325921e+01, 3.67588263e+01, + 3.43328954e+01, 3.80575146e+01, 3.61886716e+01, 3.53697709e+01, + 3.52197474e+01, 3.59416153e+01, 3.71610542e+01, 3.50041339e+01, + 3.61407066e+01, 3.55777199e+01, 3.57989925e+01, 3.61486268e+01, + 3.79354584e+01, 3.87753346e+01, 4.04170961e+01, 3.85544136e+01, + 3.75831990e+01, 3.90366313e+01, 3.49632852e+01, 3.27692015e+01, + 3.66075790e+01, 3.82664250e+01, 4.10635836e+01, 4.01232858e+01, + 3.82659471e+01, 3.86222563e+01, 3.95396663e+01, 4.07177845e+01, + 4.16468646e+01, 4.14681492e+01, 3.87589157e+01, 4.25642139e+01, + 4.28169508e+01, 4.05391734e+01, 3.96425358e+01, 3.65827556e+01, + 3.83378639e+01, 3.78017070e+01, 4.20828011e+01, 3.95418102e+01, + 4.31436797e+01, 4.05330045e+01, 3.90917448e+01, 3.92669923e+01, + 3.89403550e+01, 3.78645983e+01, 4.00010289e+01, 4.11466577e+01, + 4.09166206e+01, 4.28376642e+01, 4.83032535e+01, 5.03440683e+01, + 5.04939470e+01, 4.83873177e+01, 4.88424153e+01, 4.60204642e+01, + 4.53628392e+01, 4.53208678e+01, 4.34104582e+01, 4.72712188e+01, + 4.53085845e+01, 4.87469883e+01, 4.53176722e+01, 4.59848480e+01, + 4.68180202e+01, 4.80254853e+01, 4.88952277e+01, 4.89675131e+01, + 4.89846423e+01, 4.77669309e+01, 4.68676581e+01, 4.72767848e+01, + 4.66397904e+01, 4.74731207e+01, 4.85923869e+01, 5.10479851e+01, + 5.09645735e+01, 4.63719640e+01, 4.87186987e+01, 4.72966687e+01, + 4.89071728e+01, 4.73546804e+01, 4.87901205e+01, 4.83914422e+01, + 4.86255307e+01, 4.61155063e+01, 4.73960112e+01, 4.60700542e+01, + 4.37176963e+01, 4.52532335e+01, 4.65388609e+01, 5.06470491e+01, + 4.98672825e+01, 5.14808476e+01, 5.25245687e+01, 5.18767123e+01, + 5.01147751e+01, 5.19041234e+01, 4.97524862e+01, 5.12598980e+01, + 5.05654567e+01, 5.13628797e+01, 4.96992870e+01, 5.05866731e+01, + 4.98398190e+01, 4.96328435e+01, 4.92854321e+01, 5.33002064e+01, + 5.14676067e+01, 5.11365110e+01, 5.51071319e+01, 5.13009479e+01, + 5.31046041e+01, 5.24629503e+01, 5.36683054e+01, 5.60098629e+01, + 5.52283081e+01, 5.62461682e+01, 5.80666498e+01, 5.93170298e+01, + 6.11729440e+01, 6.14663833e+01, 5.99074269e+01, 5.70475155e+01, + 5.57467469e+01, 5.53562365e+01, 5.16084196e+01, 5.38203806e+01, + 5.26558492e+01, 5.59798364e+01, 5.59254414e+01, 5.89710219e+01, + 5.80706137e+01, 5.49073495e+01, 5.54256550e+01, 5.45916219e+01, + 5.60633665e+01, 5.76954172e+01, 5.64758096e+01, 5.95694095e+01, + 5.82511840e+01, 5.83089690e+01, 5.67660766e+01, 5.89150907e+01, + 6.39865309e+01, 6.47192476e+01, 6.32750374e+01, 5.99555479e+01, + 6.31998530e+01, 6.25046699e+01, 6.06357692e+01, 6.25754608e+01, + 6.24183937e+01, 6.19650579e+01, 6.56295351e+01, 6.78886249e+01, + 6.68003985e+01, 6.62405935e+01, 7.15603668e+01, 6.92797252e+01, + 7.19722185e+01, 7.15217642e+01, 7.10314056e+01, 7.45748984e+01, + 7.09084928e+01, 6.75038378e+01, 6.89372006e+01, 6.53693364e+01, + 6.88181343e+01, 6.75895752e+01, 6.08994529e+01, 6.32462543e+01, + 6.58479365e+01, 6.62897445e+01, 6.75779635e+01, 6.48719351e+01, + 6.46889864e+01, 6.32792958e+01, 6.17130927e+01, 6.42011267e+01, + 5.67388212e+01, 6.13849245e+01, 6.24991451e+01, 6.10164266e+01, + 6.32098091e+01, 6.03996109e+01, 5.83264172e+01, 6.09408768e+01, + 6.25592304e+01, 6.04607721e+01, 6.23145977e+01, 6.35697235e+01, + 6.30387394e+01, 6.16345805e+01, 6.11449057e+01, 6.37358160e+01, + 6.29064326e+01, 6.30331826e+01, 6.26331013e+01, 6.31896768e+01, + 6.37297928e+01, 5.62072318e+01, 5.60973969e+01, 5.66564049e+01, + 5.78190653e+01, 5.99665542e+01, 5.73184637e+01, 5.50657201e+01, + 5.39367730e+01, 5.48769411e+01, 5.13593369e+01, 5.08861712e+01, + 5.26016574e+01, 5.24615858e+01, 5.60127981e+01, 5.48129187e+01, + 5.77475665e+01, 5.76449134e+01, 5.55324824e+01, 5.82712202e+01, + 5.28443004e+01, 5.58971797e+01, 5.48283501e+01, 5.62996898e+01, + 5.32566918e+01, 5.44371150e+01, 4.97249489e+01, 5.21936462e+01, + 5.32619388e+01, 5.22416434e+01, 5.03770476e+01, 5.06533824e+01, + 5.19585498e+01, 4.95978199e+01, 5.07628022e+01, 4.97096006e+01, + 5.04437666e+01, 4.93671130e+01, 5.31059827e+01, 5.12641705e+01, + 4.99410381e+01, 5.11623926e+01, 5.06482031e+01, 4.62325177e+01, + 4.75640127e+01, 4.66129257e+01, 4.41874806e+01, 4.36189522e+01, + 4.65535771e+01, 4.57677232e+01, 4.34520084e+01, 4.22660168e+01, + 4.69396993e+01, 4.71220580e+01, 4.63055727e+01, 4.61986738e+01, + 4.63830465e+01, 4.73434232e+01, 4.39444361e+01, 4.28300201e+01, + 4.40471823e+01, 4.20143829e+01, 4.42918783e+01, 4.56706574e+01, + 4.88734642e+01, 4.62027279e+01, 4.80358036e+01, 4.75777879e+01, + 4.58832198e+01, 4.94304927e+01, 4.82845984e+01, 4.66429346e+01, + 4.67056668e+01, 4.69122986e+01, 4.27901620e+01, 4.65325967e+01, + 4.83390420e+01, 4.70025843e+01, 4.79409464e+01, 5.17068551e+01, + 5.05584049e+01, 4.97489082e+01, 4.57898409e+01, 4.53285840e+01, + 4.81430637e+01, 4.89080671e+01, 5.03626445e+01, 5.02400699e+01, + 4.94825214e+01, 4.54937190e+01, 5.19526355e+01, 5.38132293e+01, + 4.99436520e+01, 4.83595928e+01, 4.78333140e+01, 4.48498755e+01, + 4.46929281e+01, 4.09217192e+01, 4.28240121e+01, 4.65186237e+01, + 4.61604551e+01, 4.84110465e+01, 4.61393856e+01, 4.67949869e+01, + 4.69190818e+01, 4.42192962e+01, 4.62411790e+01, 4.62952609e+01, + 4.56967639e+01, 4.94963276e+01, 4.76438980e+01, 4.76246976e+01, + 4.62150732e+01, 4.43248525e+01, 4.70872994e+01, 4.57373936e+01, + 4.50707527e+01, 4.55083211e+01, 4.64562081e+01, 4.50560377e+01, + 4.50077580e+01, 4.51886837e+01, 4.62188716e+01, 4.80803749e+01, + 4.45905514e+01, 4.94097478e+01, 4.87136217e+01, 4.70088604e+01, + 4.82228625e+01, 4.86098235e+01, 4.81864026e+01, 4.87803107e+01, + 4.91864136e+01, 5.28993466e+01, 5.47071607e+01, 5.53396866e+01, + 5.90826705e+01, 5.63086868e+01, 5.70123754e+01, 6.21627017e+01, + 5.97786562e+01, 5.83973692e+01, 5.79965186e+01, 5.60436028e+01, + 5.90281943e+01, 5.93664942e+01, 5.88925447e+01, 5.97547768e+01, + 5.85818570e+01, 5.92996480e+01, 5.97213295e+01, 5.90132611e+01, + 6.00813692e+01, 6.10462490e+01, 5.96630644e+01, 6.09650978e+01, + 6.17591429e+01, 6.27805127e+01, 6.52614308e+01, 6.64653215e+01, + 6.48972785e+01, 6.34602292e+01, 6.25913131e+01, 6.37276648e+01, + 6.39805631e+01, 6.65411796e+01, 6.69554024e+01, 6.64835310e+01, + 6.27351487e+01, 6.37736792e+01, 6.33913638e+01, 6.41161732e+01, + 6.17439655e+01, 5.96170382e+01, 6.33738869e+01, 6.36763719e+01, + 6.35608969e+01, 5.93835735e+01, 6.27860841e+01, 6.15587271e+01, + 5.99924406e+01, 5.87257680e+01, 5.80450010e+01, 5.95806869e+01, + 6.04869804e+01, 6.07560040e+01, 5.65849642e+01, 5.56785583e+01, + 5.78695710e+01, 5.42369067e+01, 5.43403858e+01, 5.22663184e+01, + 5.08715347e+01, 5.28165264e+01, 5.37169510e+01, 5.28806414e+01, + 5.68591051e+01, 5.64357773e+01, 5.69022166e+01, 5.35198183e+01, + 5.50060076e+01, 5.46257864e+01, 5.21519161e+01, 5.39814688e+01, + 5.20783610e+01, 5.52712187e+01, 4.90838566e+01, 5.30739350e+01, + 5.22890425e+01, 5.49519898e+01, 5.41107932e+01, 5.15971948e+01, + 5.16711718e+01, 5.16623935e+01, 4.76658924e+01, 4.66589874e+01, + 4.28605236e+01, 4.26718233e+01, 4.32431638e+01, 4.08936186e+01, + 4.16692229e+01, 4.05938496e+01, 4.34593718e+01, 4.26736559e+01, + 4.18790859e+01, 4.25754888e+01, 4.73663448e+01, 4.84002881e+01, + 4.63102423e+01, 4.59397977e+01, 4.57514586e+01, 4.72229676e+01, + 4.49929558e+01, 4.63036309e+01, 4.49127403e+01, 4.69778551e+01] +} \ No newline at end of file diff --git a/examples/stan/1000-hmm.stan b/examples/stan/1000-hmm.stan new file mode 100644 index 000000000..db0c70178 --- /dev/null +++ b/examples/stan/1000-hmm.stan @@ -0,0 +1,25 @@ +data { + int N; + vector[N] observations; +} + +parameters { + real log_sigma_transition; + vector[N] latents; +} + +transformed parameters { + real sigma_transition = exp(log_sigma_transition); +} + +model { + log_sigma_transition ~ normal(1,1); + + latents[1] ~ normal(2,0.5); + observations[1] ~ normal(latents[1],1); + + for (t in 2:N) { + latents[t] ~ normal(latents[t-1], sigma_transition); + observations[t] ~ normal(latents[t],1); + } +} \ No newline at end of file From 7cce4d07b810687f4a8a13d12329c5879e883f1a Mon Sep 17 00:00:00 2001 From: Jack Date: Sat, 26 Jul 2025 19:52:39 -0700 Subject: [PATCH 16/27] added the new full recorder --- src/variational/GaussianReference.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/variational/GaussianReference.jl b/src/variational/GaussianReference.jl index a2e225c26..33a523118 100644 --- a/src/variational/GaussianReference.jl +++ b/src/variational/GaussianReference.jl @@ -4,7 +4,7 @@ A Gaussian mean-field variational reference (i.e., with a diagonal covariance ma @kwdef mutable struct GaussianReference mean::Dict{Symbol, Any} = Dict{Symbol, Any}() standard_deviation::Dict{Symbol, Any} = Dict{Symbol, Any}() - first_tuning_round::Int = 6 # TODO: this should be moved elsewhere? + first_tuning_round::Int = 11 # TODO: this should be moved elsewhere? function GaussianReference(mean, standard_deviation, first_tuning_round) @assert length(mean) == length(standard_deviation) @@ -17,7 +17,7 @@ dim(variational::GaussianReference) = length(variational.mean) function activate_variational(variational::GaussianReference, iterators::Iterators) iterators.round ≥ variational.first_tuning_round ? true : false end -variational_recorder_builders(::GaussianReference) = [_transformed_online] +variational_recorder_builders(::GaussianReference) = [_transformed_online_full] function update_reference!(reduced_recorders, variational::GaussianReference, state) isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") From 4842120c992121501b673ae832138c0f22f4be60 Mon Sep 17 00:00:00 2001 From: Jack Date: Sun, 27 Jul 2025 01:20:20 -0700 Subject: [PATCH 17/27] added AD gradient mechanism --- src/variational/TreeReference.jl | 45 ++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index 93133fcc8..98192944f 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -164,3 +164,48 @@ function tree_logdensity(variational::TreeReference, child_num, parent_num, stat return (new_mu, new_sigma) end +function tree_gradient(variational::TreeReference, state) + gradient = 0.0 + + marginal_state = variable(state, variational.which_variable[1])[1] + marginal_mean = variational.mean[1] + marginal_standard_deviation = variational.standard_deviation[1] + gradient += -(marginal_state - marginal_mean) / marginal_standard_deviation^2 + + for edge in variational.edge_set + child_idx = edge[4] + parent_idx = edge[3] + + parent_var_name = variational.which_variable[parent_idx] + child_var_name = variational.which_variable[child_idx] + + state_at_parent = variable(state, parent_var_name)[variational.which_index[parent_idx]] + state_at_child = variable(state, child_var_name)[variational.which_index[child_idx]] + + mu, sigma = tree_logdensity(variational, child_idx, parent_idx, state_at_parent, edge[1]) + gradient += -(state_at_child - mu) / sigma^2 + end + return gradient +end + + + +# LogDensityProblemsAD implementation (currently only for special case of a singleton variable) + +LogDensityProblems.logdensity(log_potential::TreeReference, x) = + log_potential(x) + +function LogDensityProblems.dimension(log_potential::TreeReference) + dim = length(log_potential.edge_set) + 1 + return dim +end + +LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType, log_potential::TreeReference, replica::Replica) = + BufferedAD(log_potential, replica.recorders.buffers) + +function LogDensityProblems.logdensity_and_gradient(log_potential::BufferedAD{TreeReference}, x) + variational = log_potential.enclosed + buffer = log_potential.buffer + @. buffer = tree_gradient(log_potential, x) + return LogDensityProblems.logdensity(variational, x), buffer +end \ No newline at end of file From f5ce66e697f7f01db888269e9e02f79ff1c27503 Mon Sep 17 00:00:00 2001 From: Jack Date: Tue, 5 Aug 2025 18:40:59 -0700 Subject: [PATCH 18/27] Changed dimension for toy example --- examples/stan/{1000-hmm.json => 601-hmm.json} | 148 +++--------------- examples/stan/{1000-hmm.stan => 601-hmm.stan} | 0 2 files changed, 24 insertions(+), 124 deletions(-) rename examples/stan/{1000-hmm.json => 601-hmm.json} (51%) rename examples/stan/{1000-hmm.stan => 601-hmm.stan} (100%) diff --git a/examples/stan/1000-hmm.json b/examples/stan/601-hmm.json similarity index 51% rename from examples/stan/1000-hmm.json rename to examples/stan/601-hmm.json index fa8206b92..794e5842b 100644 --- a/examples/stan/1000-hmm.json +++ b/examples/stan/601-hmm.json @@ -1,5 +1,5 @@ { - "N": 1000, + "N": 600, "observations": [2.20041627e+00, 1.95769524e+00, -8.52087759e-01, -3.49801576e+00, -2.75466819e+00, -7.84143831e+00, -7.00086086e+00, -1.00700985e+01, -1.12625951e+01, -1.14712207e+01, -7.63410793e+00, -6.82938344e+00, @@ -15,7 +15,7 @@ 7.28679787e-01, 2.02044301e+00, 4.83385362e+00, 4.28691024e+00, 6.39388537e+00, 3.85840269e+00, 2.80378245e+00, 5.52823636e+00, 3.65488463e+00, 1.16215512e+00, -2.67688803e-01, -7.58287776e-01, - -7.02438545e-01, -6.81065563e-01, 4.78604164e+00, 2.50092926e+00, + -7.02438545e-01, -6.81065563e-01, 4.78604164e+00, 2.50092926e+00, 4.90315224e+00, 2.88597671e+00, 5.79131041e+00, 4.51960872e+00, 2.96660337e+00, 2.05281631e+00, 2.94518154e+00, 3.80861207e+00, 3.57769218e+00, 4.72358097e+00, 5.27504634e+00, 2.72113257e+00, @@ -27,9 +27,9 @@ 2.58430799e+00, -8.10793344e-02, 8.01232671e-01, 1.26356208e+00, 2.42295161e+00, 4.78682256e+00, 1.64580644e+00, 7.49921116e-01, 4.67797903e+00, 3.96665687e+00, 6.75268362e-01, 3.93938545e+00, - 4.02636161e+00, 4.44271098e+00, 5.19042137e+00, 2.58723543e+00, + 4.02636161e+00, 4.44271098e+00, 5.19042137e+00, 2.58723543e+00, 7.18735268e+00, 6.71400675e+00, 7.33104925e+00, 8.07291814e+00, - 1.01641232e+01, 7.12567322e+00, 5.56664294e+00, 8.06508369e+00, + 1.01641232e+01, 7.12567322e+00, 5.56664294e+00, 8.06508369e+00, 8.06292109e+00, 4.96482253e+00, 7.52618800e+00, 8.00781373e+00, 3.75931312e+00, 9.41197214e-01, 7.20632190e+00, 7.21676797e+00, 6.53389120e+00, 5.23761666e+00, 3.48087341e+00, 2.21970007e+00, @@ -39,7 +39,7 @@ 1.13194604e+01, 1.19072907e+01, 1.02393721e+01, 1.02645342e+01, 1.01497331e+01, 1.41101105e+01, 8.85600537e+00, 7.99558228e+00, 1.08893333e+01, 1.03842444e+01, 9.92441316e+00, 1.06790997e+01, - 1.12610941e+01, 1.02065620e+01, 8.74433367e+00, 7.94298791e+00, + 1.12610941e+01, 1.02065620e+01, 8.74433367e+00, 7.94298791e+00, 7.08367963e+00, 6.09791682e+00, 6.30358218e+00, 8.68326355e+00, 1.12112083e+01, 1.15750269e+01, 1.38522936e+01, 9.37299819e+00, 5.78993972e+00, 4.49590123e+00, 6.51469197e+00, 7.36551889e+00, @@ -47,10 +47,10 @@ 7.38936309e+00, 7.38584474e+00, 6.47593338e+00, 5.28929490e+00, 5.40888051e+00, 7.68078795e+00, 9.59892683e+00, 5.20831396e+00, 5.16531652e+00, 5.24548731e+00, 5.56486992e+00, 6.77264209e+00, - 1.61973862e+00, 5.37412033e+00, 3.03930250e+00, 4.47814021e+00, + 1.61973862e+00, 5.37412033e+00, 3.03930250e+00, 4.47814021e+00, 3.44894920e+00, 2.34977267e+00, 1.00274187e+00, 4.78782641e+00, 5.00688524e+00, 5.42917704e+00, 8.74264437e+00, 1.08546136e+01, - 1.26508115e+01, 1.38369640e+01, 1.22559499e+01, 1.35193366e+01, + 1.26508115e+01, 1.38369640e+01, 1.22559499e+01, 1.35193366e+01, 1.45558898e+01, 1.98940529e+01, 1.56230795e+01, 1.44094233e+01, 1.72861607e+01, 1.33902066e+01, 1.58302603e+01, 1.70575462e+01, 1.46093763e+01, 1.56669888e+01, 1.61703988e+01, 1.75911871e+01, @@ -74,7 +74,7 @@ 3.11005590e+01, 3.29650208e+01, 3.57058316e+01, 3.67591464e+01, 3.59652717e+01, 3.86935166e+01, 3.57911075e+01, 3.79579306e+01, 3.84987982e+01, 3.56516264e+01, 3.08808791e+01, 3.14919873e+01, - 3.37580507e+01, 3.67153048e+01, 3.71685145e+01, 3.87817383e+01, + 3.37580507e+01, 3.67153048e+01, 3.71685145e+01, 3.87817383e+01, 3.96963411e+01, 3.83564527e+01, 3.93413077e+01, 3.75321922e+01, 4.15746594e+01, 4.36385157e+01, 3.98425106e+01, 4.15453595e+01, 3.94474951e+01, 3.80429200e+01, 3.58932846e+01, 3.55484494e+01, @@ -82,7 +82,7 @@ 3.30004314e+01, 3.33909262e+01, 3.38864614e+01, 3.51344117e+01, 3.47694398e+01, 3.61716779e+01, 3.64870273e+01, 3.60862257e+01, 3.77519484e+01, 3.93528421e+01, 3.94286091e+01, 4.03535430e+01, - 3.91922900e+01, 3.71025976e+01, 3.22791905e+01, 3.18571947e+01, + 3.91922900e+01, 3.71025976e+01, 3.22791905e+01, 3.18571947e+01, 3.00682943e+01, 3.15075832e+01, 3.11689635e+01, 3.36183250e+01, 3.36876822e+01, 3.54169048e+01, 3.52179400e+01, 3.62088136e+01, 3.30185438e+01, 4.01510887e+01, 3.64507373e+01, 3.77575895e+01, @@ -95,9 +95,9 @@ 4.05876623e+01, 4.14133895e+01, 4.14028065e+01, 4.18386992e+01, 4.13414260e+01, 4.11379412e+01, 4.11907486e+01, 4.09471608e+01, 3.95399716e+01, 3.69100236e+01, 3.66860475e+01, 3.79205085e+01, - 3.34496798e+01, 3.78740709e+01, 3.47468014e+01, 3.24855140e+01, + 3.34496798e+01, 3.78740709e+01, 3.47468014e+01, 3.24855140e+01, 3.40201234e+01, 3.61644552e+01, 3.63968850e+01, 3.42714727e+01, - 3.41741941e+01, 3.11705640e+01, 2.83597819e+01, 3.07470347e+01, + 3.41741941e+01, 3.11705640e+01, 2.83597819e+01, 3.07470347e+01, 2.56121553e+01, 2.59008892e+01, 2.57916997e+01, 2.88839503e+01, 2.96389389e+01, 2.76744072e+01, 2.61646143e+01, 2.57337329e+01, 2.47063686e+01, 2.58303525e+01, 2.47858266e+01, 2.87182967e+01, @@ -108,33 +108,33 @@ 3.40074353e+01, 3.16579505e+01, 3.38981861e+01, 3.28793324e+01, 3.52557179e+01, 3.39301843e+01, 3.63010869e+01, 4.05274543e+01, 4.07032346e+01, 3.85513408e+01, 4.02953088e+01, 3.70458613e+01, - 3.72419718e+01, 3.70039193e+01, 3.77764164e+01, 3.56303451e+01, + 3.72419718e+01, 3.70039193e+01, 3.77764164e+01, 3.56303451e+01, 3.62959585e+01, 3.78946072e+01, 3.62709061e+01, 3.76634104e+01, 3.91337749e+01, 3.84004126e+01, 3.81241726e+01, 4.15581530e+01, 4.01142431e+01, 3.63541133e+01, 3.86243449e+01, 4.01430134e+01, 3.92610382e+01, 3.59249212e+01, 3.73114540e+01, 3.78569320e+01, - 3.84827639e+01, 3.73911966e+01, 3.60235311e+01, 3.55317825e+01, + 3.84827639e+01, 3.73911966e+01, 3.60235311e+01, 3.55317825e+01, 3.79022022e+01, 3.80714924e+01, 3.83597586e+01, 3.78601350e+01, 3.73587552e+01, 3.60520322e+01, 3.63999056e+01, 3.57917551e+01, 3.61548360e+01, 3.50228969e+01, 3.56458695e+01, 3.42961012e+01, 3.58128162e+01, 3.52237851e+01, 3.47537293e+01, 3.68640322e+01, - 3.41322362e+01, 3.66964622e+01, 3.34686593e+01, 3.44682303e+01, + 3.41322362e+01, 3.66964622e+01, 3.34686593e+01, 3.44682303e+01, 3.45441304e+01, 3.58217338e+01, 3.58910303e+01, 3.73505089e+01, 3.77782400e+01, 3.46415997e+01, 3.58564269e+01, 3.53247241e+01, - 3.36910112e+01, 3.27083918e+01, 3.13210632e+01, 2.77079547e+01, + 3.36910112e+01, 3.27083918e+01, 3.13210632e+01, 2.77079547e+01, 3.00823366e+01, 2.99031872e+01, 2.79379527e+01, 2.74939928e+01, - 2.75225223e+01, 2.82877226e+01, 2.83256334e+01, 2.83580797e+01, + 2.75225223e+01, 2.82877226e+01, 2.83256334e+01, 2.83580797e+01, 3.04150303e+01, 2.66598925e+01, 2.91794583e+01, 2.83726774e+01, 2.51347763e+01, 2.25963168e+01, 2.18457418e+01, 2.42479872e+01, - 2.67623007e+01, 2.25385780e+01, 2.73269381e+01, 2.59379632e+01, - 2.10230150e+01, 2.07278808e+01, 2.00192286e+01, 2.01415859e+01, + 2.67623007e+01, 2.25385780e+01, 2.73269381e+01, 2.59379632e+01, + 2.10230150e+01, 2.07278808e+01, 2.00192286e+01, 2.01415859e+01, 2.17867287e+01, 2.43700543e+01, 2.45504939e+01, 2.47710967e+01, - 2.38890926e+01, 2.54430988e+01, 2.19310121e+01, 2.01861465e+01, + 2.38890926e+01, 2.54430988e+01, 2.19310121e+01, 2.01861465e+01, 2.38134047e+01, 2.19633916e+01, 2.16153891e+01, 2.45069550e+01, - 2.04463694e+01, 1.90655406e+01, 2.09490496e+01, 2.29965558e+01, + 2.04463694e+01, 1.90655406e+01, 2.09490496e+01, 2.29965558e+01, 2.57579529e+01, 2.54823802e+01, 2.56986140e+01, 2.50657061e+01, 2.61572543e+01, 2.35304540e+01, 2.03513795e+01, 2.47338959e+01, - 2.85361200e+01, 2.67531185e+01, 2.53851547e+01, 2.93807397e+01, + 2.85361200e+01, 2.67531185e+01, 2.53851547e+01, 2.93807397e+01, 2.93031483e+01, 3.07399296e+01, 2.82640182e+01, 2.87514656e+01, 2.70374233e+01, 2.65811547e+01, 2.61418710e+01, 2.58550020e+01, 2.71578776e+01, 2.64238308e+01, 2.84622023e+01, 2.79479797e+01, @@ -146,108 +146,8 @@ 3.61407066e+01, 3.55777199e+01, 3.57989925e+01, 3.61486268e+01, 3.79354584e+01, 3.87753346e+01, 4.04170961e+01, 3.85544136e+01, 3.75831990e+01, 3.90366313e+01, 3.49632852e+01, 3.27692015e+01, - 3.66075790e+01, 3.82664250e+01, 4.10635836e+01, 4.01232858e+01, - 3.82659471e+01, 3.86222563e+01, 3.95396663e+01, 4.07177845e+01, + 3.66075790e+01, 3.82664250e+01, 4.10635836e+01, 4.01232858e+01, + 3.82659471e+01, 3.86222563e+01, 3.95396663e+01, 4.07177845e+01, 4.16468646e+01, 4.14681492e+01, 3.87589157e+01, 4.25642139e+01, - 4.28169508e+01, 4.05391734e+01, 3.96425358e+01, 3.65827556e+01, - 3.83378639e+01, 3.78017070e+01, 4.20828011e+01, 3.95418102e+01, - 4.31436797e+01, 4.05330045e+01, 3.90917448e+01, 3.92669923e+01, - 3.89403550e+01, 3.78645983e+01, 4.00010289e+01, 4.11466577e+01, - 4.09166206e+01, 4.28376642e+01, 4.83032535e+01, 5.03440683e+01, - 5.04939470e+01, 4.83873177e+01, 4.88424153e+01, 4.60204642e+01, - 4.53628392e+01, 4.53208678e+01, 4.34104582e+01, 4.72712188e+01, - 4.53085845e+01, 4.87469883e+01, 4.53176722e+01, 4.59848480e+01, - 4.68180202e+01, 4.80254853e+01, 4.88952277e+01, 4.89675131e+01, - 4.89846423e+01, 4.77669309e+01, 4.68676581e+01, 4.72767848e+01, - 4.66397904e+01, 4.74731207e+01, 4.85923869e+01, 5.10479851e+01, - 5.09645735e+01, 4.63719640e+01, 4.87186987e+01, 4.72966687e+01, - 4.89071728e+01, 4.73546804e+01, 4.87901205e+01, 4.83914422e+01, - 4.86255307e+01, 4.61155063e+01, 4.73960112e+01, 4.60700542e+01, - 4.37176963e+01, 4.52532335e+01, 4.65388609e+01, 5.06470491e+01, - 4.98672825e+01, 5.14808476e+01, 5.25245687e+01, 5.18767123e+01, - 5.01147751e+01, 5.19041234e+01, 4.97524862e+01, 5.12598980e+01, - 5.05654567e+01, 5.13628797e+01, 4.96992870e+01, 5.05866731e+01, - 4.98398190e+01, 4.96328435e+01, 4.92854321e+01, 5.33002064e+01, - 5.14676067e+01, 5.11365110e+01, 5.51071319e+01, 5.13009479e+01, - 5.31046041e+01, 5.24629503e+01, 5.36683054e+01, 5.60098629e+01, - 5.52283081e+01, 5.62461682e+01, 5.80666498e+01, 5.93170298e+01, - 6.11729440e+01, 6.14663833e+01, 5.99074269e+01, 5.70475155e+01, - 5.57467469e+01, 5.53562365e+01, 5.16084196e+01, 5.38203806e+01, - 5.26558492e+01, 5.59798364e+01, 5.59254414e+01, 5.89710219e+01, - 5.80706137e+01, 5.49073495e+01, 5.54256550e+01, 5.45916219e+01, - 5.60633665e+01, 5.76954172e+01, 5.64758096e+01, 5.95694095e+01, - 5.82511840e+01, 5.83089690e+01, 5.67660766e+01, 5.89150907e+01, - 6.39865309e+01, 6.47192476e+01, 6.32750374e+01, 5.99555479e+01, - 6.31998530e+01, 6.25046699e+01, 6.06357692e+01, 6.25754608e+01, - 6.24183937e+01, 6.19650579e+01, 6.56295351e+01, 6.78886249e+01, - 6.68003985e+01, 6.62405935e+01, 7.15603668e+01, 6.92797252e+01, - 7.19722185e+01, 7.15217642e+01, 7.10314056e+01, 7.45748984e+01, - 7.09084928e+01, 6.75038378e+01, 6.89372006e+01, 6.53693364e+01, - 6.88181343e+01, 6.75895752e+01, 6.08994529e+01, 6.32462543e+01, - 6.58479365e+01, 6.62897445e+01, 6.75779635e+01, 6.48719351e+01, - 6.46889864e+01, 6.32792958e+01, 6.17130927e+01, 6.42011267e+01, - 5.67388212e+01, 6.13849245e+01, 6.24991451e+01, 6.10164266e+01, - 6.32098091e+01, 6.03996109e+01, 5.83264172e+01, 6.09408768e+01, - 6.25592304e+01, 6.04607721e+01, 6.23145977e+01, 6.35697235e+01, - 6.30387394e+01, 6.16345805e+01, 6.11449057e+01, 6.37358160e+01, - 6.29064326e+01, 6.30331826e+01, 6.26331013e+01, 6.31896768e+01, - 6.37297928e+01, 5.62072318e+01, 5.60973969e+01, 5.66564049e+01, - 5.78190653e+01, 5.99665542e+01, 5.73184637e+01, 5.50657201e+01, - 5.39367730e+01, 5.48769411e+01, 5.13593369e+01, 5.08861712e+01, - 5.26016574e+01, 5.24615858e+01, 5.60127981e+01, 5.48129187e+01, - 5.77475665e+01, 5.76449134e+01, 5.55324824e+01, 5.82712202e+01, - 5.28443004e+01, 5.58971797e+01, 5.48283501e+01, 5.62996898e+01, - 5.32566918e+01, 5.44371150e+01, 4.97249489e+01, 5.21936462e+01, - 5.32619388e+01, 5.22416434e+01, 5.03770476e+01, 5.06533824e+01, - 5.19585498e+01, 4.95978199e+01, 5.07628022e+01, 4.97096006e+01, - 5.04437666e+01, 4.93671130e+01, 5.31059827e+01, 5.12641705e+01, - 4.99410381e+01, 5.11623926e+01, 5.06482031e+01, 4.62325177e+01, - 4.75640127e+01, 4.66129257e+01, 4.41874806e+01, 4.36189522e+01, - 4.65535771e+01, 4.57677232e+01, 4.34520084e+01, 4.22660168e+01, - 4.69396993e+01, 4.71220580e+01, 4.63055727e+01, 4.61986738e+01, - 4.63830465e+01, 4.73434232e+01, 4.39444361e+01, 4.28300201e+01, - 4.40471823e+01, 4.20143829e+01, 4.42918783e+01, 4.56706574e+01, - 4.88734642e+01, 4.62027279e+01, 4.80358036e+01, 4.75777879e+01, - 4.58832198e+01, 4.94304927e+01, 4.82845984e+01, 4.66429346e+01, - 4.67056668e+01, 4.69122986e+01, 4.27901620e+01, 4.65325967e+01, - 4.83390420e+01, 4.70025843e+01, 4.79409464e+01, 5.17068551e+01, - 5.05584049e+01, 4.97489082e+01, 4.57898409e+01, 4.53285840e+01, - 4.81430637e+01, 4.89080671e+01, 5.03626445e+01, 5.02400699e+01, - 4.94825214e+01, 4.54937190e+01, 5.19526355e+01, 5.38132293e+01, - 4.99436520e+01, 4.83595928e+01, 4.78333140e+01, 4.48498755e+01, - 4.46929281e+01, 4.09217192e+01, 4.28240121e+01, 4.65186237e+01, - 4.61604551e+01, 4.84110465e+01, 4.61393856e+01, 4.67949869e+01, - 4.69190818e+01, 4.42192962e+01, 4.62411790e+01, 4.62952609e+01, - 4.56967639e+01, 4.94963276e+01, 4.76438980e+01, 4.76246976e+01, - 4.62150732e+01, 4.43248525e+01, 4.70872994e+01, 4.57373936e+01, - 4.50707527e+01, 4.55083211e+01, 4.64562081e+01, 4.50560377e+01, - 4.50077580e+01, 4.51886837e+01, 4.62188716e+01, 4.80803749e+01, - 4.45905514e+01, 4.94097478e+01, 4.87136217e+01, 4.70088604e+01, - 4.82228625e+01, 4.86098235e+01, 4.81864026e+01, 4.87803107e+01, - 4.91864136e+01, 5.28993466e+01, 5.47071607e+01, 5.53396866e+01, - 5.90826705e+01, 5.63086868e+01, 5.70123754e+01, 6.21627017e+01, - 5.97786562e+01, 5.83973692e+01, 5.79965186e+01, 5.60436028e+01, - 5.90281943e+01, 5.93664942e+01, 5.88925447e+01, 5.97547768e+01, - 5.85818570e+01, 5.92996480e+01, 5.97213295e+01, 5.90132611e+01, - 6.00813692e+01, 6.10462490e+01, 5.96630644e+01, 6.09650978e+01, - 6.17591429e+01, 6.27805127e+01, 6.52614308e+01, 6.64653215e+01, - 6.48972785e+01, 6.34602292e+01, 6.25913131e+01, 6.37276648e+01, - 6.39805631e+01, 6.65411796e+01, 6.69554024e+01, 6.64835310e+01, - 6.27351487e+01, 6.37736792e+01, 6.33913638e+01, 6.41161732e+01, - 6.17439655e+01, 5.96170382e+01, 6.33738869e+01, 6.36763719e+01, - 6.35608969e+01, 5.93835735e+01, 6.27860841e+01, 6.15587271e+01, - 5.99924406e+01, 5.87257680e+01, 5.80450010e+01, 5.95806869e+01, - 6.04869804e+01, 6.07560040e+01, 5.65849642e+01, 5.56785583e+01, - 5.78695710e+01, 5.42369067e+01, 5.43403858e+01, 5.22663184e+01, - 5.08715347e+01, 5.28165264e+01, 5.37169510e+01, 5.28806414e+01, - 5.68591051e+01, 5.64357773e+01, 5.69022166e+01, 5.35198183e+01, - 5.50060076e+01, 5.46257864e+01, 5.21519161e+01, 5.39814688e+01, - 5.20783610e+01, 5.52712187e+01, 4.90838566e+01, 5.30739350e+01, - 5.22890425e+01, 5.49519898e+01, 5.41107932e+01, 5.15971948e+01, - 5.16711718e+01, 5.16623935e+01, 4.76658924e+01, 4.66589874e+01, - 4.28605236e+01, 4.26718233e+01, 4.32431638e+01, 4.08936186e+01, - 4.16692229e+01, 4.05938496e+01, 4.34593718e+01, 4.26736559e+01, - 4.18790859e+01, 4.25754888e+01, 4.73663448e+01, 4.84002881e+01, - 4.63102423e+01, 4.59397977e+01, 4.57514586e+01, 4.72229676e+01, - 4.49929558e+01, 4.63036309e+01, 4.49127403e+01, 4.69778551e+01] + 4.28169508e+01, 4.05391734e+01, 3.96425358e+01, 3.65827556e+01] } \ No newline at end of file diff --git a/examples/stan/1000-hmm.stan b/examples/stan/601-hmm.stan similarity index 100% rename from examples/stan/1000-hmm.stan rename to examples/stan/601-hmm.stan From 858fccc51fd3616579a1deedfab61782f2e5a481 Mon Sep 17 00:00:00 2001 From: Jack Date: Tue, 5 Aug 2025 18:43:26 -0700 Subject: [PATCH 19/27] corrected gradient, fixed bugs --- src/variational/TreeReference.jl | 34 ++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index 98192944f..d0f80030f 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -9,12 +9,12 @@ A Gaussian tree variational reference which_variable::Vector{Symbol} = Vector{Symbol}() which_index::Vector{Int} = Vector{Int}() iid_sample_set::Vector{Float64} = Vector{Float64}() - first_tuning_round::Int = 11 covariance_matrix::Matrix{Float64} = zeros(Float64, 0, 0) + first_tuning_round::Int = 10 - function TreeReference(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, first_tuning_round, covariance_matrix) + function TreeReference(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, covariance_matrix, first_tuning_round) @assert first_tuning_round ≥ 1 - new(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, first_tuning_round, covariance_matrix) + new(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, covariance_matrix, first_tuning_round) end end @@ -56,7 +56,7 @@ function update_reference!(reduced_recorders, variational::TreeReference, state) adjacency_list::Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}} = Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}}() for i in 1:total_number_of_nodes - adjacency_list[i] = Vector{Int}() + adjacency_list[i] = Vector{Tuple{Float64, Float64, Int, Int}}() end variational.covariance_matrix = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) @@ -65,10 +65,11 @@ function update_reference!(reduced_recorders, variational::TreeReference, state) for j = (i+1):total_number_of_nodes normalization = (variational.standard_deviation[i] * variational.standard_deviation[j]) rho = variational.covariance_matrix[i,j] / normalization + rho = clamp(rho, -0.99, 0.99) I = -0.5*log(1-rho^2) - push!(adjacency_list[i], (rho, I, i, j)) - push!(adjacency_list[j], (rho, I, j, i)) + push!(adjacency_list[i], (I, rho, i, j)) + push!(adjacency_list[j], (I, rho, j, i)) end end root = 1 @@ -118,8 +119,8 @@ function sample_iid!(variational::TreeReference, replica, shared) parent_idx = edge[3] child_idx = edge[4] - mu, sigma = tree_logdensity(variational, child_idx, parent_idx, variational.iid_sample_set[parent_idx], edge[1]) - val = rand(replica.rng) * sigma + mu + mu, sigma = tree_logdensity(variational, child_idx, parent_idx, variational.iid_sample_set[parent_idx], edge[2]) + val = randn(replica.rng) * sigma + mu variational.iid_sample_set[child_idx] = val update_state!(replica.state, variational.which_variable[child_idx], variational.which_index[child_idx], val) @@ -133,7 +134,7 @@ function (variational::TreeReference)(state) marginal_state = variable(state, variational.which_variable[1])[1] marginal_mean = variational.mean[1] marginal_standard_deviation = variational.standard_deviation[1] - log_pdf += Distributions.logpdf(Distributions.Normal(marginal_mean, marginal_standard_deviation), marginal_state)[1] + log_pdf += Distributions.logpdf(Distributions.Normal(marginal_mean, marginal_standard_deviation), marginal_state) for edge in variational.edge_set child_idx = edge[4] @@ -145,7 +146,7 @@ function (variational::TreeReference)(state) state_at_parent = variable(state, parent_var_name)[variational.which_index[parent_idx]] state_at_child = variable(state, child_var_name)[variational.which_index[child_idx]] - mu, sigma = tree_logdensity(variational, child_idx, parent_idx, state_at_parent, edge[1]) + mu, sigma = tree_logdensity(variational, child_idx, parent_idx, state_at_parent, edge[2]) log_pdf += Distributions.logpdf(Distributions.Normal(mu, sigma), state_at_child) end return log_pdf @@ -165,12 +166,12 @@ function tree_logdensity(variational::TreeReference, child_num, parent_num, stat end function tree_gradient(variational::TreeReference, state) - gradient = 0.0 + gradient = zeros(length(variational.mean)) marginal_state = variable(state, variational.which_variable[1])[1] marginal_mean = variational.mean[1] marginal_standard_deviation = variational.standard_deviation[1] - gradient += -(marginal_state - marginal_mean) / marginal_standard_deviation^2 + gradient[1] = -(marginal_state - marginal_mean) / marginal_standard_deviation^2 for edge in variational.edge_set child_idx = edge[4] @@ -182,8 +183,11 @@ function tree_gradient(variational::TreeReference, state) state_at_parent = variable(state, parent_var_name)[variational.which_index[parent_idx]] state_at_child = variable(state, child_var_name)[variational.which_index[child_idx]] - mu, sigma = tree_logdensity(variational, child_idx, parent_idx, state_at_parent, edge[1]) - gradient += -(state_at_child - mu) / sigma^2 + mu, sigma = tree_logdensity(variational, child_idx, parent_idx, state_at_parent, edge[2]) + delta = -(state_at_child - mu) / sigma^2 + + gradient[child_idx] += delta + gradient[parent_idx] += delta * (edge[2] * variational.standard_deviation[child_idx] / variational.standard_deviation[parent_idx]) end return gradient end @@ -206,6 +210,6 @@ LogDensityProblemsAD.ADgradient(kind::ADTypes.AbstractADType, log_potential::Tre function LogDensityProblems.logdensity_and_gradient(log_potential::BufferedAD{TreeReference}, x) variational = log_potential.enclosed buffer = log_potential.buffer - @. buffer = tree_gradient(log_potential, x) + buffer .= tree_gradient(variational, x) return LogDensityProblems.logdensity(variational, x), buffer end \ No newline at end of file From 5ce40d6e48dadd39a44f3c3c1a08b935c2d9f9e1 Mon Sep 17 00:00:00 2001 From: Jack Date: Tue, 5 Aug 2025 18:44:10 -0700 Subject: [PATCH 20/27] added jitter term for cholesky stability, refactored --- src/variational/DenseGaussianReference.jl | 29 +++++++++++++---------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/variational/DenseGaussianReference.jl b/src/variational/DenseGaussianReference.jl index 562b4bacd..68becc164 100644 --- a/src/variational/DenseGaussianReference.jl +++ b/src/variational/DenseGaussianReference.jl @@ -8,11 +8,12 @@ A Gaussian dense variational reference (i.e., with a dense covariance matrix). cholesky::Any = zeros(Float64, 0, 0) which_variable::Vector{Any} = Vector{Any}() which_index::Vector{Int} = Vector{Int}() - first_tuning_round::Int = 11 # TODO: this should be moved elsewhere? + identity_gaussian::Any = zeros(Float64, 0, 0) + first_tuning_round::Int = 10 # TODO: this should be moved elsewhere? - function DenseGaussianReference(mean, covariance, precision, cholesky, which_variable, which_index, first_tuning_round) + function DenseGaussianReference(mean, covariance, precision, cholesky, which_variable, which_index, identity_gaussian, first_tuning_round) @assert first_tuning_round ≥ 1 - new(mean, covariance, precision, cholesky, which_variable, which_index, first_tuning_round) + new(mean, covariance, precision, cholesky, which_variable, which_index, identity_gaussian, first_tuning_round) end end @@ -28,30 +29,32 @@ function update_reference!(reduced_recorders, variational::DenseGaussianReferenc empty!(variational.which_variable) empty!(variational.which_index) empty!(variational.mean) + empty!(variational.which_variable) + empty!(variational.which_index) + + eps = 1e-6 + temp_covariance = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) + variational.covariance = temp_covariance + eps * I - variational.covariance = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) + variational.mean = get_transformed_statistic(reduced_recorders, :singleton_variable, Mean) + variational.identity_gaussian = MvNormal(zeros(length(variational.mean)), I) variational.precision = inv(variational.covariance) variational.cholesky = cholesky(variational.covariance).L - variational.mean = get_transformed_statistic(reduced_recorders, :singleton_variable, Mean) for var_name in continuous_variables(state) for i = 1:length(variable(state, var_name)) push!(variational.which_variable, var_name) push!(variational.which_index, i) - end end end function sample_iid!(variational::DenseGaussianReference, replica, shared) - dim = length(variational.mean) - identity_gaussian = MvNormal(zeros(dim), I) - z = rand(identity_gaussian) - + z = rand(variational.identity_gaussian) sample = variational.mean + variational.cholesky * z - for i in 1:dim + for i = 1:length(variational.mean) update_state!(replica.state, variational.which_variable[i], variational.which_index[i], sample[i]) end end @@ -59,7 +62,7 @@ end function (variational::DenseGaussianReference)(state) flattened_state = Vector{Float64}() - for i in 1:length(variational.mean) + for i = 1:length(variational.mean) name = variational.which_variable[i] index = variational.which_index[i] push!(flattened_state, Pigeons.variable(state, name)[index]) @@ -87,6 +90,6 @@ function LogDensityProblems.logdensity_and_gradient(log_potential::BufferedAD{De buffer = log_potential.buffer mean = variational.mean precision = variational.precision - @. buffer = -precision * (x - mean) + buffer .= -precision * (x - mean) return LogDensityProblems.logdensity(variational, x), buffer end \ No newline at end of file From 119c724891c9010dc272ca89e0f8ec51c12448a7 Mon Sep 17 00:00:00 2001 From: Jack Date: Wed, 6 Aug 2025 13:18:17 -0700 Subject: [PATCH 21/27] Refactored update for improved modularity --- src/variational/TreeReference.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/variational/TreeReference.jl b/src/variational/TreeReference.jl index d0f80030f..79cb4a606 100644 --- a/src/variational/TreeReference.jl +++ b/src/variational/TreeReference.jl @@ -50,19 +50,23 @@ function update_reference!(reduced_recorders, variational::TreeReference, state) end end @assert length(variational.mean) == length(variational.standard_deviation) - total_number_of_nodes = length(variational.mean) - variational.iid_sample_set = zeros(total_number_of_nodes) + variational.iid_sample_set = zeros(length(variational.mean)) + variational.covariance_matrix = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) + variational.edge_set = build_tree(variational) +end + + +function build_tree(variational::TreeReference) + dim = length(variational.mean) adjacency_list::Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}} = Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}}() - for i in 1:total_number_of_nodes + for i in 1:dim adjacency_list[i] = Vector{Tuple{Float64, Float64, Int, Int}}() end - variational.covariance_matrix = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) - - for i = 1:total_number_of_nodes - for j = (i+1):total_number_of_nodes + for i = 1:dim + for j = (i+1):dim normalization = (variational.standard_deviation[i] * variational.standard_deviation[j]) rho = variational.covariance_matrix[i,j] / normalization rho = clamp(rho, -0.99, 0.99) @@ -73,13 +77,10 @@ function update_reference!(reduced_recorders, variational::TreeReference, state) end end root = 1 - variational.edge_set = directed_max_tree(adjacency_list, root) - - empty!(adjacency_list) + return directed_max_tree(adjacency_list, root) end - function directed_max_tree(adjacency_list, root) total_number_of_nodes = length(keys(adjacency_list)) mst = Vector{Tuple{Float64, Float64, Int, Int}}() From 23fc1f9a54acbc47d42ae5737afb9519ddc62cdd Mon Sep 17 00:00:00 2001 From: Jack Date: Tue, 12 Aug 2025 14:16:43 -0700 Subject: [PATCH 22/27] Updated dependencies --- test/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Project.toml b/test/Project.toml index 0ebd5bcff..f90cd62e9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ ArgMacros = "dbc42088-9de8-42a0-8ec8-2cd114e1ea3e" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a" Comrade = "99d987ce-9a1e-4df8-bc0b-1ea019aa547b" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" From 823ae31e5e8dbb95df2bc562c55f935527158044 Mon Sep 17 00:00:00 2001 From: Jack Date: Tue, 12 Aug 2025 14:19:49 -0700 Subject: [PATCH 23/27] Added test case for tree reference --- test/test_tree.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 test/test_tree.jl diff --git a/test/test_tree.jl b/test/test_tree.jl new file mode 100644 index 000000000..7fe07caad --- /dev/null +++ b/test/test_tree.jl @@ -0,0 +1,34 @@ +import Pigeons: TreeReference, directed_max_tree +using DataStructures + +dummy = 999 + +## Generate a complete graph with 5 vertices +function generate_test_tree() + dim = 5 + adjacency_list = Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}}() + + for i = 1:dim + adjacency_list[i] = Vector{Tuple{Float64, Float64, Int, Int}}() + end + + append!(adjacency_list[1], [(1, dummy, 1,2), (2, dummy, 1,3), (3, dummy, 1,4), (4, dummy, 1,5)]) + append!(adjacency_list[2], [(1, dummy, 2,1), (5, dummy, 2,3), (8, dummy, 2,4), (6, dummy, 2,5)]) + append!(adjacency_list[3], [(2, dummy, 3,1), (5, dummy, 3,2), (7, dummy, 3,4), (9, dummy, 3,5)]) + append!(adjacency_list[4], [(3, dummy, 4,1), (8, dummy, 4,2), (7, dummy, 4,3), (10, dummy, 4,5)]) + append!(adjacency_list[5], [(4, dummy, 5,1), (6, dummy, 5,2), (9, dummy, 5,3), (10, dummy, 5,4)]) + + return adjacency_list +end + +@testset "Spanning tree maximality" begin + root = 1 + tree = directed_max_tree(generate_test_tree(), root) + + @test length(tree) == 4 + + expected = [(4, dummy, 1,5), (10, dummy, 5, 4), (8, dummy, 4, 2), (9, dummy, 5, 3)] + for edge in expected + @test edge in expected + end +end \ No newline at end of file From 16b4c517428583f7fcb7bfaed8c5aeb16104d855 Mon Sep 17 00:00:00 2001 From: Jack Date: Wed, 13 Aug 2025 18:58:37 -0700 Subject: [PATCH 24/27] init mixed tree reference class --- src/Pigeons.jl | 2 +- src/includes.jl | 1 + src/variational/MixedTreeReference.jl | 216 ++++++++++++++++++++++++ src/variational/VariationalReference.jl | 2 +- 4 files changed, 219 insertions(+), 2 deletions(-) create mode 100644 src/variational/MixedTreeReference.jl diff --git a/src/Pigeons.jl b/src/Pigeons.jl index 2a9e36d57..15ecf1b0b 100644 --- a/src/Pigeons.jl +++ b/src/Pigeons.jl @@ -69,7 +69,7 @@ export pigeons, Inputs, PT, # getting information out of an execution: stepping_stone, n_tempered_restarts, n_round_trips, process_sample, get_sample, # variational references: - GaussianReference, TreeReference, DenseGaussianReference, + GaussianReference, TreeReference, MixedTreeReference, DenseGaussianReference, # samplers SliceSampler, AutoMALA, Compose, AAPS, MALA, Mix diff --git a/src/includes.jl b/src/includes.jl index e755b0036..3670260c1 100644 --- a/src/includes.jl +++ b/src/includes.jl @@ -80,6 +80,7 @@ include("pt/checks.jl") include("explorers/BufferedAD.jl") include("variational/GaussianReference.jl") include("variational/TreeReference.jl") +include("variational/MixedTreeReference.jl") include("variational/DenseGaussianReference.jl") include("variational/VariationalReference.jl") include("paths/ScaledPrecisionNormalPath.jl") diff --git a/src/variational/MixedTreeReference.jl b/src/variational/MixedTreeReference.jl new file mode 100644 index 000000000..f4e1fd179 --- /dev/null +++ b/src/variational/MixedTreeReference.jl @@ -0,0 +1,216 @@ +""" +A Gaussian tree variational reference +""" + +@kwdef mutable struct MixedTreeReference + edge_set::Vector{Tuple{Float64, Float64, Int, Int}} = Vector{Tuple{Float64, Float64, Int, Int}}() + mean::Vector{Float64} = Vector{Float64}() + standard_deviation::Vector{Float64} = Vector{Float64}() + which_variable::Vector{Symbol} = Vector{Symbol}() + which_index::Vector{Int} = Vector{Int}() + iid_sample_set::Vector{Float64} = Vector{Float64}() + covariance_matrix::Matrix{Float64} = zeros(Float64, 0, 0) + first_tuning_round::Int = 10 + + function MixedTreeReference(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, covariance_matrix, first_tuning_round) + @assert first_tuning_round ≥ 1 + new(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, covariance_matrix, first_tuning_round) + end +end + + +dim(variational::MixedTreeReference) = length(variational.mean) +function activate_variational(variational::MixedTreeReference, iterators) + iterators.round ≥ variational.first_tuning_round ? true : false +end + +variational_recorder_builders(::MixedTreeReference) = [_transformed_online_full] + + +function update_reference!(reduced_recorders, variational::MixedTreeReference, state) + isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") + + empty!(variational.mean) + empty!(variational.standard_deviation) + empty!(variational.which_variable) + empty!(variational.which_index) + variational.edge_set = [] + + for var_name in continuous_variables(state) + temp_mean = get_transformed_statistic(reduced_recorders, var_name, Mean) + temp_std = sqrt.(get_transformed_statistic(reduced_recorders, var_name, Variance)) + + dimension = length(temp_mean) + for i = 1:dimension + push!(variational.mean, temp_mean[i]) + push!(variational.standard_deviation, temp_std[i]) + + push!(variational.which_variable, var_name) + push!(variational.which_index, i) + end + end + @assert length(variational.mean) == length(variational.standard_deviation) + + variational.iid_sample_set = zeros(length(variational.mean)) + variational.covariance_matrix = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) + variational.edge_set = build_tree(variational) +end + + +function build_tree(variational::MixedTreeReference) + dim = length(variational.mean) + + adjacency_list::Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}} = Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}}() + for i in 1:dim + adjacency_list[i] = Vector{Tuple{Float64, Float64, Int, Int}}() + end + + for i = 1:dim + for j = (i+1):dim + normalization = (variational.standard_deviation[i] * variational.standard_deviation[j]) + rho = variational.covariance_matrix[i,j] / normalization + rho = clamp(rho, -0.99, 0.99) + I = -0.5*log(1-rho^2) + + push!(adjacency_list[i], (I, rho, i, j)) + push!(adjacency_list[j], (I, rho, j, i)) + end + end + root = 1 + return directed_max_tree(adjacency_list, root) +end + + +function directed_max_tree(adjacency_list, root) + total_number_of_nodes = length(keys(adjacency_list)) + mst = Vector{Tuple{Float64, Float64, Int, Int}}() + pq = BinaryMaxHeap{Tuple{Float64, Float64, Int, Int}}() + visited_nodes = Set{Int}() + + push!(visited_nodes, root) + for edge in adjacency_list[root] + push!(pq, edge) + end + + while !isempty(pq) && length(mst) Date: Wed, 13 Aug 2025 23:02:21 -0700 Subject: [PATCH 25/27] renaming for clarity --- src/variational/MixedTreeReference.jl | 63 +++++++++++++-------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/src/variational/MixedTreeReference.jl b/src/variational/MixedTreeReference.jl index f4e1fd179..db8d5850e 100644 --- a/src/variational/MixedTreeReference.jl +++ b/src/variational/MixedTreeReference.jl @@ -4,22 +4,23 @@ A Gaussian tree variational reference @kwdef mutable struct MixedTreeReference edge_set::Vector{Tuple{Float64, Float64, Int, Int}} = Vector{Tuple{Float64, Float64, Int, Int}}() - mean::Vector{Float64} = Vector{Float64}() - standard_deviation::Vector{Float64} = Vector{Float64}() + continuous_mean::Vector{Float64} = Vector{Float64}() + continuous_std::Vector{Float64} = Vector{Float64}() which_variable::Vector{Symbol} = Vector{Symbol}() which_index::Vector{Int} = Vector{Int}() iid_sample_set::Vector{Float64} = Vector{Float64}() covariance_matrix::Matrix{Float64} = zeros(Float64, 0, 0) + discrete_map::Dict{Tuple{Symbol, Int, Symbol}, Tuple{Float64, Float64}} = Dict{Tuple{Symbol, Int, Symbol}, Tuple{Float64, Float64}}() first_tuning_round::Int = 10 - function MixedTreeReference(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, covariance_matrix, first_tuning_round) + function MixedTreeReference(edge_set, continuous_mean, continuous_std, which_variable, which_index, iid_sample_set, covariance_matrix, discrete_map, first_tuning_round) @assert first_tuning_round ≥ 1 - new(edge_set, mean, standard_deviation, which_variable, which_index, iid_sample_set, covariance_matrix, first_tuning_round) + new(edge_set, continuous_mean, continuous_std, which_variable, which_index, iid_sample_set, covariance_matrix, discrete_map, first_tuning_round) end end -dim(variational::MixedTreeReference) = length(variational.mean) +dim(variational::MixedTreeReference) = length(variational.continuous_mean) + length(unique(first.(keys(variational.discrete_map)))) function activate_variational(variational::MixedTreeReference, iterators) iterators.round ≥ variational.first_tuning_round ? true : false end @@ -28,10 +29,8 @@ variational_recorder_builders(::MixedTreeReference) = [_transformed_online_full] function update_reference!(reduced_recorders, variational::MixedTreeReference, state) - isempty(discrete_variables(state)) || error("Updating a Gaussian reference with discrete variables.") - - empty!(variational.mean) - empty!(variational.standard_deviation) + empty!(variational.continuous_mean) + empty!(variational.continuous_std) empty!(variational.which_variable) empty!(variational.which_index) variational.edge_set = [] @@ -40,34 +39,34 @@ function update_reference!(reduced_recorders, variational::MixedTreeReference, s temp_mean = get_transformed_statistic(reduced_recorders, var_name, Mean) temp_std = sqrt.(get_transformed_statistic(reduced_recorders, var_name, Variance)) - dimension = length(temp_mean) - for i = 1:dimension - push!(variational.mean, temp_mean[i]) - push!(variational.standard_deviation, temp_std[i]) + temp_dim = length(temp_mean) + for i = 1:temp_dim + push!(variational.continuous_mean, temp_mean[i]) + push!(variational.continuous_std, temp_std[i]) push!(variational.which_variable, var_name) push!(variational.which_index, i) end end - @assert length(variational.mean) == length(variational.standard_deviation) + @assert length(variational.continuous_mean) == length(variational.continuous_std) - variational.iid_sample_set = zeros(length(variational.mean)) + variational.iid_sample_set = zeros(length(variational.continuous_mean)) variational.covariance_matrix = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) variational.edge_set = build_tree(variational) end function build_tree(variational::MixedTreeReference) - dim = length(variational.mean) + continuous_dim = length(variational.continuous_mean) adjacency_list::Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}} = Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}}() - for i in 1:dim + for i in 1:continuous_dim adjacency_list[i] = Vector{Tuple{Float64, Float64, Int, Int}}() end - for i = 1:dim - for j = (i+1):dim - normalization = (variational.standard_deviation[i] * variational.standard_deviation[j]) + for i = 1:continuous_dim + for j = (i+1):continuous_dim + normalization = (variational.continuous_std[i] * variational.continuous_std[j]) rho = variational.covariance_matrix[i,j] / normalization rho = clamp(rho, -0.99, 0.99) I = -0.5*log(1-rho^2) @@ -112,7 +111,7 @@ end function sample_iid!(variational::MixedTreeReference, replica, shared) - marginal_val = randn(replica.rng) * variational.standard_deviation[1] + variational.mean[1] + marginal_val = randn(replica.rng) * variational.continuous_std[1] + variational.continuous_mean[1] variational.iid_sample_set[1] = marginal_val update_state!(replica.state, variational.which_variable[1], 1, marginal_val) @@ -133,8 +132,8 @@ function (variational::MixedTreeReference)(state) log_pdf = 0.0 marginal_state = variable(state, variational.which_variable[1])[1] - marginal_mean = variational.mean[1] - marginal_standard_deviation = variational.standard_deviation[1] + marginal_mean = variational.continuous_mean[1] + marginal_standard_deviation = variational.continuous_std[1] log_pdf += Distributions.logpdf(Distributions.Normal(marginal_mean, marginal_standard_deviation), marginal_state) for edge in variational.edge_set @@ -155,10 +154,10 @@ end function tree_logdensity(variational::MixedTreeReference, child_num, parent_num, state_at_parent, rho) - child_mean = variational.mean[child_num] - parent_mean = variational.mean[parent_num] - child_standard_deviation = variational.standard_deviation[child_num] - parent_standard_deviation = variational.standard_deviation[parent_num] + child_mean = variational.continuous_mean[child_num] + parent_mean = variational.continuous_mean[parent_num] + child_standard_deviation = variational.continuous_std[child_num] + parent_standard_deviation = variational.continuous_std[parent_num] new_mu = child_mean + rho * (child_standard_deviation / parent_standard_deviation) * (state_at_parent - parent_mean) new_sigma = sqrt((1-rho^2) * (child_standard_deviation)^2) @@ -167,11 +166,11 @@ function tree_logdensity(variational::MixedTreeReference, child_num, parent_num, end function tree_gradient(variational::MixedTreeReference, state) - gradient = zeros(length(variational.mean)) + gradient = zeros(length(variational.continuous_mean)) marginal_state = variable(state, variational.which_variable[1])[1] - marginal_mean = variational.mean[1] - marginal_standard_deviation = variational.standard_deviation[1] + marginal_mean = variational.continuous_mean[1] + marginal_standard_deviation = variational.continuous_std[1] gradient[1] = -(marginal_state - marginal_mean) / marginal_standard_deviation^2 for edge in variational.edge_set @@ -188,7 +187,7 @@ function tree_gradient(variational::MixedTreeReference, state) delta = -(state_at_child - mu) / sigma^2 gradient[child_idx] += delta - gradient[parent_idx] += delta * (edge[2] * variational.standard_deviation[child_idx] / variational.standard_deviation[parent_idx]) + gradient[parent_idx] += delta * (edge[2] * variational.continuous_std[child_idx] / variational.continuous_std[parent_idx]) end return gradient end @@ -201,7 +200,7 @@ LogDensityProblems.logdensity(log_potential::MixedTreeReference, x) = log_potential(x) function LogDensityProblems.dimension(log_potential::MixedTreeReference) - dim = length(log_potential.edge_set) + 1 + dim = length(variational.continuous_mean) + length(unique(first.(keys(variational.discrete_map)))) return dim end From 787e368a69b9fffa5230ca67e9fa9323c9378a08 Mon Sep 17 00:00:00 2001 From: Jack Date: Wed, 20 Aug 2025 17:31:50 -0700 Subject: [PATCH 26/27] update test --- test/test_tree.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_tree.jl b/test/test_tree.jl index 7fe07caad..d784fe6fa 100644 --- a/test/test_tree.jl +++ b/test/test_tree.jl @@ -28,7 +28,7 @@ end @test length(tree) == 4 expected = [(4, dummy, 1,5), (10, dummy, 5, 4), (8, dummy, 4, 2), (9, dummy, 5, 3)] - for edge in expected + for edge in tree @test edge in expected end end \ No newline at end of file From 7346e934826626ecb610294cae93c3f772da2d10 Mon Sep 17 00:00:00 2001 From: Jack Date: Wed, 20 Aug 2025 17:32:23 -0700 Subject: [PATCH 27/27] removed mixed tree --- src/Pigeons.jl | 2 +- src/includes.jl | 1 - src/variational/MixedTreeReference.jl | 215 ------------------------ src/variational/VariationalReference.jl | 2 +- 4 files changed, 2 insertions(+), 218 deletions(-) delete mode 100644 src/variational/MixedTreeReference.jl diff --git a/src/Pigeons.jl b/src/Pigeons.jl index 15ecf1b0b..2a9e36d57 100644 --- a/src/Pigeons.jl +++ b/src/Pigeons.jl @@ -69,7 +69,7 @@ export pigeons, Inputs, PT, # getting information out of an execution: stepping_stone, n_tempered_restarts, n_round_trips, process_sample, get_sample, # variational references: - GaussianReference, TreeReference, MixedTreeReference, DenseGaussianReference, + GaussianReference, TreeReference, DenseGaussianReference, # samplers SliceSampler, AutoMALA, Compose, AAPS, MALA, Mix diff --git a/src/includes.jl b/src/includes.jl index 3670260c1..e755b0036 100644 --- a/src/includes.jl +++ b/src/includes.jl @@ -80,7 +80,6 @@ include("pt/checks.jl") include("explorers/BufferedAD.jl") include("variational/GaussianReference.jl") include("variational/TreeReference.jl") -include("variational/MixedTreeReference.jl") include("variational/DenseGaussianReference.jl") include("variational/VariationalReference.jl") include("paths/ScaledPrecisionNormalPath.jl") diff --git a/src/variational/MixedTreeReference.jl b/src/variational/MixedTreeReference.jl deleted file mode 100644 index db8d5850e..000000000 --- a/src/variational/MixedTreeReference.jl +++ /dev/null @@ -1,215 +0,0 @@ -""" -A Gaussian tree variational reference -""" - -@kwdef mutable struct MixedTreeReference - edge_set::Vector{Tuple{Float64, Float64, Int, Int}} = Vector{Tuple{Float64, Float64, Int, Int}}() - continuous_mean::Vector{Float64} = Vector{Float64}() - continuous_std::Vector{Float64} = Vector{Float64}() - which_variable::Vector{Symbol} = Vector{Symbol}() - which_index::Vector{Int} = Vector{Int}() - iid_sample_set::Vector{Float64} = Vector{Float64}() - covariance_matrix::Matrix{Float64} = zeros(Float64, 0, 0) - discrete_map::Dict{Tuple{Symbol, Int, Symbol}, Tuple{Float64, Float64}} = Dict{Tuple{Symbol, Int, Symbol}, Tuple{Float64, Float64}}() - first_tuning_round::Int = 10 - - function MixedTreeReference(edge_set, continuous_mean, continuous_std, which_variable, which_index, iid_sample_set, covariance_matrix, discrete_map, first_tuning_round) - @assert first_tuning_round ≥ 1 - new(edge_set, continuous_mean, continuous_std, which_variable, which_index, iid_sample_set, covariance_matrix, discrete_map, first_tuning_round) - end -end - - -dim(variational::MixedTreeReference) = length(variational.continuous_mean) + length(unique(first.(keys(variational.discrete_map)))) -function activate_variational(variational::MixedTreeReference, iterators) - iterators.round ≥ variational.first_tuning_round ? true : false -end - -variational_recorder_builders(::MixedTreeReference) = [_transformed_online_full] - - -function update_reference!(reduced_recorders, variational::MixedTreeReference, state) - empty!(variational.continuous_mean) - empty!(variational.continuous_std) - empty!(variational.which_variable) - empty!(variational.which_index) - variational.edge_set = [] - - for var_name in continuous_variables(state) - temp_mean = get_transformed_statistic(reduced_recorders, var_name, Mean) - temp_std = sqrt.(get_transformed_statistic(reduced_recorders, var_name, Variance)) - - temp_dim = length(temp_mean) - for i = 1:temp_dim - push!(variational.continuous_mean, temp_mean[i]) - push!(variational.continuous_std, temp_std[i]) - - push!(variational.which_variable, var_name) - push!(variational.which_index, i) - end - end - @assert length(variational.continuous_mean) == length(variational.continuous_std) - - variational.iid_sample_set = zeros(length(variational.continuous_mean)) - variational.covariance_matrix = get_transformed_statistic(reduced_recorders, :singleton_variable, CovMatrix) - variational.edge_set = build_tree(variational) -end - - -function build_tree(variational::MixedTreeReference) - continuous_dim = length(variational.continuous_mean) - - adjacency_list::Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}} = Dict{Int, Vector{Tuple{Float64, Float64, Int, Int}}}() - for i in 1:continuous_dim - adjacency_list[i] = Vector{Tuple{Float64, Float64, Int, Int}}() - end - - for i = 1:continuous_dim - for j = (i+1):continuous_dim - normalization = (variational.continuous_std[i] * variational.continuous_std[j]) - rho = variational.covariance_matrix[i,j] / normalization - rho = clamp(rho, -0.99, 0.99) - I = -0.5*log(1-rho^2) - - push!(adjacency_list[i], (I, rho, i, j)) - push!(adjacency_list[j], (I, rho, j, i)) - end - end - root = 1 - return directed_max_tree(adjacency_list, root) -end - - -function directed_max_tree(adjacency_list, root) - total_number_of_nodes = length(keys(adjacency_list)) - mst = Vector{Tuple{Float64, Float64, Int, Int}}() - pq = BinaryMaxHeap{Tuple{Float64, Float64, Int, Int}}() - visited_nodes = Set{Int}() - - push!(visited_nodes, root) - for edge in adjacency_list[root] - push!(pq, edge) - end - - while !isempty(pq) && length(mst)