|
64 | 64 | benchmark=false, |
65 | 65 | value_atol=1e-6, |
66 | 66 | grad_atol=1e-6, |
67 | | - varinfo::AbstractVarInfo=link(VarInfo(model), model), |
68 | | - params::Vector{<:Real}=varinfo[:], |
| 67 | + linked::Bool=true, |
| 68 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 69 | + params::Union{Nothing,Vector{<:Real}}=nothing, |
69 | 70 | reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, |
70 | 71 | expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, |
71 | 72 | verbose=true, |
@@ -96,10 +97,12 @@ Everything else is optional, and can be categorised into several groups: |
96 | 97 | DynamicPPL contains several different types of VarInfo objects which change |
97 | 98 | the way model evaluation occurs. If you want to use a specific type of |
98 | 99 | VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to |
99 | | - using a `TypedVarInfo` generated from the model. It will also perform |
100 | | - _linking_, that is, the parameters in the VarInfo will be transformed to |
101 | | - unconstrained Euclidean space if they aren't already in that space. Note |
102 | | - that the act of linking may change the length of the parameters. |
| 100 | + using a `TypedVarInfo` generated from the model. |
| 101 | +
|
| 102 | + It will also perform _linking_, that is, the parameters in the VarInfo will |
| 103 | + be transformed to unconstrained Euclidean space if they aren't already in |
| 104 | + that space. Note that the act of linking may change the length of the |
| 105 | + parameters. To disable linking, set `linked=false`. |
103 | 106 |
|
104 | 107 | 2. _How to specify the parameters._ |
105 | 108 |
|
@@ -151,14 +154,22 @@ function run_ad( |
151 | 154 | benchmark=false, |
152 | 155 | value_atol=1e-6, |
153 | 156 | grad_atol=1e-6, |
154 | | - varinfo::AbstractVarInfo=link(VarInfo(model), model), |
155 | | - params::Vector{<:Real}=varinfo[:], |
| 157 | + linked::Bool=true, |
| 158 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 159 | + params::Union{Nothing,Vector{<:Real}}=nothing, |
156 | 160 | reference_adtype::AbstractADType=REFERENCE_ADTYPE, |
157 | 161 | expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, |
158 | 162 | verbose=true, |
159 | 163 | )::ADResult |
160 | | - verbose && @info "Running AD on $(model.f) with $(adtype)\n" |
| 164 | + if linked |
| 165 | + varinfo = link(varinfo, model) |
| 166 | + end |
| 167 | + if isnothing(params) |
| 168 | + params = varinfo[:] |
| 169 | + end |
161 | 170 | params = map(identity, params) |
| 171 | + |
| 172 | + verbose && @info "Running AD on $(model.f) with $(adtype)\n" |
162 | 173 | verbose && println(" params : $(params)") |
163 | 174 | ldf = LogDensityFunction(model, varinfo; adtype=adtype) |
164 | 175 |
|
|
0 commit comments