Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Pigeons"
uuid = "0eb8d820-af6a-4919-95ae-11206f830c31"
authors = ["Alexandre Bouchard-Côté <[email protected]>, Nikola Surjanovic <[email protected]>, Paul Tiede <[email protected]>, Trevor Campbell <[email protected]>, Miguel Biron-Lattes <[email protected]>, Saifuddin Syed <[email protected]>"]
version = "0.4.9"
version = "0.5.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
53 changes: 27 additions & 26 deletions src/includes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ include("swap/VariationalDEO.jl")
include("submission/presets.jl")
include("submission/Submission.jl")
include("submission/Result.jl")
include("submission/MPISettings.jl")
include("schedules/Schedule.jl")
include("tempering/adaptation.jl")
include("schedules/discretize.jl")
include("tempering/adaptation.jl")
include("recorders/NonReproducible.jl")
include("recorders/LogSum.jl")
include("recorders/LogPotentialExtractor.jl")
Expand All @@ -25,43 +26,45 @@ include("pt/report.jl")
include("pt/plots.jl")
include("pt/Iterators.jl")
include("pt/Inputs.jl")
include("targets/StreamTarget.jl")
include("pt/Shared.jl")
include("swap/swap_graphs.jl")
include("pt/PT.jl")
include("tempering/NonReversiblePT.jl")
include("tempering/StabilizedPT.jl")
include("tempering/tempering.jl")
include("swap/swap_graph.jl")
include("submission/MPIProcesses.jl")
include("targets/BlangTarget.jl")
include("submission/MPISettings.jl")
include("submission/ChildProcess.jl")
include("targets/LazyTarget.jl")
include("submission/submission_utils.jl")
include("replicas/Replica.jl")
include("swap/pair_swapper.jl")
include("recorders/RoundTripRecorder.jl")
include("recorders/OnlineStateRecorder.jl")
include("recorders/RoundTripRecorder.jl")
include("recorders/recorder.jl")
include("pt/state.jl")
include("pt/process_sample.jl")
include("replicas/Replica.jl")
include("targets/DistributionLogPotential.jl")
include("submission/ChildProcess.jl")
include("submission/MPIProcesses.jl")
include("submission/submission_utils.jl")
include("swap/swap_graph.jl")
include("pt/pigeons.jl")
include("swap/swap_graphs.jl")
include("pt/checkpoint.jl")
include("pt/state.jl")
include("targets/LazyTarget.jl")
include("targets/StreamTarget.jl")
include("targets/BlangTarget.jl")
include("tempering/NonReversiblePT.jl")
include("tempering/StabilizedPT.jl")
include("pt/process_sample.jl")
include("swap/pair_swapper.jl")
include("tempering/tempering.jl")
include("paths/path.jl")
include("paths/JuliaBUGSPath.jl")
include("paths/InterpolatedLogPotential.jl")
include("targets/TuringLogPotential.jl")
include("targets/StanLogPotential.jl")
include("paths/InterpolatingPath.jl")
include("variational/variational.jl")
include("paths/MultiStepsInterpolatingPath.jl")
include("targets/StanLogPotential.jl")
include("targets/TuringLogPotential.jl")
include("mpi_utils/misc_mpi_utils.jl")
include("mpi_utils/LoadBalance.jl")
include("mpi_utils/Entangler.jl")
include("mpi_utils/PermutedDistributedArray.jl")
include("replicas/EntangledReplicas.jl")
include("recorders/recorders.jl")
include("swap/swap.jl")
include("replicas/replicas.jl")
include("recorders/recorders.jl")
include("log_potentials/log_potentials.jl")
include("log_potentials/log_potential.jl")
include("explorers/hamiltonian_dynamics.jl")
Expand All @@ -75,16 +78,14 @@ include("explorers/MALA.jl")
include("explorers/AutoMALA.jl")
include("explorers/Compose.jl")
include("explorers/Augmentation.jl")
include("targets/DistributionLogPotential.jl")
include("pt/checks.jl")
include("explorers/BufferedAD.jl")
include("variational/GaussianReference.jl")
include("variational/VariationalReference.jl")
include("paths/ScaledPrecisionNormalPath.jl")
include("explorers/invariance_test.jl")
include("targets/toy_mvn_target.jl")
include("variational/GaussianReference.jl")
include("pt/checks.jl")
include("variational/VariationalReference.jl")
include("explorers/AAPS.jl")
include("explorers/GradientBasedSampler.jl")
include("evidence/stepping_stone.jl")
include("api.jl")
include("paths/JuliaBUGSPath.jl")
23 changes: 23 additions & 0 deletions src/paths/InterpolatingPath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,26 @@
interpolate(::LinearInterpolator, ref_log_potential, target_log_potential, beta) =
(1.0 - beta) * ref_log_potential + beta * target_log_potential


