Skip to content

Commit 77faa53

Browse files
committed
Test linked varinfos
Closes #891
1 parent 51a26b2 commit 77faa53

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

src/transforming.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ function tilde_assume(
1919
lp = Bijectors.logpdf_with_trans(right, r, !isinverse)
2020

2121
if istrans(vi, vn)
22-
@assert isinverse "Trying to link already transformed variables"
22+
isinverse || @warn "Trying to link an already transformed variable ($vn)"
2323
else
24-
@assert !isinverse "Trying to invlink non-transformed variables"
24+
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
2525
end
2626

2727
# Only transform if `!isinverse` since `vi[vn, right]`

test/ad.jl

+8-9
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,23 @@ using DynamicPPL: LogDensityFunction
2222
vns = DynamicPPL.TestUtils.varnames(m)
2323
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
2424

25-
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
26-
# TODO: This runs unlinked. Should we test linked as well?
27-
f = LogDensityFunction(m, varinfo)
25+
@testset "$(short_varinfo_name(varinfo))" for linked_varinfo in varinfos
26+
linked_varinfo = DynamicPPL.link(varinfo, m)
27+
f = LogDensityFunction(m, linked_varinfo)
2828
x = DynamicPPL.getparams(f)
2929
# Calculate reference logp + gradient of logp using ForwardDiff
30-
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
30+
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
3131
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
3232

3333
@testset "$adtype" for adtype in test_adtypes
34-
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
34+
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"
3535

3636
# Put predicates here to avoid long lines
3737
is_mooncake = adtype isa AutoMooncake
3838
is_1_10 = v"1.10" <= VERSION < v"1.11"
3939
is_1_11 = v"1.11" <= VERSION < v"1.12"
40-
is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
41-
is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict}
40+
is_svi_vnv = linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
41+
is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict}
4242

4343
# Mooncake doesn't work with several combinations of SimpleVarInfo.
4444
if is_mooncake && is_1_11 && is_svi_vnv
@@ -57,11 +57,10 @@ using DynamicPPL: LogDensityFunction
5757
ref_ldf, adtype
5858
)
5959
else
60-
# TODO: Should we test linked as well?
6160
@test DynamicPPL.TestUtils.AD.run_ad(
6261
m,
6362
adtype;
64-
varinfo=varinfo,
63+
varinfo=linked_varinfo,
6564
expected_value_and_grad=(ref_logp, ref_grad),
6665
) isa Any
6766
end

0 commit comments

Comments
 (0)