@@ -4,19 +4,13 @@ using ADTypes: AbstractADType, AutoForwardDiff
4
4
using Chairmarks: @be
5
5
import DifferentiationInterface as DI
6
6
using DocStringExtensions
7
- using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
7
+ using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
8
8
using LogDensityProblems: logdensity, logdensity_and_gradient
9
9
using Random: Random, Xoshiro
10
10
using Statistics: median
11
11
using Test: @test
12
12
13
- export ADResult, run_ad
14
-
15
- # This function needed to work around the fact that different backends can
16
- # return different AbstractArrays for the gradient. See
17
- # https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more
18
- # context.
19
- _to_vec_f64 (x:: AbstractArray ) = x isa Vector{Float64} ? x : collect (Float64, x)
13
+ export ADResult, run_ad, ADIncorrectException
20
14
21
15
"""
22
16
REFERENCE_ADTYPE
@@ -27,33 +21,50 @@ it's the default AD backend used in Turing.jl.
27
21
const REFERENCE_ADTYPE = AutoForwardDiff ()
28
22
29
23
"""
30
- ADResult
24
+ ADIncorrectException{T<:AbstractFloat}
25
+
26
+ Exception thrown when an AD backend returns an incorrect value or gradient.
27
+
28
+ The type parameter `T` is the numeric type of the value and gradient.
29
+ """
30
+ struct ADIncorrectException{T<: AbstractFloat } <: Exception
31
+ value_expected:: T
32
+ value_actual:: T
33
+ grad_expected:: Vector{T}
34
+ grad_actual:: Vector{T}
35
+ end
36
+
37
+ """
38
+ ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
31
39
32
40
Data structure to store the results of the AD correctness test.
41
+
42
+ The type parameter `Tparams` is the numeric type of the parameters passed in;
43
+ `Tresult` is the type of the value and the gradient.
33
44
"""
34
- struct ADResult
45
+ struct ADResult{Tparams <: AbstractFloat ,Tresult <: AbstractFloat }
35
46
" The DynamicPPL model that was tested"
36
47
model:: Model
37
48
" The VarInfo that was used"
38
49
varinfo:: AbstractVarInfo
39
50
" The values at which the model was evaluated"
40
- params:: Vector{<:Real }
51
+ params:: Vector{Tparams }
41
52
" The AD backend that was tested"
42
53
adtype:: AbstractADType
43
54
" The absolute tolerance for the value of logp"
44
- value_atol:: Real
55
+ value_atol:: Tresult
45
56
" The absolute tolerance for the gradient of logp"
46
- grad_atol:: Real
57
+ grad_atol:: Tresult
47
58
" The expected value of logp"
48
- value_expected:: Union{Nothing,Float64 }
59
+ value_expected:: Union{Nothing,Tresult }
49
60
" The expected gradient of logp"
50
- grad_expected:: Union{Nothing,Vector{Float64 }}
61
+ grad_expected:: Union{Nothing,Vector{Tresult }}
51
62
" The value of logp (calculated using `adtype`)"
52
- value_actual:: Union{Nothing,Real }
63
+ value_actual:: Union{Nothing,Tresult }
53
64
" The gradient of logp (calculated using `adtype`)"
54
- grad_actual:: Union{Nothing,Vector{Float64 }}
65
+ grad_actual:: Union{Nothing,Vector{Tresult }}
55
66
" If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
56
- time_vs_primal:: Union{Nothing,Float64 }
67
+ time_vs_primal:: Union{Nothing,Tresult }
57
68
end
58
69
59
70
"""
64
75
benchmark=false,
65
76
value_atol=1e-6,
66
77
grad_atol=1e-6,
67
- varinfo::AbstractVarInfo=VarInfo(model),
68
- params::Vector{<:Real}=varinfo[:] ,
78
+ varinfo::AbstractVarInfo=link( VarInfo(model), model),
79
+ params::Union{Nothing, Vector{<:AbstractFloat}}=nothing ,
69
80
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
70
- expected_value_and_grad::Union{Nothing,Tuple{Real ,Vector{<:Real }}}=nothing,
81
+ expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat ,Vector{<:AbstractFloat }}}=nothing,
71
82
verbose=true,
72
83
)::ADResult
73
84
85
+ ### Description
86
+
74
87
Test the correctness and/or benchmark the AD backend `adtype` for the model
75
88
`model`.
76
89
77
90
Whether to test and benchmark is controlled by the `test` and `benchmark`
78
91
keyword arguments. By default, `test` is `true` and `benchmark` is `false`.
79
92
80
- Returns an [`ADResult`](@ref) object, which contains the results of the
81
- test and/or benchmark.
82
-
83
93
Note that to run AD successfully you will need to import the AD backend itself.
84
94
For example, to test with `AutoReverseDiff()` you will need to run `import
85
95
ReverseDiff`.
86
96
97
+ ### Arguments
98
+
87
99
There are two positional arguments, which absolutely must be provided:
88
100
89
101
1. `model` - The model being tested.
@@ -96,7 +108,9 @@ Everything else is optional, and can be categorised into several groups:
96
108
DynamicPPL contains several different types of VarInfo objects which change
97
109
the way model evaluation occurs. If you want to use a specific type of
98
110
VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to
99
- using a `TypedVarInfo` generated from the model.
111
+ using a linked `TypedVarInfo` generated from the model. Here, _linked_
112
+ means that the parameters in the VarInfo have been transformed to
113
+ unconstrained Euclidean space if they aren't already in that space.
100
114
101
115
2. _How to specify the parameters._
102
116
@@ -140,27 +154,40 @@ Everything else is optional, and can be categorised into several groups:
140
154
141
155
By default, this function prints messages when it runs. To silence it, set
142
156
`verbose=false`.
157
+
158
+ ### Returns / Throws
159
+
160
+ Returns an [`ADResult`](@ref) object, which contains the results of the
161
+ test and/or benchmark.
162
+
163
+ If `test` is `true` and the AD backend returns an incorrect value or gradient, an
164
+ `ADIncorrectException` is thrown. If a different error occurs, it will be
165
+ thrown as-is.
143
166
"""
144
167
function run_ad (
145
168
model:: Model ,
146
169
adtype:: AbstractADType ;
147
- test= true ,
148
- benchmark= false ,
149
- value_atol= 1e-6 ,
150
- grad_atol= 1e-6 ,
151
- varinfo:: AbstractVarInfo = VarInfo (model),
152
- params:: Vector{<:Real} = varinfo[:] ,
170
+ test:: Bool = true ,
171
+ benchmark:: Bool = false ,
172
+ value_atol:: AbstractFloat = 1e-6 ,
173
+ grad_atol:: AbstractFloat = 1e-6 ,
174
+ varinfo:: AbstractVarInfo = link ( VarInfo (model), model),
175
+ params:: Union{Nothing, Vector{<:AbstractFloat}} = nothing ,
153
176
reference_adtype:: AbstractADType = REFERENCE_ADTYPE,
154
- expected_value_and_grad:: Union{Nothing,Tuple{Real ,Vector{<:Real }}} = nothing ,
177
+ expected_value_and_grad:: Union{Nothing,Tuple{AbstractFloat ,Vector{<:AbstractFloat }}} = nothing ,
155
178
verbose= true ,
156
179
):: ADResult
180
+ if isnothing (params)
181
+ params = varinfo[:]
182
+ end
183
+ params = map (identity, params) # Concretise
184
+
157
185
verbose && @info " Running AD on $(model. f) with $(adtype) \n "
158
- params = map (identity, params)
159
186
verbose && println (" params : $(params) " )
160
187
ldf = LogDensityFunction (model, varinfo; adtype= adtype)
161
188
162
189
value, grad = logdensity_and_gradient (ldf, params)
163
- grad = _to_vec_f64 (grad)
190
+ grad = collect (grad)
164
191
verbose && println (" actual : $((value, grad)) " )
165
192
166
193
if test
@@ -172,10 +199,11 @@ function run_ad(
172
199
expected_value_and_grad
173
200
end
174
201
verbose && println (" expected : $((value_true, grad_true)) " )
175
- grad_true = _to_vec_f64 (grad_true)
176
- # Then compare
177
- @test isapprox (value, value_true; atol= value_atol)
178
- @test isapprox (grad, grad_true; atol= grad_atol)
202
+ grad_true = collect (grad_true)
203
+
204
+ exc () = throw (ADIncorrectException (value, value_true, grad, grad_true))
205
+ isapprox (value, value_true; atol= value_atol) || exc ()
206
+ isapprox (grad, grad_true; atol= grad_atol) || exc ()
179
207
else
180
208
value_true = nothing
181
209
grad_true = nothing
0 commit comments