"""
$SIGNATURES

[`InterpolatingPath`](@ref) is already a [`path`](@ref), so return it.
"""
create_path(target::InterpolatingPath, inputs::Inputs) = target_is_already_a_path(target, inputs)

Check warning on line 35 in src/paths/InterpolatingPath.jl

View check run for this annotation

Codecov / codecov/patch

src/paths/InterpolatingPath.jl#L35

Added line #L35 was not covered by tests

"""
$SIGNATURES

Utility to allow the user to specify a path of distribution in inputs.target.
In that case, we verify that inputs.reference was not specified.
"""
function target_is_already_a_path(target, inputs)
if !isnothing(inputs.reference)
error("Conflicting options: inputs.target is already a path, but an explicit inputs.reference was provided.")
end
return target
end

initialization(target::InterpolatingPath, rng::AbstractRNG, replica_index::Int64) =

Check warning on line 50 in src/paths/InterpolatingPath.jl

View check run for this annotation

Codecov / codecov/patch

src/paths/InterpolatingPath.jl#L50

Added line #L50 was not covered by tests
initialization(target.target, rng, replica_index)
54 changes: 54 additions & 0 deletions src/paths/MultiStepsInterpolatingPath.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Paths based on interpolating an arbitrary list of log_potential's

@auto struct MultiStepsInterpolatingPath
knots
interpolator
end

"""
$SIGNATURES

Given a list of [`log_potential`](@ref) ("knots"),
return a [`path`](@ref) interpolating between them.

By default, the `interpolator` is a `LinearInterpolator`, i.e.
standard annealing.
"""
@provides path MultiStepsInterpolatingPath(knots) = MultiStepsInterpolatingPath(knots, LinearInterpolator())

function interpolate(path::MultiStepsInterpolatingPath, beta)

# The task here is to reduce the interpolation problem to a "segment" between
# two adjacent knots. For example if n_knots = 3, and beta = 1/3, then in the following
# we would grab the fist and second knots, and rescale beta to 2/3.

n_knots = length(path.knots)

# First, find the two end points of the segment
ref_index = ref_knot_index(beta, n_knots)
target_index = ref_index + 1

# Second, rescale beta to be interpolating between the segment's endpoints
delta = 1.0 / (n_knots - 1) # difference in original beta param between knots
baseline_beta = (ref_index - 1) * delta
@assert baseline_beta ≤ beta "$baseline_beta $beta "
rescaled_beta = (beta - baseline_beta) / delta

# The problem is now reduced to a segment, i.e. a path with only 2 knots:
segment = InterpolatingPath(path.knots[ref_index], path.knots[target_index], path.interpolator)

# We can now invoke the machinery constructed for the special case with 2 knots:
return InterpolatedLogPotential(segment, rescaled_beta)
end

ref_knot_index(beta, n_knots) = max(1, ceil(Int, beta * (n_knots - 1)))

