Skip to content

Commit b4d9ad8

Browse files
committed
fix add missing extensions
1 parent 4cebb35 commit b4d9ad8

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

ext/AdvancedVIEnzymeExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module AdvancedVIEnzymeExt
2+
3+
using AdvancedVI
4+
using LogDensityProblems
5+
using Enzyme
6+
7+
Enzyme.@import_rrule(
8+
typeof(LogDensityProblems.logdensity),
9+
AdvancedVI.MixedADLogDensityProblem,
10+
AbstractVector
11+
)
12+
13+
end

ext/AdvancedVIMooncakeExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module AdvancedVIMooncakeExt
2+
3+
using AdvancedVI
4+
using Base: IEEEFloat
5+
using LogDensityProblems
6+
using Mooncake
7+
8+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{
9+
typeof(LogDensityProblems.logdensity),
10+
AdvancedVI.MixedADLogDensityProblem,
11+
Array{<:IEEEFloat,1},
12+
}
13+
14+
end

ext/AdvancedVIReverseDiffExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module AdvancedVIReverseDiffExt
2+
3+
using AdvancedVI
4+
using LogDensityProblems
5+
using ReverseDiff
6+
7+
ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity(
8+
prob::AdvancedVI.MixedADLogDensityProblem, x::ReverseDiff.TrackedArray
9+
)
10+
11+
end

0 commit comments

Comments
 (0)