Skip to content

Commit b9c368b

Browse files
authored
Unify {untyped,typed}_{vector_,}varinfo constructor functions (#879)
* Unify {Untyped,Typed}{Vector,}VarInfo constructors * Update invocations * NTVarInfo * Fix tests * More fixes * Fixes * Fixes * Fixes * Use lowercase functions, don't deprecate VarInfo * Rewrite VarInfo docstring * Fix methods * Fix methods (really)
1 parent cc5e581 commit b9c368b

16 files changed

+436
-255
lines changed

HISTORY.md

+30-1
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,25 @@
44

55
**Breaking changes**
66

7-
### VarInfo constructor
7+
### VarInfo constructors
88

99
`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.
1010

11+
The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed.
12+
If you were not using this argument (most likely), then there is no change needed.
13+
If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below).
14+
15+
The `UntypedVarInfo` constructor and type is no longer exported.
16+
If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead.
17+
18+
The `TypedVarInfo` constructor and type is no longer exported.
19+
The _type_ has been replaced with `DynamicPPL.NTVarInfo`.
20+
The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`.
21+
22+
Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail.
23+
Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs.
24+
Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface.
25+
1126
### VarName prefixing behaviour
1227

1328
The way in which VarNames in submodels are prefixed has been changed.
@@ -53,6 +68,20 @@ outer() | (a.x=1.0,)
5368
If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
5469
(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.)
5570

71+
**Other changes**
72+
73+
While these are technically breaking, they are only internal changes and do not affect the public API.
74+
The following four functions have been added and/or reworked to make it easier to construct VarInfos with different types of metadata:
75+
76+
1. `DynamicPPL.untyped_varinfo([rng, ]model[, sampler, context])`
77+
2. `DynamicPPL.typed_varinfo([rng, ]model[, sampler, context])`
78+
3. `DynamicPPL.untyped_vector_varinfo([rng, ]model[, sampler, context])`
79+
4. `DynamicPPL.typed_vector_varinfo([rng, ]model[, sampler, context])`
80+
81+
The reason for this change is that there were several flavours of VarInfo.
82+
Some, like `typed_varinfo`, were easy to construct because we had convenience methods for them; however, the others were more difficult.
83+
This change makes it easier to access different VarInfo types, and also makes it more explicit which one you are constructing.
84+
5685
## 0.35.5
5786

5887
Several internal methods have been removed:

benchmarks/src/DynamicPPLBenchmarks.jl

+4-6
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ end
5252
5353
Create a benchmark suite for `model` using the selected varinfo type and AD backend.
5454
Available varinfo choices:
55-
• `:untyped` → uses `VarInfo()`
56-
• `:typed` → uses `VarInfo(model)`
55+
• `:untyped` → uses `DynamicPPL.untyped_varinfo(model)`
56+
• `:typed` → uses `DynamicPPL.typed_varinfo(model)`
5757
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
5858
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)
5959
@@ -67,11 +67,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
6767
suite = BenchmarkGroup()
6868

6969
vi = if varinfo_choice == :untyped
70-
vi = VarInfo()
71-
model(rng, vi)
72-
vi
70+
DynamicPPL.untyped_varinfo(rng, model)
7371
elseif varinfo_choice == :typed
74-
VarInfo(rng, model)
72+
DynamicPPL.typed_varinfo(rng, model)
7573
elseif varinfo_choice == :simple_namedtuple
7674
SimpleVarInfo{Float64}(model(rng))
7775
elseif varinfo_choice == :simple_dict

docs/src/api.md

+6-7
Original file line numberDiff line numberDiff line change
@@ -291,18 +291,17 @@ AbstractVarInfo
291291

292292
But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary.
293293

294-
For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods:
294+
#### `VarInfo`
295295

296296
```@docs
297-
DynamicPPL.untyped_varinfo
298-
DynamicPPL.typed_varinfo
297+
VarInfo
299298
```
300299

301-
#### `VarInfo`
302-
303300
```@docs
304-
VarInfo
305-
TypedVarInfo
301+
DynamicPPL.untyped_varinfo
302+
DynamicPPL.typed_varinfo
303+
DynamicPPL.untyped_vector_varinfo
304+
DynamicPPL.typed_vector_varinfo
306305
```
307306

308307
One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/).

docs/src/internals/varinfo.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,13 @@ Continuing from the example from the previous section, we can use a `VarInfo` wi
227227

228228
```@example varinfo-design
229229
# Type-unstable
230-
varinfo_untyped_vnv = DynamicPPL.VectorVarInfo(varinfo_untyped)
230+
varinfo_untyped_vnv = DynamicPPL.untyped_vector_varinfo(varinfo_untyped)
231231
varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)]
232232
```
233233

234234
```@example varinfo-design
235235
# Type-stable
236-
varinfo_typed_vnv = DynamicPPL.VectorVarInfo(varinfo_typed)
236+
varinfo_typed_vnv = DynamicPPL.typed_vector_varinfo(varinfo_typed)
237237
varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)]
238238
```
239239

src/DynamicPPL.jl

-2
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ import Base:
4545
# VarInfo
4646
export AbstractVarInfo,
4747
VarInfo,
48-
UntypedVarInfo,
49-
TypedVarInfo,
5048
SimpleVarInfo,
5149
push!!,
5250
empty!!,

src/abstract_varinfo.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,11 @@ julia> values_as(SimpleVarInfo(data), Vector)
247247
2.0
248248
```
249249
250-
`TypedVarInfo`:
250+
`VarInfo` with `NamedTuple` of `Metadata`:
251251
252252
```jldoctest
253253
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
254-
vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe());
254+
vi = DynamicPPL.typed_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe());
255255
256256
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
257257
@@ -273,11 +273,11 @@ julia> values_as(vi, Vector)
273273
2.0
274274
```
275275
276-
`UntypedVarInfo`:
276+
`VarInfo` with `Metadata`:
277277
278278
```jldoctest
279279
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
280-
vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi);
280+
vi = DynamicPPL.untyped_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe());
281281
282282
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
283283

src/sampler.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function default_varinfo(
8686
context::AbstractContext,
8787
)
8888
init_sampler = initialsampler(sampler)
89-
return VarInfo(rng, model, init_sampler, context)
89+
return typed_varinfo(rng, model, init_sampler, context)
9090
end
9191

9292
function AbstractMCMC.sample(

src/simple_varinfo.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`.
1010
$(FIELDS)
1111
1212
# Notes
13-
The major differences between this and `TypedVarInfo` are:
13+
The major differences between this and `NTVarInfo` are:
1414
1. `SimpleVarInfo` does not require linearization.
1515
2. `SimpleVarInfo` can use more efficient bijectors.
1616
3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either
@@ -244,7 +244,7 @@ function SimpleVarInfo{T}(
244244
end
245245

246246
# Constructor from `VarInfo`.
247-
function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D}
247+
function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D}
248248
return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...)
249249
end
250250
function SimpleVarInfo{T}(

src/test_utils/contexts.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
9494
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true)
9595
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true)
9696
# Typed varinfo.
97-
varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped)
97+
varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped)
9898
@test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true)
9999
@test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true)
100100
end

src/test_utils/varinfo.jl

+4-6
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,10 @@ function setup_varinfos(
2727
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
2828
)
2929
# VarInfo
30-
vi_untyped_metadata = VarInfo(DynamicPPL.Metadata())
31-
vi_untyped_vnv = VarInfo(DynamicPPL.VarNamedVector())
32-
model(vi_untyped_metadata)
33-
model(vi_untyped_vnv)
34-
vi_typed_metadata = DynamicPPL.TypedVarInfo(vi_untyped_metadata)
35-
vi_typed_vnv = DynamicPPL.TypedVarInfo(vi_untyped_vnv)
30+
vi_untyped_metadata = DynamicPPL.untyped_varinfo(model)
31+
vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model)
32+
vi_typed_metadata = DynamicPPL.typed_varinfo(model)
33+
vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model)
3634

3735
# SimpleVarInfo
3836
svi_typed = SimpleVarInfo(example_values)

0 commit comments

Comments
 (0)