Skip to content

Support Mixing AD Frameworks for LogDensityProblems and the objective (cleaned-up) #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

Red-Portal
Copy link
Member

This PR enables mixing AD frameworks for the target LogDensityProblem and differentiating through the variational objective. This is enabled by defining a custom rrule through ChainRulesCore. As such, we are restricted to AD frameworks that support importing rrules. However, we still support directly differentiating through LogDensityProblems.logdensity if the target LogDensityProblem only has zeroth-order capability.

There are three motivations for this PR

  • Currently, for ADVI, the selected AD framework differentiates all the way through LogDensityProblems.logdensity instead of relying on LogDensityProblems.logdensity_and_gradient. This is inconvenient since we can't check for capabilities and inconsistent with the intended usage of LogDensityProblems.
  • The second reason is that we can now decouple AD within AdvancedVI and within the target. This would be useful in the future when we move to GPUs. We could use GPU arrays winthin AdvancedVI, and CPU arrays for the log-target, for example.
  • The last reason, which is related to the previous one, is that the log-target now doesn't need to be differentiable by the selected AD framework. So we can cover models coming outside of Julia with their own gradient calculations. Crucially, Stan models using BridgeStan.

Copy link
Contributor

AdvancedVI.jl documentation for PR #187 is available at:
https://TuringLang.github.io/AdvancedVI.jl/previews/PR187/

@Red-Portal
Copy link
Member Author

Hi @wsmoses , it seems the Enzyme tests on lts started failing. As far as Enzyme is concerned, not much has changed in the code from last week. Is there anything that changed on the Enzyme side in the past week that would have caused the failure? Otherwise, I'll take a deeper look.

@wsmoses
Copy link
Contributor

wsmoses commented Jul 30, 2025

nope, nothing that comes to mind. if you can file an issue with an MWE I can take alook

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 1f5b495 Previous: 630d212 Ratio
normal/RepGradELBO + STL/meanfield/Zygote 4026781873 ns 4080370618.5 ns 0.99
normal/RepGradELBO + STL/meanfield/ReverseDiff 1115607139 ns 1150682904 ns 0.97
normal/RepGradELBO + STL/meanfield/Mooncake 1205063583 ns 1241363224 ns 0.97
normal/RepGradELBO + STL/fullrank/Zygote 3970155819.5 ns 4019642435 ns 0.99
normal/RepGradELBO + STL/fullrank/ReverseDiff 1597103126.5 ns 1636522620 ns 0.98
normal/RepGradELBO + STL/fullrank/Mooncake 1280192193 ns 1274526418 ns 1.00
normal/RepGradELBO/meanfield/Zygote 2851261688 ns 2847378761.5 ns 1.00
normal/RepGradELBO/meanfield/ReverseDiff 765590505 ns 771309994 ns 0.99
normal/RepGradELBO/meanfield/Mooncake 1109370223 ns 1091051726 ns 1.02
normal/RepGradELBO/fullrank/Zygote 2857329571.5 ns 2838345735 ns 1.01
normal/RepGradELBO/fullrank/ReverseDiff 949336325.5 ns 962637306 ns 0.99
normal/RepGradELBO/fullrank/Mooncake 1151867635 ns 1143764360 ns 1.01
normal + bijector/RepGradELBO + STL/meanfield/Zygote 5687899208 ns 5614419426 ns 1.01
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff 2438116302 ns 2389714745 ns 1.02
normal + bijector/RepGradELBO + STL/meanfield/Mooncake 4029199467 ns 3869611785 ns 1.04
normal + bijector/RepGradELBO + STL/fullrank/Zygote 5660492832 ns 5485521005 ns 1.03
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff 3066241035.5 ns 3016952886.5 ns 1.02
normal + bijector/RepGradELBO + STL/fullrank/Mooncake 4140026557 ns 4004564598.5 ns 1.03
normal + bijector/RepGradELBO/meanfield/Zygote 4365927528.5 ns 4229780243.5 ns 1.03
normal + bijector/RepGradELBO/meanfield/ReverseDiff 2053732863 ns 1998304814 ns 1.03
normal + bijector/RepGradELBO/meanfield/Mooncake 3841033310.5 ns 3735322298.5 ns 1.03
normal + bijector/RepGradELBO/fullrank/Zygote 4467055087.5 ns 4345273868 ns 1.03
normal + bijector/RepGradELBO/fullrank/ReverseDiff 2333907144 ns 2273026717 ns 1.03
normal + bijector/RepGradELBO/fullrank/Mooncake 4015987416.5 ns 3896302067.5 ns 1.03

