diff --git a/Project.toml b/Project.toml index 178e4c3a..dd32ed55 100644 --- a/Project.toml +++ b/Project.toml @@ -3,8 +3,10 @@ uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" version = "0.7.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @@ -17,13 +19,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" [extensions] -AdvancedHMCADTypesExt = "ADTypes" AdvancedHMCCUDAExt = "CUDA" AdvancedHMCMCMCChainsExt = "MCMCChains" AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq" @@ -33,6 +33,7 @@ ADTypes = "1" AbstractMCMC = "5.6" ArgCheck = "1, 2" CUDA = "3, 4, 5" +DifferentiationInterface = "0.6.49" DocStringExtensions = "0.8, 0.9" LinearAlgebra = "<0.1, 1" LogDensityProblems = "2" diff --git a/ext/AdvancedHMCADTypesExt.jl b/ext/AdvancedHMCADTypesExt.jl deleted file mode 100644 index 09159726..00000000 --- a/ext/AdvancedHMCADTypesExt.jl +++ /dev/null @@ -1,24 +0,0 @@ -module AdvancedHMCADTypesExt - -using AdvancedHMC: - AbstractMetric, LogDensityModel, Hamiltonian, LogDensityProblems, LogDensityProblemsAD -using ADTypes: AbstractADType - -function Hamiltonian( - metric::AbstractMetric, ℓπ::LogDensityModel, kind::AbstractADType; kwargs... -) - return Hamiltonian(metric, ℓπ.logdensity, kind; kwargs...) -end -function Hamiltonian(metric::AbstractMetric, ℓπ, kind::AbstractADType; kwargs...) - if LogDensityProblems.capabilities(ℓπ) === nothing - throw( - ArgumentError( - "The log density function does not support the LogDensityProblems.jl interface", - ), - ) - end - ℓ = LogDensityProblemsAD.ADgradient(kind, ℓπ; kwargs...) - return Hamiltonian(metric, ℓ) -end - -end diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 51cf5f57..05548c7b 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -1,5 +1,6 @@ module AdvancedHMC +using ADTypes: AbstractADType, AutoForwardDiff using Statistics: mean, var, middle using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling @@ -21,7 +22,10 @@ using AbstractMCMC: AbstractMCMC, LogDensityModel import StatsBase: sample +import DifferentiationInterface: DifferentiationInterface as DI + const DEFAULT_FLOAT_TYPE = typeof(float(0)) +const DEFAULT_ADTYPE = AutoForwardDiff() include("utilities.jl") @@ -134,48 +138,32 @@ function Hamiltonian(metric::AbstractMetric, ℓ::LogDensityModel; kwargs...) return Hamiltonian(metric, ℓ.logdensity; kwargs...) end function Hamiltonian(metric::AbstractMetric, ℓ; kwargs...) - cap = LogDensityProblems.capabilities(ℓ) - if cap === nothing - throw( - ArgumentError( - "The log density function does not support the LogDensityProblems.jl interface", - ), - ) - end - # Check if we're capable of computing gradients. - ℓπ = if cap === LogDensityProblems.LogDensityOrder{0}() - # In this case ℓ does not support evaluation of the gradient of the log density function - # We use ForwardDiff to compute the gradient - LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...) - else - # In this case ℓ already supports evaluation of the gradient of the log density function - ℓ - end - return Hamiltonian( - metric, - Base.Fix1(LogDensityProblems.logdensity, ℓπ), - Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓπ), - ) + return Hamiltonian(metric, ℓ, DEFAULT_ADTYPE; kwargs...) end ## With explicit AD specification function Hamiltonian( - metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val,Module}; kwargs... + metric::AbstractMetric, ℓπ::LogDensityModel, adtype::AbstractADType; kwargs... ) - return Hamiltonian(metric, ℓπ.logdensity, kind; kwargs...) + return Hamiltonian(metric, ℓπ.logdensity, adtype; kwargs...) end -function Hamiltonian(metric::AbstractMetric, ℓπ, kind::Union{Symbol,Val,Module}; kwargs...) - if LogDensityProblems.capabilities(ℓπ) === nothing +function Hamiltonian(metric::AbstractMetric, ℓπ, adtype::AbstractADType; kwargs...) + cap = LogDensityProblems.capabilities(ℓπ) + if cap === nothing throw( ArgumentError( "The log density function does not support the LogDensityProblems.jl interface", ), ) end - ℓ = LogDensityProblemsAD.ADgradient( - kind isa Val ? kind : Val(Symbol(kind)), ℓπ; kwargs... - ) - return Hamiltonian(metric, ℓ) + _logdensity = Base.Fix1(LogDensityProblems.logdensity, ℓπ) + _logdensity_and_gradient = if cap === LogDensityProblems.LogDensityOrder{0}() + # In this case ℓπ does not support evaluation of the gradient of the log density function + x -> DI.value_and_gradient(_logdensity, adtype, x) + else + Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓπ) + end + return Hamiltonian(metric, _logdensity, _logdensity_and_gradient) end ### Init diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index da448f78..11e220e2 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -1,4 +1,4 @@ -using ReTest, Random, AdvancedHMC, ForwardDiff, AbstractMCMC +using ReTest, Random, AdvancedHMC, ForwardDiff, AbstractMCMC, ADTypes using Statistics: mean @testset "AbstractMCMC w/ gdemo" begin @@ -18,7 +18,7 @@ using Statistics: mean custom = HMCSampler(κ, metric, adaptor) model = AdvancedHMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo) + LogDensityProblemsAD.ADgradient(AutoForwardDiff(), ℓπ_gdemo) ) @testset "getparams and setparams!!" begin diff --git a/test/adaptation.jl b/test/adaptation.jl index 346423ea..1808d493 100644 --- a/test/adaptation.jl +++ b/test/adaptation.jl @@ -1,4 +1,4 @@ -using ReTest, LinearAlgebra, Distributions, AdvancedHMC, Random, ForwardDiff +using ReTest, LinearAlgebra, Distributions, AdvancedHMC, Random, ForwardDiff, ADTypes using AdvancedHMC.Adaptation: WelfordVar, NaiveVar, WelfordCov, NaiveCov, get_estimation, get_estimation, reset! @@ -9,7 +9,7 @@ function runnuts(ℓπ, metric; n_samples=10_000) rng = MersenneTwister(0) nuts = NUTS(0.8) - h = Hamiltonian(metric, ℓπ, ForwardDiff) + h = Hamiltonian(metric, ℓπ, AutoForwardDiff()) step_size = AdvancedHMC.make_step_size(rng, nuts, h, θ_init) integrator = AdvancedHMC.make_integrator(nuts, step_size) κ = AdvancedHMC.make_kernel(nuts, integrator) diff --git a/test/contrib.jl b/test/contrib.jl index 38a5023a..92e95af3 100644 --- a/test/contrib.jl +++ b/test/contrib.jl @@ -1,11 +1,11 @@ -using ReTest, AdvancedHMC, ForwardDiff, Zygote +using ReTest, AdvancedHMC, ForwardDiff, Zygote, ADTypes @testset "contrib" begin @testset "ad" begin metric = UnitEuclideanMetric(D) h_hand = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) - h_forwarddiff = Hamiltonian(metric, ℓπ, ForwardDiff) - h_zygote = Hamiltonian(metric, ℓπ, Zygote) + h_forwarddiff = Hamiltonian(metric, ℓπ, AutoForwardDiff()) + h_zygote = Hamiltonian(metric, ℓπ, AutoZygote()) for x in [rand(D), rand(D, 10)] v_hand, g_hand = h_hand.∂ℓπ∂θ(x) v_forwarddiff, g_forwarddiff = h_forwarddiff.∂ℓπ∂θ(x) diff --git a/test/demo.jl b/test/demo.jl index 5fe3cf08..35dc319a 100644 --- a/test/demo.jl +++ b/test/demo.jl @@ -1,5 +1,5 @@ using ReTest -using AdvancedHMC, Distributions, ForwardDiff, ComponentArrays, AbstractMCMC +using AdvancedHMC, Distributions, ForwardDiff, ComponentArrays, AbstractMCMC, ADTypes using LinearAlgebra, ADTypes @testset "Demo" begin @@ -23,7 +23,7 @@ using LinearAlgebra, ADTypes # Define a Hamiltonian system metric = DiagEuclideanMetric(D) - hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff) + hamiltonian = Hamiltonian(metric, ℓπ, AutoForwardDiff()) # Define a leapfrog solver, with initial step size chosen heuristically initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) @@ -76,7 +76,7 @@ end metric = DiagEuclideanMetric(D) # choose AD framework or provide a function manually - hamiltonian = Hamiltonian(metric, ℓπ, Val(:ForwardDiff); x=p1) + hamiltonian = Hamiltonian(metric, ℓπ, AutoForwardDiff(); x=p1) # Define a leapfrog solver, with initial step size chosen heuristically initial_ϵ = find_good_stepsize(hamiltonian, p1) diff --git a/test/integrator.jl b/test/integrator.jl index b9eb1407..854a724a 100644 --- a/test/integrator.jl +++ b/test/integrator.jl @@ -1,4 +1,4 @@ -using ReTest, Random, AdvancedHMC, ForwardDiff +using ReTest, Random, AdvancedHMC, ForwardDiff, ADTypes using OrdinaryDiffEq using LinearAlgebra: dot @@ -120,7 +120,7 @@ using Statistics: mean ϵ = 0.01 for lf in [Leapfrog(ϵ), DiffEqIntegrator(ϵ, VerletLeapfrog())] q_init = randn(1) - h = Hamiltonian(UnitEuclideanMetric(1), negU, ForwardDiff) + h = Hamiltonian(UnitEuclideanMetric(1), negU, AutoForwardDiff()) p_init = AdvancedHMC.rand_momentum( Random.default_rng(), h.metric, h.kinetic, q_init ) diff --git a/test/mcmcchains.jl b/test/mcmcchains.jl index 4d5526ae..b50e59db 100644 --- a/test/mcmcchains.jl +++ b/test/mcmcchains.jl @@ -1,4 +1,4 @@ -using ReTest, Random, AdvancedHMC, ForwardDiff, AbstractMCMC, MCMCChains +using ReTest, Random, AdvancedHMC, ForwardDiff, AbstractMCMC, MCMCChains, ADTypes using Statistics: mean @testset "MCMCChains w/ gdemo" begin @@ -10,7 +10,7 @@ using Statistics: mean θ_init = randn(rng, 2) model = AdvancedHMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo) + LogDensityProblemsAD.ADgradient(AutoForwardDiff(), ℓπ_gdemo) ) integrator = Leapfrog(1e-3) kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) diff --git a/test/models.jl b/test/models.jl index 87c1b214..32c3a9b5 100644 --- a/test/models.jl +++ b/test/models.jl @@ -1,4 +1,4 @@ -using ReTest, Random, AdvancedHMC, ForwardDiff +using ReTest, Random, AdvancedHMC, ForwardDiff, ADTypes using Statistics: mean @testset "Models" begin @@ -11,7 +11,7 @@ using Statistics: mean θ_init = randn(rng, 2) metric = DiagEuclideanMetric(2) - h = Hamiltonian(metric, ℓπ_gdemo, ForwardDiff) + h = Hamiltonian(metric, ℓπ_gdemo, AutoForwardDiff()) integrator = Leapfrog(0.1) κ = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) adaptor = StanHMCAdaptor(