-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow ADTypes AD backend selection in Hamiltonian #405
Conversation
src/AdvancedHMC.jl
Outdated
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...) | ||
_logdensity = Base.Fix1(LogDensityProblems.logdensity, ℓ) | ||
_logdensity_and_gradient = function (x) | ||
prep = DI.prepare_gradient(_logdensity, adtype, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preparing outside the closure would be more efficient, but users have to be warned I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree we should prepare outside the closure, but since we need the log density and its gradient of the model as a function in Hamiltonian, how can we prepare in advance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have access to a "typical input" somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK, the Hamiltonian
is basically used for keeping the log density and its gradient function, and "typical input" is used in the subsequent sampling process which may not be possible to pass in here.
src/AdvancedHMC.jl
Outdated
@@ -146,7 +152,12 @@ function Hamiltonian(metric::AbstractMetric, ℓ; kwargs...) | |||
ℓπ = if cap === LogDensityProblems.LogDensityOrder{0}() | |||
# In this case ℓ does not support evaluation of the gradient of the log density function | |||
# We use ForwardDiff to compute the gradient | |||
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...) | |||
_logdensity = Base.Fix1(LogDensityProblems.logdensity, ℓ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One can also use DI.Constant
for ℓ
if it contains no differentiable storage
src/AdvancedHMC.jl
Outdated
function Hamiltonian( | ||
metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val,Module}; kwargs... | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess removing this is a breaking change?
Project.toml
Outdated
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unconditional reexports are a code smell IMO - why can't users just load ADTypes? By unconditional reexporting everything you completely lose control over what's defined in your package and automatically the whole API of the reexported package (and hence all breaking changes etc.) become part of the API of your package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, users who want to try other AD backends should just load ADTypes by themselves then
src/AdvancedHMC.jl
Outdated
end | ||
function Hamiltonian(metric::AbstractMetric, ℓ; kwargs...) | ||
function Hamiltonian(metric::AbstractMetric, ℓ; adtype=AutoForwardDiff(), kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function Hamiltonian(metric::AbstractMetric, ℓ; adtype=AutoForwardDiff(), kwargs...) | |
function Hamiltonian(metric::AbstractMetric, ℓ; adtype::ADType=AutoForwardDiff(), kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function Hamiltonian(metric::AbstractMetric, ℓ; adtype=AutoForwardDiff(), kwargs...) | |
function Hamiltonian(metric::AbstractMetric, ℓ; adtype=ADTypes.AutoForwardDiff(), kwargs...) |
src/AdvancedHMC.jl
Outdated
function Hamiltonian(metric::AbstractMetric, ℓ::LogDensityModel; kwargs...) | ||
return Hamiltonian(metric, ℓ.logdensity; kwargs...) | ||
function Hamiltonian( | ||
metric::AbstractMetric, ℓ::LogDensityModel; adtype=AutoForwardDiff(), kwargs... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metric::AbstractMetric, ℓ::LogDensityModel; adtype=AutoForwardDiff(), kwargs... | |
metric::AbstractMetric, ℓ::LogDensityModel; adtype::ADType=AutoForwardDiff(), kwargs... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metric::AbstractMetric, ℓ::LogDensityModel; adtype=AutoForwardDiff(), kwargs... | |
metric::AbstractMetric, ℓ::LogDensityModel; adtype=ADTypes.AutoForwardDiff(), kwargs... |
src/AdvancedHMC.jl
Outdated
@@ -146,7 +151,12 @@ function Hamiltonian(metric::AbstractMetric, ℓ; kwargs...) | |||
ℓπ = if cap === LogDensityProblems.LogDensityOrder{0}() | |||
# In this case ℓ does not support evaluation of the gradient of the log density function | |||
# We use ForwardDiff to compute the gradient | |||
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...) | |
LogDensityProblemsAD.ADgradient(adtype, ℓ; kwargs...) |
?
src/AdvancedHMC.jl
Outdated
_logdensity_and_gradient = function (x) | ||
prep = DI.prepare_gradient(_logdensity, adtype, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of the preparation step in this function? Every time you'd evaluate logdensity + gradient for an input x
, it would run preparation for x
and then immediately discard it after the computation. If you just use LogDensityProblemsAD.ADgradient
(see above), then users could provide a typical input x
as keyword argument to optimize the AD computations of the hamiltonian (or otherwise they'd get the slower fallback).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If preparation is not reused, then it is not beneficial. But as per my comment above, if it can be reused because we have access to a typical input, then it works basically the same as a LogDensityProblemsAD.ADgradient
(except that Turing no longer depends on LogDensityProblemsAD` in order to mutualize backend code maintenance inside DI)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing to DifferentiationInterface is just a step forward for the mutualized AD maintenance in Turing, and seems if we want to use DI to compute the gradient, we will have to use the preparation mechanism. As for the performance aspects, I will test the difference between using DifferentiationInterface and LogDensityProblemsAD, if no difference between those two options, we can safely change to DI then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't actually have to use preparation with DI.
DI.gradient(f, backend, x)
will actually work too, it usually defaults to preparing first and executing second, but for some backends there are faster shortcuts which we try to take.
src/AdvancedHMC.jl
Outdated
function Hamiltonian( | ||
metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val,Module}; kwargs... | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to remove these?
@ErikQQY, let's only introduce |
As per @yebai's suggestion, let's only focus on ADTypes support in this PR and integrate DifferentiationInterface in another PR |
While LogDensityProblemsAD.jl supports the AD backend from ADTypes.jl, we can also allow this for an unifying interface