Skip to content

Commit be27636

Browse files
authored
Link varinfo by default in AD testing utilities; make test suite run on linked varinfos (#890)
* Link VarInfo by default * Tweak interface * Fix tests * Fix interface so that callers can inspect results * Document * Fix tests * Fix changelog * Test linked varinfos Closes #891 * Fix docstring + use AbstractFloat
1 parent 5ba3530 commit be27636

File tree

6 files changed

+91
-54
lines changed

6 files changed

+91
-54
lines changed

HISTORY.md

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44

55
**Breaking changes**
66

7+
### AD testing utilities
8+
9+
`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default.
10+
To disable this, pass the `linked=false` keyword argument.
11+
If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure.
12+
This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information.
13+
From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`.
14+
15+
### SimpleVarInfo linking / invlinking
16+
17+
Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error.
18+
719
### VarInfo constructors
820

921
`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.

docs/src/api.md

+1
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL
212212
```@docs
213213
DynamicPPL.TestUtils.AD.run_ad
214214
DynamicPPL.TestUtils.AD.ADResult
215+
DynamicPPL.TestUtils.AD.ADIncorrectException
215216
```
216217

217218
## Demo models

src/test_utils/ad.jl

+66-38
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,13 @@ using ADTypes: AbstractADType, AutoForwardDiff
44
using Chairmarks: @be
55
import DifferentiationInterface as DI
66
using DocStringExtensions
7-
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
7+
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
88
using LogDensityProblems: logdensity, logdensity_and_gradient
99
using Random: Random, Xoshiro
1010
using Statistics: median
1111
using Test: @test
1212

13-
export ADResult, run_ad
14-
15-
# This function needed to work around the fact that different backends can
16-
# return different AbstractArrays for the gradient. See
17-
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more
18-
# context.
19-
_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x)
13+
export ADResult, run_ad, ADIncorrectException
2014

2115
"""
2216
REFERENCE_ADTYPE
@@ -27,33 +21,50 @@ it's the default AD backend used in Turing.jl.
2721
const REFERENCE_ADTYPE = AutoForwardDiff()
2822

2923
"""
30-
ADResult
24+
ADIncorrectException{T<:AbstractFloat}
25+
26+
Exception thrown when an AD backend returns an incorrect value or gradient.
27+
28+
The type parameter `T` is the numeric type of the value and gradient.
29+
"""
30+
struct ADIncorrectException{T<:AbstractFloat} <: Exception
31+
value_expected::T
32+
value_actual::T
33+
grad_expected::Vector{T}
34+
grad_actual::Vector{T}
35+
end
36+
37+
"""
38+
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
3139
3240
Data structure to store the results of the AD correctness test.
41+
42+
The type parameter `Tparams` is the numeric type of the parameters passed in;
43+
`Tresult` is the type of the value and the gradient.
3344
"""
34-
struct ADResult
45+
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
3546
"The DynamicPPL model that was tested"
3647
model::Model
3748
"The VarInfo that was used"
3849
varinfo::AbstractVarInfo
3950
"The values at which the model was evaluated"
40-
params::Vector{<:Real}
51+
params::Vector{Tparams}
4152
"The AD backend that was tested"
4253
adtype::AbstractADType
4354
"The absolute tolerance for the value of logp"
44-
value_atol::Real
55+
value_atol::Tresult
4556
"The absolute tolerance for the gradient of logp"
46-
grad_atol::Real
57+
grad_atol::Tresult
4758
"The expected value of logp"
48-
value_expected::Union{Nothing,Float64}
59+
value_expected::Union{Nothing,Tresult}
4960
"The expected gradient of logp"
50-
grad_expected::Union{Nothing,Vector{Float64}}
61+
grad_expected::Union{Nothing,Vector{Tresult}}
5162
"The value of logp (calculated using `adtype`)"
52-
value_actual::Union{Nothing,Real}
63+
value_actual::Union{Nothing,Tresult}
5364
"The gradient of logp (calculated using `adtype`)"
54-
grad_actual::Union{Nothing,Vector{Float64}}
65+
grad_actual::Union{Nothing,Vector{Tresult}}
5566
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
56-
time_vs_primal::Union{Nothing,Float64}
67+
time_vs_primal::Union{Nothing,Tresult}
5768
end
5869

5970
"""
@@ -64,26 +75,27 @@ end
6475
benchmark=false,
6576
value_atol=1e-6,
6677
grad_atol=1e-6,
67-
varinfo::AbstractVarInfo=VarInfo(model),
68-
params::Vector{<:Real}=varinfo[:],
78+
varinfo::AbstractVarInfo=link(VarInfo(model), model),
79+
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
6980
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
70-
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
81+
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
7182
verbose=true,
7283
)::ADResult
7384
85+
### Description
86+
7487
Test the correctness and/or benchmark the AD backend `adtype` for the model
7588
`model`.
7689
7790
Whether to test and benchmark is controlled by the `test` and `benchmark`
7891
keyword arguments. By default, `test` is `true` and `benchmark` is `false`.
7992
80-
Returns an [`ADResult`](@ref) object, which contains the results of the
81-
test and/or benchmark.
82-
8393
Note that to run AD successfully you will need to import the AD backend itself.
8494
For example, to test with `AutoReverseDiff()` you will need to run `import
8595
ReverseDiff`.
8696
97+
### Arguments
98+
8799
There are two positional arguments, which absolutely must be provided:
88100
89101
1. `model` - The model being tested.
@@ -96,7 +108,9 @@ Everything else is optional, and can be categorised into several groups:
96108
DynamicPPL contains several different types of VarInfo objects which change
97109
the way model evaluation occurs. If you want to use a specific type of
98110
VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to
99-
using a `TypedVarInfo` generated from the model.
111+
using a linked `TypedVarInfo` generated from the model. Here, _linked_
112+
means that the parameters in the VarInfo have been transformed to
113+
unconstrained Euclidean space if they aren't already in that space.
100114
101115
2. _How to specify the parameters._
102116
@@ -140,27 +154,40 @@ Everything else is optional, and can be categorised into several groups:
140154
141155
By default, this function prints messages when it runs. To silence it, set
142156
`verbose=false`.
157+
158+
### Returns / Throws
159+
160+
Returns an [`ADResult`](@ref) object, which contains the results of the
161+
test and/or benchmark.
162+
163+
If `test` is `true` and the AD backend returns an incorrect value or gradient, an
164+
`ADIncorrectException` is thrown. If a different error occurs, it will be
165+
thrown as-is.
143166
"""
144167
function run_ad(
145168
model::Model,
146169
adtype::AbstractADType;
147-
test=true,
148-
benchmark=false,
149-
value_atol=1e-6,
150-
grad_atol=1e-6,
151-
varinfo::AbstractVarInfo=VarInfo(model),
152-
params::Vector{<:Real}=varinfo[:],
170+
test::Bool=true,
171+
benchmark::Bool=false,
172+
value_atol::AbstractFloat=1e-6,
173+
grad_atol::AbstractFloat=1e-6,
174+
varinfo::AbstractVarInfo=link(VarInfo(model), model),
175+
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
153176
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
154-
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
177+
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
155178
verbose=true,
156179
)::ADResult
180+
if isnothing(params)
181+
params = varinfo[:]
182+
end
183+
params = map(identity, params) # Concretise
184+
157185
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
158-
params = map(identity, params)
159186
verbose && println(" params : $(params)")
160187
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
161188

162189
value, grad = logdensity_and_gradient(ldf, params)
163-
grad = _to_vec_f64(grad)
190+
grad = collect(grad)
164191
verbose && println(" actual : $((value, grad))")
165192

166193
if test
@@ -172,10 +199,11 @@ function run_ad(
172199
expected_value_and_grad
173200
end
174201
verbose && println(" expected : $((value_true, grad_true))")
175-
grad_true = _to_vec_f64(grad_true)
176-
# Then compare
177-
@test isapprox(value, value_true; atol=value_atol)
178-
@test isapprox(grad, grad_true; atol=grad_atol)
202+
grad_true = collect(grad_true)
203+
204+
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
205+
isapprox(value, value_true; atol=value_atol) || exc()
206+
isapprox(grad, grad_true; atol=grad_atol) || exc()
179207
else
180208
value_true = nothing
181209
grad_true = nothing

src/transforming.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ function tilde_assume(
1919
lp = Bijectors.logpdf_with_trans(right, r, !isinverse)
2020

2121
if istrans(vi, vn)
22-
@assert isinverse "Trying to link already transformed variables"
22+
isinverse || @warn "Trying to link an already transformed variable ($vn)"
2323
else
24-
@assert !isinverse "Trying to invlink non-transformed variables"
24+
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
2525
end
2626

2727
# Only transform if `!isinverse` since `vi[vn, right]`

test/ad.jl

+10-8
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,23 @@ using DynamicPPL: LogDensityFunction
2323
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
2424

2525
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
26-
f = LogDensityFunction(m, varinfo)
26+
linked_varinfo = DynamicPPL.link(varinfo, m)
27+
f = LogDensityFunction(m, linked_varinfo)
2728
x = DynamicPPL.getparams(f)
2829
# Calculate reference logp + gradient of logp using ForwardDiff
29-
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
30+
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
3031
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
3132

3233
@testset "$adtype" for adtype in test_adtypes
33-
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
34+
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"
3435

3536
# Put predicates here to avoid long lines
3637
is_mooncake = adtype isa AutoMooncake
3738
is_1_10 = v"1.10" <= VERSION < v"1.11"
3839
is_1_11 = v"1.11" <= VERSION < v"1.12"
39-
is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
40-
is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict}
40+
is_svi_vnv =
41+
linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
42+
is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict}
4143

4244
# Mooncake doesn't work with several combinations of SimpleVarInfo.
4345
if is_mooncake && is_1_11 && is_svi_vnv
@@ -56,12 +58,12 @@ using DynamicPPL: LogDensityFunction
5658
ref_ldf, adtype
5759
)
5860
else
59-
DynamicPPL.TestUtils.AD.run_ad(
61+
@test DynamicPPL.TestUtils.AD.run_ad(
6062
m,
6163
adtype;
62-
varinfo=varinfo,
64+
varinfo=linked_varinfo,
6365
expected_value_and_grad=(ref_logp, ref_grad),
64-
)
66+
) isa Any
6567
end
6668
end
6769
end

test/simple_varinfo.jl

-6
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,6 @@
111111
# Should be approx. the same as the "lazy" transformation.
112112
@test logjoint(model, vi_linked) lp_linked
113113

114-
# TODO: Should not `VarInfo` also error here? The current implementation
115-
# only warns and acts as a no-op.
116-
if vi isa SimpleVarInfo
117-
@test_throws AssertionError link!!(vi_linked, model)
118-
end
119-
120114
# `invlink!!`
121115
vi_invlinked = invlink!!(deepcopy(vi_linked), model)
122116
lp_invlinked = getlogp(vi_invlinked)

0 commit comments

Comments
 (0)