This comment was automatically generated by workflow using github-action-benchmark.

@yebai
Copy link
Member

yebai commented Jul 31, 2025

I'm happy with the proposed changes, but I will wait for @penelopeysm or @sunxd3 to comment on the software side of things.

@Red-Portal
Copy link
Member Author

@sunxd3 @penelopeysm could you take a look at this PR when you have available bandwidth? Thanks in advance!

- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend.
- The target `LogDensityProblem` must have a capability at least `LogDensityProblems.LogDensityOrder{1}()`.
Copy link
Member

@penelopeysm penelopeysm Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you still differentiate through it (albeit with a warning)? So technically prob only needs to have zeroth-order capabilities.

Also, I wonder if a warning might be a bit harsh. How important is this use-case / do you want to actively discourage people from using a single AD backend for the entire VI process? If the aim is just to let people know rather than to discourage, then I think an info message rather than a warning might be gentler.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How important is this use-case / do you want to actively discourage people from using a single AD backend for the entire VI process?

I intend to make the Turing layer use mixed-AD backends. So it's the most important use case. In the future, I think we could introduce changes that don't allow using a single backend in an end-to-end manner. Thus, I originally thought to simply disallow using a single backend. But we don't currently have hard reasons not to allow it, so I settled with a warning.

q_stop = restructure(params)
capability = LogDensityProblems.capabilities(Prob)
@assert adtype isa Union{<:AutoReverseDiff,<:AutoZygote,<:AutoMooncake,<:AutoEnzyme}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to make this a nicer error message?

Also is there a reason why ForwardDiff is forbidden at this stage -- is it because you don't import the chain rule for logdensity_and_gradient for it? But in that case, you only need to forbid it in the MixedADLogDensityProblem branch. It looks to me that the old branch, where the AD backend just directly differentiates through logdensity, should still work with ForwardDiff. Am I misunderstanding?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in that case, you only need to forbid it in the MixedADLogDensityProblem branch. It looks to me that the old branch, where the AD backend just directly differentiates through logdensity, should still work with ForwardDiff.

That is true! Let me put a comment in the code clarifying that the backend doesn't matter if we don't rely on 1st order differentiability.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to make this a nicer error message?

I think the assert itself should be clear enough since all we need to say is that adtype must be one of those types?

Comment on lines 23 to 28
function LogDensityProblems.logdensity_and_gradient(
mixedad_prob::MixedADLogDensityProblem, x::AbstractArray
)
return LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x)
end

Copy link
Member

@penelopeysm penelopeysm Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm not super familiar with the codebase, but where / how does this method get called?

My understanding is that the call stack looks like this:

But it's not clear to me that DI.value_and_gradient! calls logdensity_and_gradient. Does it? It seems to me that it would just differentiate through logdensity anyway.

In DynamicPPL it's the other way round: logdensity_and_gradient is implemented using DI.value_and_gradient.

Copy link
Member

@penelopeysm penelopeysm Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested out the example given in docs/src/examples.md https://github.com/TuringLang/AdvancedVI.jl/blob/366c9115edf5f9c675dca44f1a78e23854656215/docs/src/examples.md, changing the implementation of logdensity_and_gradient to

  function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
-     return (
-         LogDensityProblems.logdensity(model, θ),
-         ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
-     )
+     error("is this function called?")
  end

and VI still ran perfectly fine, so my suspicion is that this isn't doing what you'd like it to do.

Copy link
Member Author

@Red-Portal Red-Portal Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm not super familiar with the codebase, but where / how does this method get called?

You are right, it isn't. Let me remove it.

But it's not clear to me that DI.value_and_gradient! calls logdensity_and_gradient. Does it? It seems to me that it would just differentiate through logdensity anyway.

The intention of this PR is to hijack the rrule for logdensity to indeed call logdensity_and_gradient. This is done here. The extensions are registering the rrule for backends that don't automatically use ChainRulesCore`.