"""
$SIGNATURES

[`MultiStepsInterpolatingPath`](@ref) is already a [`path`](@ref), so return it.
"""
create_path(target::MultiStepsInterpolatingPath, inputs::Inputs) = target_is_already_a_path(target, inputs)

initialization(target::MultiStepsInterpolatingPath, rng::AbstractRNG, replica_index::Int64) =
initialization(target.knots[end], rng, replica_index)
5 changes: 3 additions & 2 deletions src/paths/ScaledPrecisionNormalPath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ analytic_lognormalization(path::ScaledPrecisionNormalPath) =
"""
$SIGNATURES

In this case, the target is already a [`path`](@ref), so return it.
[`ScaledPrecisionNormalPath`](@ref) is already a [`path`](@ref), so return it.
"""
create_path(target::ScaledPrecisionNormalPath, inputs::Inputs) = target
create_path(target::ScaledPrecisionNormalPath, inputs::Inputs) = target_is_already_a_path(target, inputs)

4 changes: 4 additions & 0 deletions src/pt/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ function preflight_checks(inputs::Inputs)
To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
"""
end
if !isnothing(inputs.variational) && inputs.target isa MultiStepsInterpolatingPath
# would need to fit variational to closest knot instead of target
throw(ArgumentError("Variational inference not currently supported with MultiStepsInterpolatingPath"))
end
if mpi_active() && !inputs.checkpoint
@warn "To be able to call Pigeons.load() to retrieve samples in-memory, use pigeons(target = ..., checkpoint = true)"
end
Expand Down
7 changes: 5 additions & 2 deletions src/targets/toy_mvn_target.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ such that i.i.d. sampling is possible at all chains (via [`ToyExplorer`](@ref)).
"""
@provides target toy_mvn_target(dim::Int) = ScaledPrecisionNormalPath(dim)

initialization(target::ScaledPrecisionNormalPath, rng::AbstractRNG, _::Int64) =
randn(rng, target.dim) / sqrt(target.precision1)
initialization(target::ScaledPrecisionNormalLogPotential, _::AbstractRNG, _::Int64) =
zeros(target.dim)

initialization(target::ScaledPrecisionNormalPath, _::AbstractRNG, _::Int64) =
zeros(target.dim)

default_explorer(::ScaledPrecisionNormalPath) = ToyExplorer()

Expand Down
6 changes: 2 additions & 4 deletions src/variational/variational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ function update_path_if_needed(path, reduced_recorders, iterators, variational,
end
end

function update_path_variational(path, reduced_recorders, variational, state)
function update_path_variational(path::InterpolatingPath, reduced_recorders, variational, state)
update_reference!(reduced_recorders, variational, state)
path = InterpolatingPath(variational, path.target)
return path
return InterpolatingPath(variational, path.target, path.interpolator)
end

34 changes: 34 additions & 0 deletions test/supporting/rna-example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
function rna_example(_data_sizes = [0, typemax(Int)])
data_sizes = sort(_data_sizes)
@assert data_sizes[1] == 0
n_knots = length(data_sizes)
@assert n_knots > 1

example_dir = dirname(dirname(pathof(Pigeons))) * "/examples/"
data = readdlm(example_dir * "data/Ballnus_et_al_2017_M1a.csv", ',')
N = size(data, 1)
@assert data_sizes[end] == typemax(Int) || data_sizes[end] == N

# prior
# use a DistributionLogPotential based on the prior to enable iid sampling
# note: need to work on unconstrained_parameters. We use Bijectors.transformed
# to achieve this automatically, because their bijection for scalar Uniforms
# is the same as the one used in Stan (logit <-> logistic)
prior_ref = DistributionLogPotential(product_distribution(
transformed.([Uniform(-2,1), Uniform(-5,5), Uniform(-5,5), Uniform(-5,5), Uniform(-2,2)])
))

knots = (prior_ref, )

model_file = example_dir * "stan/mRNA.stan"
for knot_index in 2:n_knots
cur_N = min(N, data_sizes[knot_index])
@assert cur_N > 0
ts = data[1:cur_N,1]
ys = data[1:cur_N,2]
knot = StanLogPotential(model_file, Pigeons.json(; N = cur_N, ts, ys))
knots = (knots..., knot)
end

return Pigeons.MultiStepsInterpolatingPath(knots)
end
22 changes: 3 additions & 19 deletions test/test_DistributionLogPotential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Bijectors
using DelimitedFiles

include("supporting/mpi_test_utils.jl")
include("supporting/rna-example.jl")
n_mpis = set_n_mpis_to_one_on_windows(2)

@testset "DLP: Multivariate" begin
Expand Down Expand Up @@ -32,32 +33,15 @@ end

@static if !is_windows_in_CI()
@testset "DLP: Stan interface + iid sampling" begin
# recreate Fig. 6 in Ballnus et al. (2017)

# load data
dta_path = joinpath(dirname(@__DIR__), "examples", "data", "Ballnus_et_al_2017_M1a.csv")
dta = readdlm(dta_path, ',')
N = size(dta, 1)
ts = dta[:,1]
ys = dta[:,2]
# Example of non-standard interpolation: first prior to a subsampling of 10 observations, then to full

# create model target
model_file = joinpath(dirname(@__DIR__), "examples", "stan", "mRNA.stan")
mRNA_target = StanLogPotential(model_file, Pigeons.json(; N, ts, ys))

# use a DistributionLogPotential based on the prior to enable iid sampling
# note: need to work on unconstrained_parameters. We use Bijectors.transformed
# to achieve this automatically, because their bijection for scalar Uniforms
# is the same as the one used in Stan (logit <-> logistic)
prior_ref = DistributionLogPotential(product_distribution(
transformed.([Uniform(-2,1), Uniform(-5,5), Uniform(-5,5), Uniform(-5,5), Uniform(-2,2)])
))
mRNA_target = rna_example([0, 10, typemax(Int)])

# run
explorer = Compose(SliceSampler(), AutoMALA())
results = pigeons(
target = mRNA_target,
reference = prior_ref,
record = [round_trip; record_default()],
multithreaded = false,
n_chains = 2, # low to avoid slowing down CI; in reality, Λ ~ 6
Expand Down
26 changes: 24 additions & 2 deletions test/test_cumulative_barrier.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testset "Cumulative barrier" begin
target = toy_mvn_target(2)
pt = pigeons(; target, explorer = SliceSampler(), n_rounds = 15);
pt = pigeons(; target, n_rounds = 15);

truth = Pigeons.analytic_cumulativebarrier(target)
barriers = pt.shared.tempering.communication_barriers
Expand All @@ -9,4 +9,26 @@
for beta in 0.0:0.1:1.0
@test abs(estimated_cum(beta) - truth(beta)) < 0.01
end
end
end

@testset "Cumulative barrier multi knots" begin
target = toy_mvn_target(2)
pt_ref = pigeons(; target, n_rounds = 15)

# here we pick a middle point which happens to be in the linear path
# so that should keep the global barrier invariant in that special case

middle_point = 0.2
knots = map(beta -> Pigeons.interpolate(target, beta), [0.0, middle_point, 1.0])

multistep_target = Pigeons.MultiStepsInterpolatingPath(knots)
pt = pigeons(; target = multistep_target, n_rounds = 15);

@test abs(Pigeons.global_barrier(pt) - Pigeons.global_barrier(pt_ref)) < 0.01
@test abs(stepping_stone(pt) - stepping_stone(pt_ref)) < 0.001

@test_throws "Conflicting options" pigeons(; target = multistep_target, reference = Pigeons.interpolate(target, 0.1))

@test_throws "Variational inference not currently supported" pigeons(; target = multistep_target, variational = variational = GaussianReference())
end

Loading