Skip to content

Commit 6469ddf

Browse files
committed
fix add missing files
1 parent b99f0cd commit 6469ddf

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

src/mixedad_logdensity.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
2+
"""
3+
MixedADLogDensityProblem(problem)
4+
5+
A `LogDensityProblem` wrapper for mixing AD frameworks.
6+
Whenever the outer AD framework attempts to differentiate through `logdensity(problem)`
7+
the pullback calls `logdensity_and_gradient`, which invokes the inner AD framework.
8+
"""
9+
struct MixedADLogDensityProblem{Problem}
10+
problem::Problem
11+
end
12+
13+
function LogDensityProblems.dimension(mixedad_prob::MixedADLogDensityProblem)
14+
return LogDensityProblems.dimension(mixedad_prob.problem)
15+
end
16+
17+
function LogDensityProblems.logdensity(
18+
mixedad_prob::MixedADLogDensityProblem, x::AbstractArray
19+
)
20+
return LogDensityProblems.logdensity(mixedad_prob.problem, x)
21+
end
22+
23+
function LogDensityProblems.logdensity_and_gradient(
24+
mixedad_prob::MixedADLogDensityProblem, x::AbstractArray
25+
)
26+
return LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x)
27+
end
28+
29+
function LogDensityProblems.logdensity_gradient_and_hessian(
30+
mixedad_prob::MixedADLogDensityProblem, x::AbstractArray
31+
)
32+
return LogDensityProblems.logdensity_gradient_and_hessian(mixedad_prob.problem, x)
33+
end
34+
35+
function LogDensityProblems.capabilities(
36+
::Type{MixedADLogDensityProblem{Prob}}
37+
) where {Prob}
38+
return LogDensityProblems.capabilities(Prob)
39+
end
40+
41+
function ChainRulesCore.rrule(
42+
::typeof(LogDensityProblems.logdensity),
43+
mixedad_prob::MixedADLogDensityProblem,
44+
x::AbstractArray,
45+
)
46+
ℓπ, ∇ℓπ = LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x)
47+
function logdensity_pullback(∂y)
48+
∂x = @thunk(∂y' * ∇ℓπ)
49+
return NoTangent(), NoTangent(), ∂x
50+
end
51+
return ℓπ, logdensity_pullback
52+
end

test/general/mixedad_logdensity.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
2+
struct MixedADTestModel end
3+
4+
function LogDensityProblems.logdensity(::MixedADTestModel, θ)
5+
return Float64(ℯ)
6+
end
7+
8+
function LogDensityProblems.dimension(::MixedADTestModel)
9+
return 3
10+
end
11+
12+
function LogDensityProblems.capabilities(::Type{<:MixedADTestModel})
13+
return LogDensityProblems.LogDensityOrder{2}()
14+
end
15+
16+
function LogDensityProblems.logdensity_and_gradient(::MixedADTestModel, θ)
17+
return (Float64(ℯ), [1.0, 2.0, 3.0])
18+
end
19+
20+
function LogDensityProblems.logdensity_gradient_and_hessian(::MixedADTestModel, θ)
21+
return (Float64(ℯ), [1.0, 2.0, 3.0], [1.0 1.0 1.0; 2.0 2.0 2.0; 3.0 3.0 3.0])
22+
end
23+
24+
@testset "interface MixedADLogDensityProblem" begin
25+
model = MixedADTestModel()
26+
model_ad = AdvancedVI.MixedADLogDensityProblem(model)
27+
28+
d = 3
29+
x = ones(Float64, d)
30+
31+
@test LogDensityProblems.dimension(model) == LogDensityProblems.dimension(model_ad)
32+
@test LogDensityProblems.capabilities(typeof(model)) ==
33+
LogDensityProblems.capabilities(typeof(model_ad))
34+
@test last(LogDensityProblems.logdensity(model, x))
35+
last(LogDensityProblems.logdensity(model_ad, x))
36+
@test last(LogDensityProblems.logdensity_and_gradient(model, x))
37+
last(LogDensityProblems.logdensity_and_gradient(model_ad, x))
38+
@test last(LogDensityProblems.logdensity_gradient_and_hessian(model, x))
39+
last(LogDensityProblems.logdensity_gradient_and_hessian(model_ad, x))
40+
end

0 commit comments

Comments
 (0)