Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Graphs = "1"
HypothesisTests = "0.11"
Interpolations = "0.14, 0.15"
JSON = "0.21"
JuliaBUGS = "0.8, 0.9, 0.10"
JuliaBUGS = "0.10.2"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
LogExpFunctions = "0.3"
Expand Down
36 changes: 22 additions & 14 deletions examples/JuliaBUGS.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
function incomplete_count_data_model(;tau::Real=4)
function incomplete_count_data_model(; tau::Real = 4)
# Define the BUGS model
model_def = @bugs("model{
for (i in 1:n) {
r[i] ~ dbern(pr[i])
pr[i] <- ilogit(y[i] * alpha1 + alpha0)
y[i] ~ dpois(mu)
r[i] ~ dbern(pr[i])
pr[i] <- ilogit(y[i] * alpha1 + alpha0)
y[i] ~ dpois(mu)
}
mu ~ dgamma(1,1)
alpha0 ~ dnorm(0, 0.1)
alpha1 ~ dnorm(0, tau)
}",false,false
)
}", false, false)

# Associated data for the model
data = (
y = [
6,missing,missing,missing,missing,missing,missing,5,1,missing,1,missing,
missing,missing,2,missing,missing,0,missing,1,2,1,7,4,6,missing,missing,
missing,5,missing
],
r = [1,0,0,0,0,0,0,1,1,0,1,0,0,0,1,0,0,1,0,1,1,1,1,1,1,0,0,0,1,0],
n = 30, tau = tau
6, missing, missing, missing, missing, missing, missing, 5, 1, missing, 1, missing,
missing, missing, 2, missing, missing, 0, missing, 1, 2, 1, 7, 4, 6, missing, missing,
missing, 5, missing
],
r = [1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0],
n = 30,
tau = tau,
)
return compile(model_def, data)

# Compile and return a JuliaBUGS.BUGSModel in the original (constrained) space
# SliceSampler expects constrained targets for correct support handling
return JuliaBUGS.settrans(JuliaBUGS.compile(model_def, data))
end
incomplete_count_data(;kwargs...) = JuliaBUGSPath(incomplete_count_data_model(;kwargs...))

# Convenience wrapper returning a Pigeons path for the compiled model
incomplete_count_data(; kwargs...) = JuliaBUGSPath(incomplete_count_data_model(; kwargs...))
16 changes: 12 additions & 4 deletions ext/PigeonsJuliaBUGSExt/PigeonsJuliaBUGSExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ module PigeonsJuliaBUGSExt
using Pigeons
if isdefined(Base, :get_extension)
import JuliaBUGS
using AbstractPPL # only need because we rewrite JuliaBUGS.getparams
using Bijectors # only need because we rewrite JuliaBUGS.getparams
using AbstractPPL # needed for AbstractPPL.get and parameter access
using Bijectors # needed for transformations in getparams
using DocStringExtensions
using SplittableRandoms: SplittableRandom, split
using Random
else
import ..JuliaBUGS
using ..AbstractPPL # only need because we rewrite JuliaBUGS.getparams
using ..Bijectors # only need because we rewrite JuliaBUGS.getparams
using ..AbstractPPL # needed for AbstractPPL.get and parameter access
using ..Bijectors # needed for transformations in getparams
using ..DocStringExtensions
using ..SplittableRandoms: SplittableRandom, split
using ..Random
Expand All @@ -21,4 +21,12 @@ include(joinpath(@__DIR__, "utils.jl"))
include(joinpath(@__DIR__, "interface.jl"))
include(joinpath(@__DIR__, "invariance_test.jl"))

# Provide a no-op explorer as default for JuliaBUGS targets when no explorer
# is explicitly requested by the user/tests. Other tests explicitly pass an
# explorer (e.g., SliceSampler), so this only affects ad-hoc runs.
struct NoOpExplorer end
Pigeons.step!(::NoOpExplorer, replica, shared) = nothing
Pigeons.explorer_recorder_builders(::NoOpExplorer) = []
Pigeons.default_explorer(::Pigeons.JuliaBUGSPath) = NoOpExplorer()

end
167 changes: 137 additions & 30 deletions ext/PigeonsJuliaBUGSExt/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@

# Initialization and iid sampling
function evaluate_and_initialize(model::JuliaBUGS.BUGSModel, rng::AbstractRNG)
new_env = first(JuliaBUGS.evaluate!!(rng, model)) # sample a new evaluation environment
new_env, _ = JuliaBUGS.Model.evaluate_with_rng!!(rng, model) # sample a new evaluation environment
return JuliaBUGS.initialize!(model, new_env) # set the private_model's environment to the newly created one
end

