Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ADTypes AD backend selection in Hamiltonian #405

Merged
merged 16 commits into from
Mar 30, 2025

Conversation

ErikQQY
Copy link
Collaborator

@ErikQQY ErikQQY commented Mar 16, 2025

While LogDensityProblemsAD.jl supports the AD backend from ADTypes.jl, we can also allow this for an unifying interface

@yebai yebai requested review from yebai, sunxd3 and devmotion March 17, 2025 12:08
@ErikQQY ErikQQY requested a review from devmotion March 18, 2025 07:19
@ErikQQY ErikQQY requested a review from yebai March 22, 2025 07:30
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...)
_logdensity = Base.Fix1(LogDensityProblems.logdensity, ℓ)
_logdensity_and_gradient = function (x)
prep = DI.prepare_gradient(_logdensity, adtype, x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preparing outside the closure would be more efficient, but users have to be warned I guess

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should prepare outside the closure, but since we need the log density and its gradient of the model as a function in Hamiltonian, how can we prepare in advance?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have access to a "typical input" somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, the Hamiltonian is basically used for keeping the log density and its gradient function, and "typical input" is used in the subsequent sampling process which may not be possible to pass in here.

@@ -146,7 +152,12 @@ function Hamiltonian(metric::AbstractMetric, ℓ; kwargs...)
ℓπ = if cap === LogDensityProblems.LogDensityOrder{0}()
# In this case ℓ does not support evaluation of the gradient of the log density function
# We use ForwardDiff to compute the gradient
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...)
_logdensity = Base.Fix1(LogDensityProblems.logdensity, ℓ)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One can also use DI.Constant for if it contains no differentiable storage

Comment on lines 162 to 164
function Hamiltonian(
metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val,Module}; kwargs...
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess removing this is a breaking change?

Project.toml Outdated
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional reexports are a code smell IMO - why can't users just load ADTypes? By unconditional reexporting everything you completely lose control over what's defined in your package and automatically the whole API of the reexported package (and hence all breaking changes etc.) become part of the API of your package.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, users who want to try other AD backends should just load ADTypes by themselves then

end
function Hamiltonian(metric::AbstractMetric, ℓ; kwargs...)
function Hamiltonian(metric::AbstractMetric, ℓ; adtype=AutoForwardDiff(), kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function Hamiltonian(metric::AbstractMetric, ℓ; adtype=AutoForwardDiff(), kwargs...)
function Hamiltonian(metric::AbstractMetric, ℓ; adtype::ADType=AutoForwardDiff(), kwargs...)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function Hamiltonian(metric::AbstractMetric, ℓ; adtype=AutoForwardDiff(), kwargs...)
function Hamiltonian(metric::AbstractMetric, ℓ; adtype=ADTypes.AutoForwardDiff(), kwargs...)

function Hamiltonian(metric::AbstractMetric, ℓ::LogDensityModel; kwargs...)
return Hamiltonian(metric, ℓ.logdensity; kwargs...)
function Hamiltonian(
metric::AbstractMetric, ℓ::LogDensityModel; adtype=AutoForwardDiff(), kwargs...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
metric::AbstractMetric, ℓ::LogDensityModel; adtype=AutoForwardDiff(), kwargs...
metric::AbstractMetric, ℓ::LogDensityModel; adtype::ADType=AutoForwardDiff(), kwargs...

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
metric::AbstractMetric, ℓ::LogDensityModel; adtype=AutoForwardDiff(), kwargs...
metric::AbstractMetric, ℓ::LogDensityModel; adtype=ADTypes.AutoForwardDiff(), kwargs...

@@ -146,7 +151,12 @@ function Hamiltonian(metric::AbstractMetric, ℓ; kwargs...)
ℓπ = if cap === LogDensityProblems.LogDensityOrder{0}()
# In this case ℓ does not support evaluation of the gradient of the log density function
# We use ForwardDiff to compute the gradient
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just

Suggested change
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...)
LogDensityProblemsAD.ADgradient(adtype, ℓ; kwargs...)

?

Comment on lines 155 to 156
_logdensity_and_gradient = function (x)
prep = DI.prepare_gradient(_logdensity, adtype, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of the preparation step in this function? Every time you'd evaluate logdensity + gradient for an input x, it would run preparation for x and then immediately discard it after the computation. If you just use LogDensityProblemsAD.ADgradient (see above), then users could provide a typical input x as keyword argument to optimize the AD computations of the hamiltonian (or otherwise they'd get the slower fallback).

Copy link

@gdalle gdalle Mar 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If preparation is not reused, then it is not beneficial. But as per my comment above, if it can be reused because we have access to a typical input, then it works basically the same as a LogDensityProblemsAD.ADgradient (except that Turing no longer depends on LogDensityProblemsAD` in order to mutualize backend code maintenance inside DI)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing to DifferentiationInterface is just a step forward for the mutualized AD maintenance in Turing, and seems if we want to use DI to compute the gradient, we will have to use the preparation mechanism. As for the performance aspects, I will test the difference between using DifferentiationInterface and LogDensityProblemsAD, if no difference between those two options, we can safely change to DI then

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't actually have to use preparation with DI.
DI.gradient(f, backend, x) will actually work too, it usually defaults to preparing first and executing second, but for some backends there are faster shortcuts which we try to take.

Comment on lines 162 to 164
function Hamiltonian(
metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val,Module}; kwargs...
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to remove these?

@yebai
Copy link
Member

yebai commented Mar 24, 2025

@ErikQQY, let's only introduce ADTypes support in this PR and switch to DifferentiationInterface via a separate PR.

@ErikQQY
Copy link
Collaborator Author

ErikQQY commented Mar 29, 2025

As per @yebai's suggestion, let's only focus on ADTypes support in this PR and integrate DifferentiationInterface in another PR

@yebai yebai merged commit 7ea27f3 into TuringLang:main Mar 30, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants