diff --git a/Project.toml b/Project.toml index 83d89a3..9bb9ac1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedMH" uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" -version = "0.5.5" +version = "0.5.6" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/mh-core.jl b/src/mh-core.jl index cfeb24e..261dd9f 100644 --- a/src/mh-core.jl +++ b/src/mh-core.jl @@ -191,6 +191,31 @@ function AbstractMCMC.step( return transition, transition end +""" + is_symmetric_proposal(proposal)::Bool + +Implementing this for a custom proposal will allow `AbstractMCMC.step` to avoid +computing the "Hastings" part of the Metropolis-Hasting log acceptance +probability (if the proposal is indeed symmetric). By default, +`is_symmetric_proposal(proposal)` returns `false`. The user is responsible for +determining whether a custom proposal distribution is indeed symmetric. As +noted in `MetropolisHastings`, `proposal` is a `Proposal`, `NamedTuple` of +`Proposal`, or `Array{Proposal}` in the shape of your data. +""" +is_symmetric_proposal(proposal) = false + +# The following univariate random walk proposals are symmetric. +is_symmetric_proposal(::RandomWalkProposal{<:Normal}) = true +is_symmetric_proposal(::RandomWalkProposal{<:MvNormal}) = true +is_symmetric_proposal(::RandomWalkProposal{<:TDist}) = true +is_symmetric_proposal(::RandomWalkProposal{<:Cauchy}) = true + +# The following multivariate random walk proposals are symmetric. +is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:Normal}}) = true +is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:MvNormal}}) = true +is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:TDist}}) = true +is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:Cauchy}}) = true + # Define the other sampling steps. # Return a 2-tuple consisting of the next sample and the the next state. # In this case they are identical, and either a new proposal (if accepted) @@ -206,8 +231,12 @@ function AbstractMCMC.step( params = propose(rng, spl, model, params_prev) # Calculate the log acceptance probability. - logα = logdensity(model, params) - logdensity(model, params_prev) + - q(spl, params_prev, params) - q(spl, params, params_prev) + logα = logdensity(model, params) - logdensity(model, params_prev) + + # Compute Hastings portion of ratio if proposal is not symmetric. + if !is_symmetric_proposal(spl.proposal) + logα += q(spl, params_prev, params) - q(spl, params, params_prev) + end # Decide whether to return the previous params or the new one. if -Random.randexp(rng) < logα diff --git a/test/runtests.jl b/test/runtests.jl index 2a4a05d..49d0ce9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,8 @@ using Test using DiffResults using ForwardDiff +include("util.jl") + @testset "AdvancedMH" begin # Set a seed Random.seed!(1234) @@ -104,7 +106,31 @@ using ForwardDiff @test chain1[1].params == val end - + + @testset "is_symmetric_proposal" begin + # True distributions + d1 = Normal(5, .7) + + # Model definition. + m1 = DensityModel(x -> logpdf(d1, x)) + + # Set up the proposal (StandardNormal is a custom distribution in "util.jl"). + p1 = RandomWalkProposal(StandardNormal()) + + # Implement `is_symmetric_proposal` for StandardNormal random walk proposal. + AdvancedMH.is_symmetric_proposal(::RandomWalkProposal{<:StandardNormal}) = true + + # Make sure `is_symmetric_proposal` behaves correctly. + @test AdvancedMH.is_symmetric_proposal(p1) + + # Sample from the posterior with initial parameters. + chain1 = sample(m1, MetropolisHastings(p1), 100000; + chain_type=StructArray, param_names=["x"]) + + @test mean(chain1.x) ≈ mean(d1) atol=0.05 + @test std(chain1.x) ≈ std(d1) atol=0.05 + end + @testset "MALA" begin # Set up the sampler. diff --git a/test/util.jl b/test/util.jl new file mode 100644 index 0000000..a2e0353 --- /dev/null +++ b/test/util.jl @@ -0,0 +1,4 @@ +# Define a (custom) Standard Normal distribution, for illustrative puspose. +struct StandardNormal <: Distributions.ContinuousUnivariateDistribution end +Distributions.logpdf(::StandardNormal, x::Real) = -(x ^ 2 + log(2 * pi)) / 2 +Distributions.rand(rng::AbstractRNG, ::StandardNormal) = randn(Random.GLOBAL_RNG)