and VI still ran perfectly fine, so my suspicion is that this isn't doing what you'd like it to do.

Thanks for catching this. This is actually a bug in the rrule import for ReverseDiff, I checked that other AD backends indeed error out. I'll look into fixing this.

I should also add a test case that catches this.

Red-Portal and others added 8 commits August 5, 2025 12:11
@Red-Portal
Copy link
Member Author

@yebai When trying to use Mooncake with an imported rrule and a SubArray type, we hit this error:

  Got exception outside of a @test
  MethodError: no method matching increment_and_get_rdata!(::Mooncake.FData{@NamedTuple{parent::Matrix{Float64}, indices::Mooncake.NoFData, offset1::Mooncake.NoFData, stride1::Mooncake.NoFData}}, ::Mooncake.NoRData, ::Vector{Float64})
  The function `increment_and_get_rdata!` exists, but no method is defined for this combination of argument types.
  
  Closest candidates are:
    increment_and_get_rdata!(::Any, ::Any, ::ChainRulesCore.NoTangent)
     @ Mooncake ~/.julia/packages/Mooncake/BvRb6/src/tools_for_rules.jl:256
    increment_and_get_rdata!(::Any, ::Any, ::ChainRulesCore.Thunk)
     @ Mooncake ~/.julia/packages/Mooncake/BvRb6/src/tools_for_rules.jl:257
    increment_and_get_rdata!(::Array{P}, ::Mooncake.NoRData, ::Array{P}) where P<:Union{Float16, Float32, Float64}
     @ Mooncake ~/.julia/packages/Mooncake/BvRb6/src/tools_for_rules.jl:252
    ...
  
  Stacktrace:
    [1] increment_and_get_rdata!(f::Mooncake.FData{@NamedTuple{parent::Matrix{Float64}, indices::Mooncake.NoFData, offset1::Mooncake.NoFData, stride1::Mooncake.NoFData}}, r::Mooncake.NoRData, t::ChainRulesCore.Thunk{AdvancedVI.var"#1#3"{Float64, Vector{Float64}}})
      @ Mooncake ~/.julia/packages/Mooncake/BvRb6/src/tools_for_rules.jl:258
    [2] (::Mooncake.var"#289#291")(x::Mooncake.CoDual{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Mooncake.FData{@NamedTuple{parent::Matrix{Float64}, indices::Mooncake.NoFData, offset1::Mooncake.NoFData, stride1::Mooncake.NoFData}}}, l_rdata::Mooncake.LazyZeroRData{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Nothing}, cr_dx::ChainRulesCore.Thunk{AdvancedVI.var"#1#3"{Float64, Vector{Float64}}})
      @ Mooncake ~/.julia/packages/Mooncake/BvRb6/src/tools_for_rules.jl:303
    [3] map (repeats 3 times)
      @ ./tuple.jl:406 [inlined]

Is there a way to get around this?

Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{
typeof(LogDensityProblems.logdensity),
AdvancedVI.MixedADLogDensityProblem,
SubArray{<:IEEEFloat,1},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Red-Portal, this looks suspicious to me. Two possible issues:

  • @from_rule might not work well with SubArray
  • SubArray might require extra new rules.

cc @willtebbutt who knows this better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's mostly that @from_rrule doesn't work well with SubArrays at this point. Would it be possible to produce an MWE, and open an issue on Mooncake? Since ChainRules doesn't provide strong promises around what type is used to represent the result of an rrule, the translation between ChainRules and Mooncake is always going to be a bit flakey (there's simply not way around that), but we ought to be able to

  1. provide a better error message, and
  2. handle this particular case, which really doesn't seem like it ought to be too challenging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you all for chiming in. Filed an issue here

Copy link
Member

@yebai yebai Aug 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Red-Portal I suggest you directly implement this rule for Mooncake here. Mooncake’s support of SubArrays might be several weeks away, because we are a tiny team and forward mode has higher priority.

@yebai
Copy link
Member

yebai commented Aug 6, 2025

@yebai When trying to use Mooncake with an imported rrule and a SubArray type, we hit this error:

Can you produce an MWE for it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants