Skip to content

Commit e37e6bb

Browse files
committed
Test linked varinfos
Closes #891
1 parent 51a26b2 commit e37e6bb

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

src/transforming.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ function tilde_assume(
1919
lp = Bijectors.logpdf_with_trans(right, r, !isinverse)
2020

2121
if istrans(vi, vn)
22-
@assert isinverse "Trying to link already transformed variables"
22+
isinverse || @warn "Trying to link an already transformed variable ($vn)"
2323
else
24-
@assert !isinverse "Trying to invlink non-transformed variables"
24+
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
2525
end
2626

2727
# Only transform if `!isinverse` since `vi[vn, right]`

test/ad.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,24 @@ using DynamicPPL: LogDensityFunction
2222
vns = DynamicPPL.TestUtils.varnames(m)
2323
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
2424

25-
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
26-
# TODO: This runs unlinked. Should we test linked as well?
27-
f = LogDensityFunction(m, varinfo)
25+
@testset "$(short_varinfo_name(varinfo))" for linked_varinfo in varinfos
26+
linked_varinfo = DynamicPPL.link(varinfo, m)
27+
f = LogDensityFunction(m, linked_varinfo)
2828
x = DynamicPPL.getparams(f)
2929
# Calculate reference logp + gradient of logp using ForwardDiff
30-
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
30+
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
3131
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
3232

3333
@testset "$adtype" for adtype in test_adtypes
34-
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
34+
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"
3535

3636
# Put predicates here to avoid long lines
3737
is_mooncake = adtype isa AutoMooncake
3838
is_1_10 = v"1.10" <= VERSION < v"1.11"
3939
is_1_11 = v"1.11" <= VERSION < v"1.12"
40-
is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
41-
is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict}
40+
is_svi_vnv =
41+
linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
42+
is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict}
4243

4344
# Mooncake doesn't work with several combinations of SimpleVarInfo.
4445
if is_mooncake && is_1_11 && is_svi_vnv
@@ -57,11 +58,10 @@ using DynamicPPL: LogDensityFunction
5758
ref_ldf, adtype
5859
)
5960
else
60-
# TODO: Should we test linked as well?
6161
@test DynamicPPL.TestUtils.AD.run_ad(
6262
m,
6363
adtype;
64-
varinfo=varinfo,
64+
varinfo=linked_varinfo,
6565
expected_value_and_grad=(ref_logp, ref_grad),
6666
) isa Any
6767
end

0 commit comments

Comments
 (0)