diff --git a/Project.toml b/Project.toml index ea41cd03..178e4c3a 100644 --- a/Project.toml +++ b/Project.toml @@ -17,34 +17,38 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" [extensions] +AdvancedHMCADTypesExt = "ADTypes" AdvancedHMCCUDAExt = "CUDA" AdvancedHMCMCMCChainsExt = "MCMCChains" AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq" [compat] +ADTypes = "1" AbstractMCMC = "5.6" ArgCheck = "1, 2" CUDA = "3, 4, 5" DocStringExtensions = "0.8, 0.9" +LinearAlgebra = "<0.1, 1" LogDensityProblems = "2" LogDensityProblemsAD = "1" MCMCChains = "5, 6" OrdinaryDiffEq = "6" ProgressMeter = "1" +Random = "<0.1, 1" Setfield = "0.7, 0.8, 1" Statistics = "1.6" StatsBase = "0.31, 0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" -LinearAlgebra = "<0.1, 1" -Random = "<0.1, 1" julia = "1.10" [extras] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" diff --git a/ext/AdvancedHMCADTypesExt.jl b/ext/AdvancedHMCADTypesExt.jl new file mode 100644 index 00000000..09159726 --- /dev/null +++ b/ext/AdvancedHMCADTypesExt.jl @@ -0,0 +1,24 @@ +module AdvancedHMCADTypesExt + +using AdvancedHMC: + AbstractMetric, LogDensityModel, Hamiltonian, LogDensityProblems, LogDensityProblemsAD +using ADTypes: AbstractADType + +function Hamiltonian( + metric::AbstractMetric, ℓπ::LogDensityModel, kind::AbstractADType; kwargs... +) + return Hamiltonian(metric, ℓπ.logdensity, kind; kwargs...) +end +function Hamiltonian(metric::AbstractMetric, ℓπ, kind::AbstractADType; kwargs...) + if LogDensityProblems.capabilities(ℓπ) === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + ℓ = LogDensityProblemsAD.ADgradient(kind, ℓπ; kwargs...) + return Hamiltonian(metric, ℓ) +end + +end diff --git a/test/Project.toml b/test/Project.toml index 144952ea..9d7ee347 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" diff --git a/test/demo.jl b/test/demo.jl index 5c900a74..5fe3cf08 100644 --- a/test/demo.jl +++ b/test/demo.jl @@ -1,6 +1,6 @@ using ReTest using AdvancedHMC, Distributions, ForwardDiff, ComponentArrays, AbstractMCMC -using LinearAlgebra +using LinearAlgebra, ADTypes @testset "Demo" begin # Define the target distribution using the `LogDensityProblem` interface @@ -105,3 +105,40 @@ end @test "μ" ∈ labels @test "σ" ∈ labels end + +@testset "ADTypes" begin + # Set the number of samples to draw and warmup iterations + n_samples, n_adapts = 2_000, 1_000 + initial_θ = rand(D) + # Define a Hamiltonian system + metric = DiagEuclideanMetric(D) + + hamiltonian_ldp = Hamiltonian(metric, ℓπ_gdemo, AutoForwardDiff()) + + model = AbstractMCMC.LogDensityModel(ℓπ_gdemo) + hamiltonian_ldm = Hamiltonian(metric, model, AutoForwardDiff()) + + for hamiltonian in (hamiltonian_ldp, hamiltonian_ldm) + initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) + integrator = Leapfrog(initial_ϵ) + + kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) + adaptor = StanHMCAdaptor( + MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator) + ) + + samples, stats = sample( + hamiltonian, + kernel, + initial_θ, + n_samples, + adaptor, + n_adapts; + progress=false, + verbose=false, + ) + + @test length(samples) == n_samples + @test length(stats) == n_samples + end +end