-
Notifications
You must be signed in to change notification settings - Fork 19
Support Mixing AD Frameworks for LogDensityProblems
and the objective (cleaned-up)
#187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
b9df825
to
ac4a2e6
Compare
ac4a2e6
to
b99f0cd
Compare
AdvancedVI.jl documentation for PR #187 is available at: |
f91ca4d
to
b4d9ad8
Compare
b4d9ad8
to
e4ad19a
Compare
e4ad19a
to
f81dd9e
Compare
Hi @wsmoses , it seems the Enzyme tests on |
nope, nothing that comes to mind. if you can file an issue with an MWE I can take alook |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 1f5b495 | Previous: 630d212 | Ratio |
---|---|---|---|
normal/RepGradELBO + STL/meanfield/Zygote |
4026781873 ns |
4080370618.5 ns |
0.99 |
normal/RepGradELBO + STL/meanfield/ReverseDiff |
1115607139 ns |
1150682904 ns |
0.97 |
normal/RepGradELBO + STL/meanfield/Mooncake |
1205063583 ns |
1241363224 ns |
0.97 |
normal/RepGradELBO + STL/fullrank/Zygote |
3970155819.5 ns |
4019642435 ns |
0.99 |
normal/RepGradELBO + STL/fullrank/ReverseDiff |
1597103126.5 ns |
1636522620 ns |
0.98 |
normal/RepGradELBO + STL/fullrank/Mooncake |
1280192193 ns |
1274526418 ns |
1.00 |
normal/RepGradELBO/meanfield/Zygote |
2851261688 ns |
2847378761.5 ns |
1.00 |
normal/RepGradELBO/meanfield/ReverseDiff |
765590505 ns |
771309994 ns |
0.99 |
normal/RepGradELBO/meanfield/Mooncake |
1109370223 ns |
1091051726 ns |
1.02 |
normal/RepGradELBO/fullrank/Zygote |
2857329571.5 ns |
2838345735 ns |
1.01 |
normal/RepGradELBO/fullrank/ReverseDiff |
949336325.5 ns |
962637306 ns |
0.99 |
normal/RepGradELBO/fullrank/Mooncake |
1151867635 ns |
1143764360 ns |
1.01 |
normal + bijector/RepGradELBO + STL/meanfield/Zygote |
5687899208 ns |
5614419426 ns |
1.01 |
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff |
2438116302 ns |
2389714745 ns |
1.02 |
normal + bijector/RepGradELBO + STL/meanfield/Mooncake |
4029199467 ns |
3869611785 ns |
1.04 |
normal + bijector/RepGradELBO + STL/fullrank/Zygote |
5660492832 ns |
5485521005 ns |
1.03 |
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff |
3066241035.5 ns |
3016952886.5 ns |
1.02 |
normal + bijector/RepGradELBO + STL/fullrank/Mooncake |
4140026557 ns |
4004564598.5 ns |
1.03 |
normal + bijector/RepGradELBO/meanfield/Zygote |
4365927528.5 ns |
4229780243.5 ns |
1.03 |
normal + bijector/RepGradELBO/meanfield/ReverseDiff |
2053732863 ns |
1998304814 ns |
1.03 |
normal + bijector/RepGradELBO/meanfield/Mooncake |
3841033310.5 ns |
3735322298.5 ns |
1.03 |
normal + bijector/RepGradELBO/fullrank/Zygote |
4467055087.5 ns |
4345273868 ns |
1.03 |
normal + bijector/RepGradELBO/fullrank/ReverseDiff |
2333907144 ns |
2273026717 ns |
1.03 |
normal + bijector/RepGradELBO/fullrank/Mooncake |
4015987416.5 ns |
3896302067.5 ns |
1.03 |
This comment was automatically generated by workflow using github-action-benchmark.
I'm happy with the proposed changes, but I will wait for @penelopeysm or @sunxd3 to comment on the software side of things. |
@sunxd3 @penelopeysm could you take a look at this PR when you have available bandwidth? Thanks in advance! |
- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend. | ||
- The target `LogDensityProblem` must have a capability at least `LogDensityProblems.LogDensityOrder{1}()`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you still differentiate through it (albeit with a warning)? So technically prob
only needs to have zeroth-order capabilities.
Also, I wonder if a warning might be a bit harsh. How important is this use-case / do you want to actively discourage people from using a single AD backend for the entire VI process? If the aim is just to let people know rather than to discourage, then I think an info message rather than a warning might be gentler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How important is this use-case / do you want to actively discourage people from using a single AD backend for the entire VI process?
I intend to make the Turing layer use mixed-AD backends. So it's the most important use case. In the future, I think we could introduce changes that don't allow using a single backend in an end-to-end manner. Thus, I originally thought to simply disallow using a single backend. But we don't currently have hard reasons not to allow it, so I settled with a warning.
q_stop = restructure(params) | ||
capability = LogDensityProblems.capabilities(Prob) | ||
@assert adtype isa Union{<:AutoReverseDiff,<:AutoZygote,<:AutoMooncake,<:AutoEnzyme} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to make this a nicer error message?
Also is there a reason why ForwardDiff is forbidden at this stage -- is it because you don't import the chain rule for logdensity_and_gradient
for it? But in that case, you only need to forbid it in the MixedADLogDensityProblem
branch. It looks to me that the old branch, where the AD backend just directly differentiates through logdensity
, should still work with ForwardDiff. Am I misunderstanding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But in that case, you only need to forbid it in the MixedADLogDensityProblem branch. It looks to me that the old branch, where the AD backend just directly differentiates through logdensity, should still work with ForwardDiff.
That is true! Let me put a comment in the code clarifying that the backend doesn't matter if we don't rely on 1st order differentiability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to make this a nicer error message?
I think the assert itself should be clear enough since all we need to say is that adtype
must be one of those types?
src/mixedad_logdensity.jl
Outdated
function LogDensityProblems.logdensity_and_gradient( | ||
mixedad_prob::MixedADLogDensityProblem, x::AbstractArray | ||
) | ||
return LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x) | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm not super familiar with the codebase, but where / how does this method get called?
My understanding is that the call stack looks like this:
optimize
Line 42 in 366c911
function optimize( - ->
step
function step( - ->
estimate_gradient!
function estimate_gradient!( - ->
_value_and_gradient!
AdvancedVI.jl/src/AdvancedVI.jl
Line 51 in 366c911
function _value_and_gradient!( - ->
DifferentiationInterface.value_and_gradient!
But it's not clear to me that DI.value_and_gradient!
calls logdensity_and_gradient
. Does it? It seems to me that it would just differentiate through logdensity
anyway.
In DynamicPPL it's the other way round: logdensity_and_gradient
is implemented using DI.value_and_gradient
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested out the example given in docs/src/examples.md
https://github.com/TuringLang/AdvancedVI.jl/blob/366c9115edf5f9c675dca44f1a78e23854656215/docs/src/examples.md, changing the implementation of logdensity_and_gradient
to
function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
- return (
- LogDensityProblems.logdensity(model, θ),
- ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
- )
+ error("is this function called?")
end
and VI still ran perfectly fine, so my suspicion is that this isn't doing what you'd like it to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm not super familiar with the codebase, but where / how does this method get called?
You are right, it isn't. Let me remove it.
But it's not clear to me that
DI.value_and_gradient!
callslogdensity_and_gradient
. Does it? It seems to me that it would just differentiate throughlogdensity
anyway.
The intention of this PR is to hijack the rrule
for logdensity
to indeed call logdensity_and_gradient
. This is done here. The extensions are registering the rrule
for backends that don't automatically use ChainRulesCore`.
and VI still ran perfectly fine, so my suspicion is that this isn't doing what you'd like it to do.
Thanks for catching this. This is actually a bug in the rrule
import for ReverseDiff
, I checked that other AD backends indeed error out. I'll look into fixing this.
I should also add a test case that catches this.
Co-authored-by: Penelope Yong <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@yebai When trying to use Got exception outside of a @test
MethodError: no method matching increment_and_get_rdata!(::Mooncake.FData{@NamedTuple{parent::Matrix{Float64}, indices::Mooncake.NoFData, offset1::Mooncake.NoFData, stride1::Mooncake.NoFData}}, ::Mooncake.NoRData, ::Vector{Float64})
The function `increment_and_get_rdata!` exists, but no method is defined for this combination of argument types.
Closest candidates are:
increment_and_get_rdata!(::Any, ::Any, ::ChainRulesCore.NoTangent)
@ Mooncake ~/.julia/packages/Mooncake/BvRb6/src/tools_for_rules.jl:256
increment_and_get_rdata!(::Any, ::Any, ::ChainRulesCore.Thunk)
@ Mooncake ~/.julia/packages/Mooncake/BvRb6/src/tools_for_rules.jl:257
increment_and_get_rdata!(::Array{P}, ::Mooncake.NoRData, ::Array{P}) where P<:Union{Float16, Float32, Float64}
@ Mooncake ~/.julia/packages/Mooncake/BvRb6/src/tools_for_rules.jl:252
...
Stacktrace:
[1] increment_and_get_rdata!(f::Mooncake.FData{@NamedTuple{parent::Matrix{Float64}, indices::Mooncake.NoFData, offset1::Mooncake.NoFData, stride1::Mooncake.NoFData}}, r::Mooncake.NoRData, t::ChainRulesCore.Thunk{AdvancedVI.var"#1#3"{Float64, Vector{Float64}}})
@ Mooncake ~/.julia/packages/Mooncake/BvRb6/src/tools_for_rules.jl:258
[2] (::Mooncake.var"#289#291")(x::Mooncake.CoDual{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Mooncake.FData{@NamedTuple{parent::Matrix{Float64}, indices::Mooncake.NoFData, offset1::Mooncake.NoFData, stride1::Mooncake.NoFData}}}, l_rdata::Mooncake.LazyZeroRData{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Nothing}, cr_dx::ChainRulesCore.Thunk{AdvancedVI.var"#1#3"{Float64, Vector{Float64}}})
@ Mooncake ~/.julia/packages/Mooncake/BvRb6/src/tools_for_rules.jl:303
[3] map (repeats 3 times)
@ ./tuple.jl:406 [inlined] Is there a way to get around this? |
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{ | ||
typeof(LogDensityProblems.logdensity), | ||
AdvancedVI.MixedADLogDensityProblem, | ||
SubArray{<:IEEEFloat,1}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Red-Portal, this looks suspicious to me. Two possible issues:
@from_rule
might not work well with SubArraySubArray
might require extra new rules.
cc @willtebbutt who knows this better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's mostly that @from_rrule
doesn't work well with SubArray
s at this point. Would it be possible to produce an MWE, and open an issue on Mooncake? Since ChainRules doesn't provide strong promises around what type is used to represent the result of an rrule
, the translation between ChainRules and Mooncake is always going to be a bit flakey (there's simply not way around that), but we ought to be able to
- provide a better error message, and
- handle this particular case, which really doesn't seem like it ought to be too challenging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you all for chiming in. Filed an issue here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Red-Portal I suggest you directly implement this rule for Mooncake here. Mooncake’s support of SubArrays might be several weeks away, because we are a tiny team and forward mode has higher priority.
Can you produce an MWE for it? |
This PR enables mixing AD frameworks for the target
LogDensityProblem
and differentiating through the variational objective. This is enabled by defining a customrrule
throughChainRulesCore
. As such, we are restricted to AD frameworks that support importingrrule
s. However, we still support directly differentiating throughLogDensityProblems.logdensity
if the targetLogDensityProblem
only has zeroth-order capability.There are three motivations for this PR
LogDensityProblems.logdensity
instead of relying onLogDensityProblems.logdensity_and_gradient
. This is inconvenient since we can't check forcapabilities
and inconsistent with the intended usage ofLogDensityProblems
.