@@ -17,24 +17,41 @@ function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transit
17
17
return transition_to_turing (parent (f), transition)
18
18
end
19
19
20
- _getmodel (f:: LogDensityProblemsAD.ADGradientWrapper ) = _getmodel (parent (f))
21
- _getmodel (f:: DynamicPPL.LogDensityFunction ) = f. model
20
+ """
21
+ getmodel(f)
22
+
23
+ Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
24
+ """
25
+ getmodel (f:: LogDensityProblemsAD.ADGradientWrapper ) = getmodel (parent (f))
26
+ getmodel (f:: DynamicPPL.LogDensityFunction ) = f. model
22
27
23
28
# FIXME : We'll have to overload this for every AD backend since some of the AD backends
24
29
# will cache certain parts of a given model, e.g. the tape, which results in a discrepancy
25
30
# between the primal (forward) and dual (backward).
26
- function _setmodel (f:: LogDensityProblemsAD.ADGradientWrapper , model:: DynamicPPL.Model )
27
- return Accessors. @set f. ℓ = _setmodel (f. ℓ, model)
31
+ """
32
+ setmodel(f, model)
33
+
34
+ Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
35
+
36
+ !!! warning
37
+ Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
38
+ `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
39
+ might require recompilation of the gradient tape, depending on the AD backend.
40
+ """
41
+ function setmodel (f:: LogDensityProblemsAD.ADGradientWrapper , model:: DynamicPPL.Model )
42
+ return Accessors. @set f. ℓ = setmodel (f. ℓ, model)
28
43
end
29
- function _setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
44
+ function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
30
45
return Accessors. @set f. model = model
31
46
end
32
47
33
- _varinfo (f:: LogDensityProblemsAD.ADGradientWrapper ) = _varinfo (parent (f))
48
+ function varinfo_from_logdensityfn (f:: LogDensityProblemsAD.ADGradientWrapper )
49
+ return varinfo_from_logdensityfn (parent (f))
50
+ end
34
51
_varinfo (f:: DynamicPPL.LogDensityFunction ) = f. varinfo
35
52
36
53
function varinfo (state:: TuringState )
37
- θ = getparams (_getmodel (state. logdensity), state. state)
54
+ θ = getparams (getmodel (state. logdensity), state. state)
38
55
# TODO : Do we need to link here first?
39
56
return DynamicPPL. unflatten (_varinfo (state. logdensity), θ)
40
57
end
@@ -67,17 +84,14 @@ function recompute_logprob!!(
67
84
rng:: Random.AbstractRNG , # TODO : Do we need the `rng` here?
68
85
model:: DynamicPPL.Model ,
69
86
sampler:: DynamicPPL.Sampler{<:ExternalSampler} ,
70
- state
87
+ state,
71
88
)
72
89
# Re-using the log-density function from the `state` and updating only the `model` field,
73
90
# since the `model` might now contain different conditioning values.
74
- f = _setmodel (state. logdensity, model)
91
+ f = setmodel (state. logdensity, model)
75
92
# Recompute the log-probability with the new `model`.
76
93
state_inner = recompute_logprob!! (
77
- rng,
78
- AbstractMCMC. LogDensityModel (f),
79
- sampler. alg. sampler,
80
- state. state
94
+ rng, AbstractMCMC. LogDensityModel (f), sampler. alg. sampler, state. state
81
95
)
82
96
return state_to_turing (f, state_inner)
83
97
end
@@ -86,15 +100,13 @@ function recompute_logprob!!(
86
100
rng:: Random.AbstractRNG ,
87
101
model:: AbstractMCMC.LogDensityModel ,
88
102
sampler:: AdvancedHMC.AbstractHMCSampler ,
89
- state:: AdvancedHMC.HMCState
103
+ state:: AdvancedHMC.HMCState ,
90
104
)
91
105
# Construct hamiltionian.
92
106
hamiltonian = AdvancedHMC. Hamiltonian (state. metric, model)
93
107
# Re-compute the log-probability and gradient.
94
108
return Accessors. @set state. transition. z = AdvancedHMC. phasepoint (
95
- hamiltonian,
96
- state. transition. z. θ,
97
- state. transition. z. r,
109
+ hamiltonian, state. transition. z. θ, state. transition. z. r
98
110
)
99
111
end
100
112
@@ -115,7 +127,7 @@ function AbstractMCMC.step(
115
127
sampler_wrapper:: Sampler{<:ExternalSampler} ;
116
128
initial_state= nothing ,
117
129
initial_params= nothing ,
118
- kwargs...
130
+ kwargs... ,
119
131
)
120
132
alg = sampler_wrapper. alg
121
133
sampler = alg. sampler
@@ -145,7 +157,12 @@ function AbstractMCMC.step(
145
157
)
146
158
else
147
159
transition_inner, state_inner = AbstractMCMC. step (
148
- rng, AbstractMCMC. LogDensityModel (f), sampler, initial_state; initial_params, kwargs...
160
+ rng,
161
+ AbstractMCMC. LogDensityModel (f),
162
+ sampler,
163
+ initial_state;
164
+ initial_params,
165
+ kwargs... ,
149
166
)
150
167
end
151
168
# Update the `state`
@@ -157,7 +174,7 @@ function AbstractMCMC.step(
157
174
model:: DynamicPPL.Model ,
158
175
sampler_wrapper:: Sampler{<:ExternalSampler} ,
159
176
state:: TuringState ;
160
- kwargs...
177
+ kwargs... ,
161
178
)
162
179
sampler = sampler_wrapper. alg. sampler
163
180
f = state. logdensity
0 commit comments