Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a1f4b97
init TreeReference.jl class, added todos
jack-wu05 Jun 26, 2025
45d3510
added constructor
jack-wu05 Jun 28, 2025
92c0ab6
added logdensity evaluation for 1D variables
jack-wu05 Jun 29, 2025
7975cfd
refactored to use std instead of variance
jack-wu05 Jun 29, 2025
7cf863b
added greedy search algorithm
jack-wu05 Jun 30, 2025
61114a5
added update_ref! mechanism, light refactoring
jack-wu05 Jul 4, 2025
987b85e
added iid reference sampling mechanism
jack-wu05 Jul 4, 2025
e612b3a
added partial singleton ad implementation
jack-wu05 Jul 4, 2025
8794ac3
integrated treeref into codebase dependencies
jack-wu05 Jul 5, 2025
5bd2738
fixed compilation bugs
jack-wu05 Jul 6, 2025
e2f764e
refactored for dense/diag dispatch
jack-wu05 Jul 11, 2025
a73b5df
added covariance mechanism, improved code efficiency
jack-wu05 Jul 27, 2025
0723889
migrated dense reference from other branch
jack-wu05 Jul 27, 2025
f4f9270
incorporated references, covariance mechanism into dependencies
jack-wu05 Jul 27, 2025
d1177ae
added HMM problem for testing new references
jack-wu05 Jul 27, 2025
7cce4d0
added the new full recorder
jack-wu05 Jul 27, 2025
4842120
added AD gradient mechanism
jack-wu05 Jul 27, 2025
f5ce66e
Changed dimension for toy example
jack-wu05 Aug 6, 2025
858fccc
corrected gradient, fixed bugs
jack-wu05 Aug 6, 2025
5ce40d6
added jitter term for cholesky stability, refactored
jack-wu05 Aug 6, 2025
119c724
Refactored update for improved modularity
jack-wu05 Aug 6, 2025
23fc1f9
Updated dependencies
jack-wu05 Aug 12, 2025
823ae31
Added test case for tree reference
jack-wu05 Aug 12, 2025
16b4c51
init mixed tree reference class
jack-wu05 Aug 14, 2025
9a74755
renaming for clarity
jack-wu05 Aug 14, 2025
787e368
update test
jack-wu05 Aug 21, 2025
7346e93
removed mixed tree
jack-wu05 Aug 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.4.9"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down Expand Up @@ -66,6 +67,7 @@ AbstractPPL = "0.8.4, 0.9, 0.10, 0.11"
Bijectors = "0.13, 0.14, 0.15"
BridgeStan = "2"
DataFrames = "1"
DataStructures = "0.18.22"
DifferentiationInterface = "0.6.48"
Distributions = "0.25"
DocStringExtensions = "0.9"
Expand Down
153 changes: 153 additions & 0 deletions examples/stan/601-hmm.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
{
"N": 600,
"observations": [2.20041627e+00, 1.95769524e+00, -8.52087759e-01, -3.49801576e+00,
-2.75466819e+00, -7.84143831e+00, -7.00086086e+00, -1.00700985e+01,
-1.12625951e+01, -1.14712207e+01, -7.63410793e+00, -6.82938344e+00,
-8.94909208e+00, -1.05940554e+01, -1.01677096e+01, -1.07991865e+01,
-1.21380587e+01, -1.26044233e+01, -9.45156068e+00, -1.01137595e+01,
-1.20918822e+01, -8.58104339e+00, -9.46199003e+00, -6.16620129e+00,
-4.99686144e+00, -7.00417901e+00, -6.63052420e+00, -4.64233530e+00,
-3.70324733e+00, -3.32442030e+00, -4.88425406e-02, -5.51073794e-01,
-2.90981188e-02, 3.46099579e+00, -2.40716819e+00, -1.61163678e+00,
-5.16084106e-02, -3.91487830e+00, -2.05140073e+00, -1.28237909e+00,
-1.19528909e+00, -5.26269931e-01, -1.20447035e+00, 1.93218523e-01,
3.08078980e+00, 1.80337827e+00, 1.57832252e+00, 9.34944592e-01,
7.28679787e-01, 2.02044301e+00, 4.83385362e+00, 4.28691024e+00,
6.39388537e+00, 3.85840269e+00, 2.80378245e+00, 5.52823636e+00,
3.65488463e+00, 1.16215512e+00, -2.67688803e-01, -7.58287776e-01,
-7.02438545e-01, -6.81065563e-01, 4.78604164e+00, 2.50092926e+00,
4.90315224e+00, 2.88597671e+00, 5.79131041e+00, 4.51960872e+00,
2.96660337e+00, 2.05281631e+00, 2.94518154e+00, 3.80861207e+00,
3.57769218e+00, 4.72358097e+00, 5.27504634e+00, 2.72113257e+00,
5.62923032e+00, 4.89088010e+00, 4.85389967e+00, 2.52088724e+00,
1.68862456e+00, 1.45954408e+00, 1.67063240e+00, 3.21727514e+00,
-2.66006003e+00, 1.96019559e+00, -9.23003292e-01, 2.02073660e-03,
-1.10989865e+00, -3.57514428e+00, -1.29680815e+00, 1.62500296e-01,
-1.53874849e-01, -5.00880738e-02, 7.12391454e-01, 2.04897962e+00,
2.58430799e+00, -8.10793344e-02, 8.01232671e-01, 1.26356208e+00,
2.42295161e+00, 4.78682256e+00, 1.64580644e+00, 7.49921116e-01,
4.67797903e+00, 3.96665687e+00, 6.75268362e-01, 3.93938545e+00,
4.02636161e+00, 4.44271098e+00, 5.19042137e+00, 2.58723543e+00,
7.18735268e+00, 6.71400675e+00, 7.33104925e+00, 8.07291814e+00,
1.01641232e+01, 7.12567322e+00, 5.56664294e+00, 8.06508369e+00,
8.06292109e+00, 4.96482253e+00, 7.52618800e+00, 8.00781373e+00,
3.75931312e+00, 9.41197214e-01, 7.20632190e+00, 7.21676797e+00,
6.53389120e+00, 5.23761666e+00, 3.48087341e+00, 2.21970007e+00,
4.29720023e+00, 4.70966628e+00, 4.09379447e+00, 6.49047977e+00,
6.27283245e+00, 4.31824266e+00, 7.48014444e+00, 9.19625336e+00,
9.37229733e+00, 1.08146351e+01, 1.11088290e+01, 1.02598495e+01,
1.13194604e+01, 1.19072907e+01, 1.02393721e+01, 1.02645342e+01,
1.01497331e+01, 1.41101105e+01, 8.85600537e+00, 7.99558228e+00,
1.08893333e+01, 1.03842444e+01, 9.92441316e+00, 1.06790997e+01,
1.12610941e+01, 1.02065620e+01, 8.74433367e+00, 7.94298791e+00,
7.08367963e+00, 6.09791682e+00, 6.30358218e+00, 8.68326355e+00,
1.12112083e+01, 1.15750269e+01, 1.38522936e+01, 9.37299819e+00,
5.78993972e+00, 4.49590123e+00, 6.51469197e+00, 7.36551889e+00,
7.32101484e+00, 7.67118026e+00, 6.33513478e+00, 6.60466730e+00,
7.38936309e+00, 7.38584474e+00, 6.47593338e+00, 5.28929490e+00,
5.40888051e+00, 7.68078795e+00, 9.59892683e+00, 5.20831396e+00,
5.16531652e+00, 5.24548731e+00, 5.56486992e+00, 6.77264209e+00,
1.61973862e+00, 5.37412033e+00, 3.03930250e+00, 4.47814021e+00,
3.44894920e+00, 2.34977267e+00, 1.00274187e+00, 4.78782641e+00,
5.00688524e+00, 5.42917704e+00, 8.74264437e+00, 1.08546136e+01,
1.26508115e+01, 1.38369640e+01, 1.22559499e+01, 1.35193366e+01,
1.45558898e+01, 1.98940529e+01, 1.56230795e+01, 1.44094233e+01,
1.72861607e+01, 1.33902066e+01, 1.58302603e+01, 1.70575462e+01,
1.46093763e+01, 1.56669888e+01, 1.61703988e+01, 1.75911871e+01,
1.77323057e+01, 1.69723792e+01, 1.89714998e+01, 2.14827767e+01,
1.71719057e+01, 2.01649086e+01, 2.02643282e+01, 2.23352866e+01,
2.45251186e+01, 2.75240294e+01, 2.67906918e+01, 2.34239784e+01,
2.47779506e+01, 2.41075077e+01, 2.38932096e+01, 2.37542881e+01,
2.41000283e+01, 2.73505371e+01, 2.64419694e+01, 2.52751711e+01,
2.80470395e+01, 2.57457760e+01, 2.59011799e+01, 2.48515667e+01,
2.17817917e+01, 2.50912782e+01, 2.57779044e+01, 2.83524943e+01,
2.57348726e+01, 2.64355250e+01, 2.59956499e+01, 2.67754598e+01,
2.15529923e+01, 2.09129092e+01, 2.19243167e+01, 2.19974779e+01,
2.19960943e+01, 2.21024865e+01, 2.52887950e+01, 2.40804220e+01,
2.27381049e+01, 2.00207687e+01, 2.19041866e+01, 2.49407407e+01,
2.27611733e+01, 2.64211439e+01, 2.79037728e+01, 2.62650187e+01,
2.95082709e+01, 2.63434330e+01, 2.69195724e+01, 2.72645202e+01,
2.42261173e+01, 2.61695505e+01, 2.38557180e+01, 2.48194219e+01,
2.40467891e+01, 2.76022611e+01, 2.42635759e+01, 2.30312713e+01,
1.95100536e+01, 2.10866278e+01, 2.27955861e+01, 2.51106600e+01,
2.54197451e+01, 2.32745902e+01, 3.19060151e+01, 3.00306215e+01,
3.11005590e+01, 3.29650208e+01, 3.57058316e+01, 3.67591464e+01,
3.59652717e+01, 3.86935166e+01, 3.57911075e+01, 3.79579306e+01,
3.84987982e+01, 3.56516264e+01, 3.08808791e+01, 3.14919873e+01,
3.37580507e+01, 3.67153048e+01, 3.71685145e+01, 3.87817383e+01,
3.96963411e+01, 3.83564527e+01, 3.93413077e+01, 3.75321922e+01,
4.15746594e+01, 4.36385157e+01, 3.98425106e+01, 4.15453595e+01,
3.94474951e+01, 3.80429200e+01, 3.58932846e+01, 3.55484494e+01,
3.57650208e+01, 3.55469244e+01, 3.13031299e+01, 3.35075406e+01,
3.30004314e+01, 3.33909262e+01, 3.38864614e+01, 3.51344117e+01,
3.47694398e+01, 3.61716779e+01, 3.64870273e+01, 3.60862257e+01,
3.77519484e+01, 3.93528421e+01, 3.94286091e+01, 4.03535430e+01,
3.91922900e+01, 3.71025976e+01, 3.22791905e+01, 3.18571947e+01,
3.00682943e+01, 3.15075832e+01, 3.11689635e+01, 3.36183250e+01,
3.36876822e+01, 3.54169048e+01, 3.52179400e+01, 3.62088136e+01,
3.30185438e+01, 4.01510887e+01, 3.64507373e+01, 3.77575895e+01,
3.63044233e+01, 3.64477644e+01, 3.68859079e+01, 3.79617659e+01,
3.77345991e+01, 3.73876702e+01, 4.00806336e+01, 4.19850377e+01,
4.11555981e+01, 4.42295454e+01, 4.16191062e+01, 3.84851625e+01,
3.89853669e+01, 4.07976633e+01, 4.27051670e+01, 4.00062468e+01,
4.30978454e+01, 4.23601404e+01, 4.41922518e+01, 4.39965820e+01,
4.67794112e+01, 5.02978572e+01, 4.56351465e+01, 4.48239990e+01,
4.05876623e+01, 4.14133895e+01, 4.14028065e+01, 4.18386992e+01,
4.13414260e+01, 4.11379412e+01, 4.11907486e+01, 4.09471608e+01,
3.95399716e+01, 3.69100236e+01, 3.66860475e+01, 3.79205085e+01,
3.34496798e+01, 3.78740709e+01, 3.47468014e+01, 3.24855140e+01,
3.40201234e+01, 3.61644552e+01, 3.63968850e+01, 3.42714727e+01,
3.41741941e+01, 3.11705640e+01, 2.83597819e+01, 3.07470347e+01,
2.56121553e+01, 2.59008892e+01, 2.57916997e+01, 2.88839503e+01,
2.96389389e+01, 2.76744072e+01, 2.61646143e+01, 2.57337329e+01,
2.47063686e+01, 2.58303525e+01, 2.47858266e+01, 2.87182967e+01,
3.04326208e+01, 2.94320397e+01, 3.07871136e+01, 2.99429604e+01,
2.81048675e+01, 2.87370975e+01, 2.96210477e+01, 2.75839731e+01,
2.91019014e+01, 2.92451746e+01, 2.84477817e+01, 2.91807574e+01,
2.92471887e+01, 3.04269667e+01, 3.15569941e+01, 3.38214070e+01,
3.40074353e+01, 3.16579505e+01, 3.38981861e+01, 3.28793324e+01,
3.52557179e+01, 3.39301843e+01, 3.63010869e+01, 4.05274543e+01,
4.07032346e+01, 3.85513408e+01, 4.02953088e+01, 3.70458613e+01,
3.72419718e+01, 3.70039193e+01, 3.77764164e+01, 3.56303451e+01,
3.62959585e+01, 3.78946072e+01, 3.62709061e+01, 3.76634104e+01,
3.91337749e+01, 3.84004126e+01, 3.81241726e+01, 4.15581530e+01,
4.01142431e+01, 3.63541133e+01, 3.86243449e+01, 4.01430134e+01,
3.92610382e+01, 3.59249212e+01, 3.73114540e+01, 3.78569320e+01,
3.84827639e+01, 3.73911966e+01, 3.60235311e+01, 3.55317825e+01,
3.79022022e+01, 3.80714924e+01, 3.83597586e+01, 3.78601350e+01,
3.73587552e+01, 3.60520322e+01, 3.63999056e+01, 3.57917551e+01,
3.61548360e+01, 3.50228969e+01, 3.56458695e+01, 3.42961012e+01,
3.58128162e+01, 3.52237851e+01, 3.47537293e+01, 3.68640322e+01,
3.41322362e+01, 3.66964622e+01, 3.34686593e+01, 3.44682303e+01,
3.45441304e+01, 3.58217338e+01, 3.58910303e+01, 3.73505089e+01,
3.77782400e+01, 3.46415997e+01, 3.58564269e+01, 3.53247241e+01,
3.36910112e+01, 3.27083918e+01, 3.13210632e+01, 2.77079547e+01,
3.00823366e+01, 2.99031872e+01, 2.79379527e+01, 2.74939928e+01,
2.75225223e+01, 2.82877226e+01, 2.83256334e+01, 2.83580797e+01,
3.04150303e+01, 2.66598925e+01, 2.91794583e+01, 2.83726774e+01,
2.51347763e+01, 2.25963168e+01, 2.18457418e+01, 2.42479872e+01,
2.67623007e+01, 2.25385780e+01, 2.73269381e+01, 2.59379632e+01,
2.10230150e+01, 2.07278808e+01, 2.00192286e+01, 2.01415859e+01,
2.17867287e+01, 2.43700543e+01, 2.45504939e+01, 2.47710967e+01,
2.38890926e+01, 2.54430988e+01, 2.19310121e+01, 2.01861465e+01,
2.38134047e+01, 2.19633916e+01, 2.16153891e+01, 2.45069550e+01,
2.04463694e+01, 1.90655406e+01, 2.09490496e+01, 2.29965558e+01,
2.57579529e+01, 2.54823802e+01, 2.56986140e+01, 2.50657061e+01,
2.61572543e+01, 2.35304540e+01, 2.03513795e+01, 2.47338959e+01,
2.85361200e+01, 2.67531185e+01, 2.53851547e+01, 2.93807397e+01,
2.93031483e+01, 3.07399296e+01, 2.82640182e+01, 2.87514656e+01,
2.70374233e+01, 2.65811547e+01, 2.61418710e+01, 2.58550020e+01,
2.71578776e+01, 2.64238308e+01, 2.84622023e+01, 2.79479797e+01,
2.54783638e+01, 2.60418242e+01, 2.79443392e+01, 3.07094261e+01,
2.99400838e+01, 2.82979800e+01, 3.07511641e+01, 3.33209394e+01,
3.33287246e+01, 3.74255324e+01, 3.66325921e+01, 3.67588263e+01,
3.43328954e+01, 3.80575146e+01, 3.61886716e+01, 3.53697709e+01,
3.52197474e+01, 3.59416153e+01, 3.71610542e+01, 3.50041339e+01,
3.61407066e+01, 3.55777199e+01, 3.57989925e+01, 3.61486268e+01,
3.79354584e+01, 3.87753346e+01, 4.04170961e+01, 3.85544136e+01,
3.75831990e+01, 3.90366313e+01, 3.49632852e+01, 3.27692015e+01,
3.66075790e+01, 3.82664250e+01, 4.10635836e+01, 4.01232858e+01,
3.82659471e+01, 3.86222563e+01, 3.95396663e+01, 4.07177845e+01,
4.16468646e+01, 4.14681492e+01, 3.87589157e+01, 4.25642139e+01,
4.28169508e+01, 4.05391734e+01, 3.96425358e+01, 3.65827556e+01]
}
25 changes: 25 additions & 0 deletions examples/stan/601-hmm.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
data {
int<lower=1> N;
vector[N] observations;
}

parameters {
real log_sigma_transition;
vector[N] latents;
}

transformed parameters {
real<lower=0> sigma_transition = exp(log_sigma_transition);
}

model {
log_sigma_transition ~ normal(1,1);

latents[1] ~ normal(2,0.5);
observations[1] ~ normal(latents[1],1);

for (t in 2:N) {
latents[t] ~ normal(latents[t-1], sigma_transition);
observations[t] ~ normal(latents[t],1);
}
}
3 changes: 2 additions & 1 deletion src/Pigeons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ using StaticArraysCore
using Statistics
using StatsBase
using ZipFile
using DataStructures

import Base: Forward, @kwdef, show, print, merge, keys
import Base.Threads.@threads
Expand Down Expand Up @@ -68,7 +69,7 @@ export pigeons, Inputs, PT,
# getting information out of an execution:
stepping_stone, n_tempered_restarts, n_round_trips, process_sample, get_sample,
# variational references:
GaussianReference,
GaussianReference, TreeReference, DenseGaussianReference,
# samplers
SliceSampler, AutoMALA, Compose, AAPS, MALA, Mix

Expand Down
2 changes: 1 addition & 1 deletion src/explorers/GradientBasedSampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ function gradient_based_sampler_recorders!(recorders, explorer::GradientBasedSam
push!(recorders, buffers)
push!(recorders, Pigeons.ad_buffers)
if hasproperty(explorer, :preconditioner) && explorer.preconditioner isa AdaptedDiagonalPreconditioner
push!(recorders, _transformed_online) # for mass matrix adaptation
push!(recorders, _transformed_online_full) # for mass matrix adaptation
end
end
2 changes: 2 additions & 0 deletions src/includes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ include("targets/DistributionLogPotential.jl")
include("pt/checks.jl")
include("explorers/BufferedAD.jl")
include("variational/GaussianReference.jl")
include("variational/TreeReference.jl")
include("variational/DenseGaussianReference.jl")
include("variational/VariationalReference.jl")
include("paths/ScaledPrecisionNormalPath.jl")
include("explorers/invariance_test.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/pt/pigeons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ function explore!(pt, replica, explorer)
# for the online stats, we ignore pt.inputs.extended_traces
# because the recorders do not support grouping by chains
@record_if_requested!(replica.recorders, :online, extract_sample(replica.state, log_potential, pt.inputs.extractor))
@record_if_requested!(replica.recorders, :_transformed_online, replica.state)
@record_if_requested!(replica.recorders, :_transformed_online_full, replica.state)
end
if pt.inputs.extended_traces || is_target(pt.shared.tempering.swap_graphs, replica.chain)
@record_if_requested!(
Expand Down
19 changes: 15 additions & 4 deletions src/recorders/OnlineStateRecorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
See [`online()`](@ref).
"""
@kwdef struct OnlineStateRecorder
full::Bool = false
stats::Dict{Pair{Symbol, Type}, Any} = Dict{Pair{Symbol, Type}, Any}()
end

OnlineStateRecorder(from_another::OnlineStateRecorder) = OnlineStateRecorder(copy(from_another.stats))
OnlineStateRecorder(from_another::OnlineStateRecorder) = OnlineStateRecorder(from_another.full, copy(from_another.stats))

"""
$SIGNATURES
Expand All @@ -30,9 +31,12 @@ get_transformed_statistic(reduced_recorders, variable_name::Symbol, t::Type{T})
get_statistic(reduced_recorders, variable_name, t, false)

function get_statistic(reduced_recorders, variable_name::Symbol, ::Type{T}, original_param = true) where {T}
recorder = original_param ? reduced_recorders.online : reduced_recorders._transformed_online
recorder = original_param ? reduced_recorders.online : reduced_recorders._transformed_online_full
key = Pair(variable_name, T)
v = value(recorder.stats[key])
v = value(recorder.stats[key])
if T==CovMatrix
return v
end
return value.(v)
end

Expand Down Expand Up @@ -68,7 +72,7 @@ recorded_continuous_variables(state) = continuous_variables(state)
`OnlineStat` types to be computed when the [`online()`]
recorder is enabled.
"""
const registered_online_types = [Mean, Variance]
const registered_online_types = [Mean, Variance, CovMatrix]

"""
$SIGNATURES
Expand Down Expand Up @@ -101,6 +105,13 @@ initialize_online_state_recorder!(stats, state) =
initialize_online_state_recorder!(stats, state, stat_type)
end

initialize_online_state_recorder!(stats, state, ::Type{CovMatrix}) =
for name in recorded_continuous_variables(state)
var = variable(state, name)
key = Pair(name, CovMatrix)
stats[key] = CovMatrix()
end

initialize_online_state_recorder!(stats, state, ::Type{T}) where {T} =
for name in recorded_continuous_variables(state)
var = variable(state, name)
Expand Down
3 changes: 2 additions & 1 deletion src/recorders/recorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ transformed to be defined on an unconstrained space.
This is used internally by [`explorer`](@ref)'s for adaptation purposes
(in particular, pre-conditioning and variational references).
"""
@provides recorder _transformed_online() = OnlineStateRecorder()
@provides recorder _transformed_online() = OnlineStateRecorder()
@provides recorder _transformed_online_full() = OnlineStateRecorder(full=true)

"""
Restart and round-trip counts.
Expand Down
Loading
Loading