Skip to content

Commit 89bc2e1

Browse files
committed
renamed _getmodel to getmodel, _setmodel to setmodel, and
`_varinfo` to `varinfo_from_logdensityfn`
1 parent 414a077 commit 89bc2e1

File tree

1 file changed

+37
-20
lines changed

1 file changed

+37
-20
lines changed

src/mcmc/abstractmcmc.jl

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,41 @@ function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transit
1717
return transition_to_turing(parent(f), transition)
1818
end
1919

20-
_getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = _getmodel(parent(f))
21-
_getmodel(f::DynamicPPL.LogDensityFunction) = f.model
20+
"""
21+
getmodel(f)
22+
23+
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
24+
"""
25+
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f))
26+
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
2227

2328
# FIXME: We'll have to overload this for every AD backend since some of the AD backends
2429
# will cache certain parts of a given model, e.g. the tape, which results in a discrepancy
2530
# between the primal (forward) and dual (backward).
26-
function _setmodel(f::LogDensityProblemsAD.ADGradientWrapper, model::DynamicPPL.Model)
27-
return Accessors.@set f.= _setmodel(f.ℓ, model)
31+
"""
32+
setmodel(f, model)
33+
34+
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
35+
36+
!!! warning
37+
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
38+
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
39+
might require recompilation of the gradient tape, depending on the AD backend.
40+
"""
41+
function setmodel(f::LogDensityProblemsAD.ADGradientWrapper, model::DynamicPPL.Model)
42+
return Accessors.@set f.= setmodel(f.ℓ, model)
2843
end
29-
function _setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
44+
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
3045
return Accessors.@set f.model = model
3146
end
3247

33-
_varinfo(f::LogDensityProblemsAD.ADGradientWrapper) = _varinfo(parent(f))
48+
function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
49+
return varinfo_from_logdensityfn(parent(f))
50+
end
3451
_varinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
3552

3653
function varinfo(state::TuringState)
37-
θ = getparams(_getmodel(state.logdensity), state.state)
54+
θ = getparams(getmodel(state.logdensity), state.state)
3855
# TODO: Do we need to link here first?
3956
return DynamicPPL.unflatten(_varinfo(state.logdensity), θ)
4057
end
@@ -67,17 +84,14 @@ function recompute_logprob!!(
6784
rng::Random.AbstractRNG, # TODO: Do we need the `rng` here?
6885
model::DynamicPPL.Model,
6986
sampler::DynamicPPL.Sampler{<:ExternalSampler},
70-
state
87+
state,
7188
)
7289
# Re-using the log-density function from the `state` and updating only the `model` field,
7390
# since the `model` might now contain different conditioning values.
74-
f = _setmodel(state.logdensity, model)
91+
f = setmodel(state.logdensity, model)
7592
# Recompute the log-probability with the new `model`.
7693
state_inner = recompute_logprob!!(
77-
rng,
78-
AbstractMCMC.LogDensityModel(f),
79-
sampler.alg.sampler,
80-
state.state
94+
rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state
8195
)
8296
return state_to_turing(f, state_inner)
8397
end
@@ -86,15 +100,13 @@ function recompute_logprob!!(
86100
rng::Random.AbstractRNG,
87101
model::AbstractMCMC.LogDensityModel,
88102
sampler::AdvancedHMC.AbstractHMCSampler,
89-
state::AdvancedHMC.HMCState
103+
state::AdvancedHMC.HMCState,
90104
)
91105
# Construct hamiltionian.
92106
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
93107
# Re-compute the log-probability and gradient.
94108
return Accessors.@set state.transition.z = AdvancedHMC.phasepoint(
95-
hamiltonian,
96-
state.transition.z.θ,
97-
state.transition.z.r,
109+
hamiltonian, state.transition.z.θ, state.transition.z.r
98110
)
99111
end
100112

@@ -115,7 +127,7 @@ function AbstractMCMC.step(
115127
sampler_wrapper::Sampler{<:ExternalSampler};
116128
initial_state=nothing,
117129
initial_params=nothing,
118-
kwargs...
130+
kwargs...,
119131
)
120132
alg = sampler_wrapper.alg
121133
sampler = alg.sampler
@@ -145,7 +157,12 @@ function AbstractMCMC.step(
145157
)
146158
else
147159
transition_inner, state_inner = AbstractMCMC.step(
148-
rng, AbstractMCMC.LogDensityModel(f), sampler, initial_state; initial_params, kwargs...
160+
rng,
161+
AbstractMCMC.LogDensityModel(f),
162+
sampler,
163+
initial_state;
164+
initial_params,
165+
kwargs...,
149166
)
150167
end
151168
# Update the `state`
@@ -157,7 +174,7 @@ function AbstractMCMC.step(
157174
model::DynamicPPL.Model,
158175
sampler_wrapper::Sampler{<:ExternalSampler},
159176
state::TuringState;
160-
kwargs...
177+
kwargs...,
161178
)
162179
sampler = sampler_wrapper.alg.sampler
163180
f = state.logdensity

0 commit comments

Comments
 (0)