# used for both initializing and iid sampling
# Note: state is a flattened vector of the parameters
# Also, the vector is **concretely typed**. This means that if the evaluation
# environment contains floats and integers, the latter will be cast to float.
_sample_iid(model::JuliaBUGS.BUGSModel, rng::AbstractRNG) =
getparams(evaluate_and_initialize(model, rng)) # flatten the unobserved parameters in the model's eval environment and return
# Draw a single prior sample and return flattened parameters
sample_params_from_prior(model::JuliaBUGS.BUGSModel, rng::AbstractRNG) =
getparams(evaluate_and_initialize(model, rng))


# Note: JuliaBUGS.getparams creates a new vector on each call, so it is safe
# to call _sample_iid during initialization (**sequentially**, as done as of time
# of writing) for different Replicas (i.e., they won't share the same state).
Pigeons.initialization(target::JuliaBUGSPath, rng::AbstractRNG, _::Int64) =
_sample_iid(target.model, rng)
sample_params_from_prior(target.model, rng)

# target is already a Path
Pigeons.create_path(target::JuliaBUGSPath, ::Inputs) = target
Expand Down Expand Up @@ -58,40 +56,149 @@ function Pigeons.interpolate(path::JuliaBUGSPath, beta)
JuliaBUGSLogPotential(private_model, beta)
end

# log_potential evaluation
(log_potential::JuliaBUGSLogPotential)(flattened_values) =
try
log_prior, _, tempered_log_joint = last(
JuliaBUGS._tempered_evaluate!!(
log_potential.private_model,
flattened_values;
temperature=log_potential.beta
)
# log_potential evaluation for Vector state (legacy, but kept for compatibility)
function (log_potential::JuliaBUGSLogPotential)(flattened_values::AbstractVector)
try
# Evaluate at given values using JuliaBUGS 0.10 API
# evaluate_with_values!! returns (evaluation_env, log_densities_namedtuple)
_, log_densities = JuliaBUGS.Model.evaluate_with_values!!(
log_potential.private_model,
flattened_values;
temperature=log_potential.beta,
transformed=log_potential.private_model.transformed,
)
# avoid potential 0*Inf (= NaN)
# log_densities is a NamedTuple with fields: logprior, loglikelihood, tempered_logjoint
log_prior = log_densities.logprior
tempered_log_joint = log_densities.tempered_logjoint
# avoid potential 0*Inf (= NaN)
return iszero(log_potential.beta) ? log_prior : tempered_log_joint
catch e
(isa(e, DomainError) || isa(e, BoundsError)) && return -Inf
rethrow(e)
end
end

# iid sampling
function Pigeons.sample_iid!(log_potential::JuliaBUGSLogPotential, replica, shared)
replica.state = _sample_iid(log_potential.private_model, replica.rng)
# iid sampling - extract parameters as Vector to match initialization type
function Pigeons.sample_iid!(log_potential::JuliaBUGSLogPotential, replica, ::Pigeons.Shared)
# Draw exactly from the reference (prior) when beta=0
# Reference chains are the only ones for which sample_iid! is invoked.
replica.state = sample_params_from_prior(log_potential.private_model, replica.rng)
end

# parameter names
Pigeons.sample_names(::Vector, log_potential::JuliaBUGSLogPotential) =
[(Symbol(string(vn)) for vn in log_potential.private_model.parameters)...,:log_density]
# parameter names for Vector state
Pigeons.sample_names(::Vector, log_potential::JuliaBUGSLogPotential) =
[(Symbol(string(vn)) for vn in JuliaBUGS.parameters(log_potential.private_model))..., :log_density]

# extract samples for Vector state
Pigeons.extract_sample(state::Vector, log_potential::JuliaBUGSLogPotential) =
vcat(state, log_potential(state))

# Parallelism invariance
Pigeons.recursive_equal(a::Union{JuliaBUGSPath,JuliaBUGSLogPotential}, b) =
Pigeons._recursive_equal(a,b)
function Pigeons.recursive_equal(a::T, b) where T <: JuliaBUGS.BUGSModel
Pigeons._recursive_equal(a, b)
# just check the betas match, the model is already checked within path
Pigeons.recursive_equal(a::AbstractVector{<:JuliaBUGSLogPotential}, b) =
all(lp1.beta == lp2.beta for (lp1, lp2) in zip(a, b))

# BUGSModel-specific equality: compare only stable fields to avoid nondeterminism
# in evaluation caches and generated functions while preserving true model identity.
function Pigeons.recursive_equal(a::T, b) where {T<:JuliaBUGS.BUGSModel}
included = (:transformed, :model_def, :data)
excluded = Tuple(setdiff(fieldnames(T), included))
Pigeons._recursive_equal(a,b,excluded)
return Pigeons._recursive_equal(a, b, excluded)
end

#######################################
# Robustness for log-ratio evaluation
#######################################

# Numerically robust log-ratio for a vector of JuliaBUGS log-potentials.
#
# Rationale:
# - At support boundaries, the same state evaluated at two betas can give
# lp_num == -Inf and lp_den == -Inf (or both Inf). The default lp_num - lp_den
# becomes (-Inf) - (-Inf) or (Inf - Inf) which is NaN.
# - These NaNs would leak into swap statistics and PT adaptation, tripping the
# “no NaN log potentials” test and destabilizing schedules.
# Policy:
# - If both sides are infinite with the same sign, return 0.0 (ratio 1.0) to
# stay well-defined and avoid NaNs. Otherwise, compute the normal difference.
function Pigeons.log_unnormalized_ratio(
lps::AbstractVector{<:JuliaBUGSLogPotential},
numerator::Int,
denominator::Int,
state,
)
lp_num = lps[numerator](state)
lp_den = lps[denominator](state)
if (lp_num == -Inf && lp_den == -Inf) || (lp_num == Inf && lp_den == Inf)
# Indeterminate form; map to a neutral, finite value to avoid NaNs.
return 0.0
end
ans = lp_num - lp_den
if isnan(ans)
# If a NaN still arises, surface a clear error with context.
error("Got NaN log-unnormalized ratio; Dumping information:\n\tlp_num=$lp_num\n\tlp_den=$lp_den\n\tState=$state")
end
return ans
end

#######################################
# Robustness for swap acceptance
#######################################

# Robust swap acceptance for JuliaBUGS paths.
#
# Rationale:
# - Acceptance uses exp(stat1.log_ratio + stat2.log_ratio). With +Inf and -Inf
# from two chains, the sum becomes NaN (0/0) and breaks stats/adaptation.
# Policy:
# - Finite sum: standard min(1, exp(sum)).
# - NaN from opposite infinities: treat as 1.0 (neutral accept for 0/0 case).
# - ±Inf sum: accept if +Inf, reject if -Inf.
_is_opposite_infinities(a, b) = isinf(a) && isinf(b) && signbit(a) != signbit(b)

function _robust_acceptance(stat1, stat2)
s = stat1.log_ratio + stat2.log_ratio
if isfinite(s)
return min(1.0, exp(s))
elseif isnan(s)
# (+Inf) + (-Inf) → NaN; interpret as a neutral accept.
return _is_opposite_infinities(stat1.log_ratio, stat2.log_ratio) ? 1.0 : 0.0
else
# s is ±Inf
return s > 0 ? 1.0 : 0.0
end
end

# Recorder hook using the robust acceptance above to avoid NaNs in recorded
# statistics, while still recording the original log-ratios for diagnostics.
function Pigeons.record_swap_stats!(
pair_swapper::AbstractVector{<:JuliaBUGSLogPotential},
recorders,
chain1::Int,
stat1,
chain2::Int,
stat2,
)
acceptance_pr = _robust_acceptance(stat1, stat2)
key1 = (chain1, chain2)
key2 = (chain2, chain1)
Pigeons.@record_if_requested!(recorders, :swap_acceptance_pr, (key1, acceptance_pr))
Pigeons.@record_if_requested!(recorders, :log_sum_ratio, (key1, stat1.log_ratio))
Pigeons.@record_if_requested!(recorders, :log_sum_ratio, (key2, stat2.log_ratio))
end

# Swap decision mirroring the robust acceptance calculation. Identical to the
# default when the sum is finite; well-defined in indeterminate cases.
function Pigeons.swap_decision(
pair_swapper::AbstractVector{<:JuliaBUGSLogPotential},
chain1::Int,
stat1,
chain2::Int,
stat2,
)
acceptance_pr = _robust_acceptance(stat1, stat2)
uniform = chain1 < chain2 ? stat1.uniform : stat2.uniform
return uniform < acceptance_pr
end
# just check the betas match, the model is already checked within path
Pigeons.recursive_equal(a::AbstractVector{<:JuliaBUGSLogPotential}, b) =
all(lp1.beta == lp2.beta for (lp1,lp2) in zip(a,b))
2 changes: 1 addition & 1 deletion ext/PigeonsJuliaBUGSExt/invariance_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function Pigeons.forward_sample_condition_and_explore(
# maybe condition the model using the sampled observations
conditioned_model = if length(condition_on) > 0
var_group = [JuliaBUGS.VarName{sym}() for sym in condition_on] # transform Symbols into VarNames
JuliaBUGS.condition(model, var_group)
JuliaBUGS.condition(model, var_group; regenerate_log_density=false)
else
model
end
Expand Down
Loading
Loading