Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e332d8c
remove the type `ParamSpaceSGD`
Red-Portal Sep 15, 2025
1f35cc9
run formatter
Red-Portal Sep 15, 2025
c8404b6
run formatter
Red-Portal Sep 15, 2025
0cc7538
run formatter
Red-Portal Sep 15, 2025
ede91c6
fix rename file paramspacesgd.jl to interface.jl
Red-Portal Oct 13, 2025
625f429
Merge branch 'remove_paramspacesgd' of github.com:TuringLang/Advanced…
Red-Portal Oct 13, 2025
e3c2761
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into remov…
Red-Portal Oct 13, 2025
683a09d
throw invalid state for unknown paramspacesgd type
Red-Portal Oct 13, 2025
570fe11
add docstring for union type of paramspacesgd algorithms
Red-Portal Oct 13, 2025
2d5f373
fix remove custom state types for paramspacesgd algorithms
Red-Portal Oct 13, 2025
e0221eb
fix remove custom state types for paramspacesgd
Red-Portal Oct 13, 2025
e51ab3c
fix file path
Red-Portal Oct 13, 2025
e49c680
fix bug in BijectorsExt
Red-Portal Oct 13, 2025
3c5b56f
fix include `SubSampleObjective` as part of `ParamSpaceSGD`
Red-Portal Oct 13, 2025
30f5160
fix formatting
Red-Portal Oct 13, 2025
008c4ea
fix revert adding SubsampledObjective into ParamSpaceSGD
Red-Portal Oct 13, 2025
8a18902
refactor flatten algorithms
Red-Portal Oct 13, 2025
b002e1e
fix error update paths in main file
Red-Portal Oct 13, 2025
1ba361f
refactor flatten the tests to reflect new structure
Red-Portal Oct 13, 2025
86baa07
fix file include path in tests
Red-Portal Oct 13, 2025
67e9375
fix missing operator in subsampledobj tests
Red-Portal Oct 13, 2025
9b2eabb
fix formatting
Red-Portal Oct 13, 2025
922e5d7
update docs
Red-Portal Oct 20, 2025
7d3ed86
Merge branch 'remove_paramspacesgd' of github.com:TuringLang/Advanced…
Red-Portal Oct 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ function AdvancedVI.init(
obj_st = AdvancedVI.init(rng, objective, adtype, q_init, prob, params, re)
avg_st = AdvancedVI.init(averager, params)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
return AdvancedVI.ParamSpaceSGDState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st)
return (
prob=prob,
q=q_init,
iteration=0,
grad_buf=grad_buf,
opt_st=opt_st,
obj_st=obj_st,
avg_st=avg_st,
)
end

function AdvancedVI.apply(
Expand Down
18 changes: 8 additions & 10 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,10 @@ export optimize
include("utils.jl")
include("optimize.jl")

## Parameter Space SGD
include("algorithms/paramspacesgd/abstractobjective.jl")
include("algorithms/paramspacesgd/paramspacesgd.jl")
## Parameter Space SGD Implementations

export ParamSpaceSGD
include("algorithms/abstractobjective.jl")

## Parameter Space SGD Implementations
### ELBO Maximization

abstract type AbstractEntropyEstimator end
Expand All @@ -304,10 +301,10 @@ Estimate the entropy of `q`.
"""
function estimate_entropy end

include("algorithms/paramspacesgd/subsampledobjective.jl")
include("algorithms/paramspacesgd/repgradelbo.jl")
include("algorithms/paramspacesgd/scoregradelbo.jl")
include("algorithms/paramspacesgd/entropy.jl")
include("algorithms/subsampledobjective.jl")
include("algorithms/repgradelbo.jl")
include("algorithms/scoregradelbo.jl")
include("algorithms/entropy.jl")

export RepGradELBO,
ScoreGradELBO,
Expand All @@ -318,7 +315,8 @@ export RepGradELBO,
StickingTheLandingEntropyZeroGradient,
SubsampledObjective

include("algorithms/paramspacesgd/constructors.jl")
include("algorithms/constructors.jl")
include("algorithms/interface.jl")

export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,43 @@ KL divergence minimization by running stochastic gradient descent with the repar
- `operator::AbstractOperator`: Operator to be applied after each gradient descent step. (default: `IdentityOperator()`)
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)

# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
- `params`: Current variational parameters.
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
- `gradient`: The estimated (possibly stochastic) gradient.

# Requirements
- The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`.
- The variational approximation ``q_{\\lambda}`` implements `rand`.
- The target distribution and the variational approximation have the same support.
- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend.
- Additonal requirements on `q` may apply depending on the choice of `entropy`.
"""
struct KLMinRepGradDescent{
Obj<:Union{<:RepGradELBO,<:SubsampledObjective},
AD<:ADTypes.AbstractADType,
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:AbstractOperator,
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
averager::Avg
operator::Op
end

function KLMinRepGradDescent(
adtype::ADTypes.AbstractADType;
entropy::Union{<:ClosedFormEntropy,<:StickingTheLandingEntropy,<:MonteCarloEntropy}=ClosedFormEntropy(),
Expand All @@ -39,7 +69,11 @@ function KLMinRepGradDescent(
else
SubsampledObjective(RepGradELBO(n_samples; entropy=entropy), subsampling)
end
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
return KLMinRepGradDescent{
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
}(
objective, adtype, optimizer, averager, operator
)
end

const ADVI = KLMinRepGradDescent
Expand All @@ -63,12 +97,42 @@ Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed.
- `averager::AbstractAverager`: Parameter averaging strategy. (default: `PolynomialAveraging()`)
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)

# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
- `params`: Current variational parameters.
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
- `gradient`: The estimated (possibly stochastic) gradient.

# Requirements
- The variational family is `MvLocationScale`.
- The target distribution and the variational approximation have the same support.
- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend.
- Additonal requirements on `q` may apply depending on the choice of `entropy_zerograd`.
"""
struct KLMinRepGradProxDescent{
Obj<:Union{<:RepGradELBO,<:SubsampledObjective},
AD<:ADTypes.AbstractADType,
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:ProximalLocationScaleEntropy,
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
averager::Avg
operator::Op
end

function KLMinRepGradProxDescent(
adtype::ADTypes.AbstractADType;
entropy_zerograd::Union{
Expand All @@ -85,7 +149,11 @@ function KLMinRepGradProxDescent(
else
SubsampledObjective(RepGradELBO(n_samples; entropy=entropy_zerograd), subsampling)
end
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
return KLMinRepGradProxDescent{
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
}(
objective, adtype, optimizer, averager, operator
)
end

"""
Expand All @@ -106,15 +174,45 @@ KL divergence minimization by running stochastic gradient descent with the score
- `operator::Union{<:IdentityOperator, <:ClipScale}`: Operator to be applied after each gradient descent step. (default: `IdentityOperator()`)
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)

# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
- `params`: Current variational parameters.
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
- `gradient`: The estimated (possibly stochastic) gradient.

# Requirements
- The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`.
- The variational approximation ``q_{\\lambda}`` implements `rand`.
- The variational approximation ``q_{\\lambda}`` implements `logpdf(q, x)`, which should also be differentiable with respect to `x`.
- The target distribution and the variational approximation have the same support.
"""
struct KLMinScoreGradDescent{
Obj<:Union{<:ScoreGradELBO,<:SubsampledObjective},
AD<:ADTypes.AbstractADType,
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:AbstractOperator,
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
averager::Avg
operator::Op
end

function KLMinScoreGradDescent(
adtype::ADTypes.AbstractADType;
optimizer::Union{<:Descent,<:DoG,<:DoWG}=DoWG(),
optimizer::Optimisers.AbstractRule=DoWG(),
n_samples::Int=1,
averager::AbstractAverager=PolynomialAveraging(),
operator::AbstractOperator=IdentityOperator(),
Expand All @@ -125,7 +223,11 @@ function KLMinScoreGradDescent(
else
SubsampledObjective(ScoreGradELBO(n_samples), subsampling)
end
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
return KLMinScoreGradDescent{
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
}(
objective, adtype, optimizer, averager, operator
)
end

const BBVI = KLMinScoreGradDescent
File renamed without changes.
83 changes: 83 additions & 0 deletions src/algorithms/interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

"""
This family of algorithms (`<:KLMinRepGradDescent`,`<:KLMinRepGradProxDescent`,`<:KLMinScoreGradDescent`) applies stochastic gradient descent (SGD) to the variational `objective` over the (Euclidean) space of variational parameters.
The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`.
This requires the variational approximation to be marked as a functor through `Functors.@functor`.
"""
const ParamSpaceSGD = Union{
<:KLMinRepGradDescent,<:KLMinRepGradProxDescent,<:KLMinScoreGradDescent
}

function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init, prob)
(; adtype, optimizer, averager, objective, operator) = alg
if q_init isa AdvancedVI.MvLocationScale && operator isa AdvancedVI.IdentityOperator
@warn(
"IdentityOperator is used with a variational family <:MvLocationScale. Optimization can easily fail under this combination due to singular scale matrices. Consider using the operator `ClipScale` in the algorithm instead.",
)
end
params, re = Optimisers.destructure(q_init)
opt_st = Optimisers.setup(optimizer, params)
obj_st = init(rng, objective, adtype, q_init, prob, params, re)
avg_st = init(averager, params)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
return (
prob=prob,
q=q_init,
iteration=0,
grad_buf=grad_buf,
opt_st=opt_st,
obj_st=obj_st,
avg_st=avg_st,
)
Comment on lines +23 to +31
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only have one comment, which is that it's probably better to keep this as a struct (although I recognise you may have some difficulty in choosing a name for it). NamedTuples are too flexible, if you return one you don't know what fields are in it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose we stick with this for now and improve this later once we can make things more organized. (The new structure in the PR is making a whole lot of things implicit anyway, so I think the fields being implicit doesn't make things much worse)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'm not that fussed.

end

function output(alg::ParamSpaceSGD, state)
params_avg = value(alg.averager, state.avg_st)
_, re = Optimisers.destructure(state.q)
return re(params_avg)
end

function step(
rng::Random.AbstractRNG, alg::ParamSpaceSGD, state, callback, objargs...; kwargs...
)
(; adtype, objective, operator, averager) = alg
(; prob, q, iteration, grad_buf, opt_st, obj_st, avg_st) = state

iteration += 1

params, re = Optimisers.destructure(q)

grad_buf, obj_st, info = estimate_gradient!(
rng, objective, adtype, grad_buf, obj_st, params, re, objargs...
)

grad = DiffResults.gradient(grad_buf)
opt_st, params = Optimisers.update!(opt_st, params, grad)
params = apply(operator, typeof(q), opt_st, params, re)
avg_st = apply(averager, avg_st, params)

state = (
prob=prob,
q=re(params),
iteration=iteration,
grad_buf=grad_buf,
opt_st=opt_st,
obj_st=obj_st,
avg_st=avg_st,
)

if !isnothing(callback)
averaged_params = value(averager, avg_st)
info′ = callback(;
rng,
iteration,
restructure=re,
params=params,
averaged_params=averaged_params,
gradient=grad,
state=state,
)
info = !isnothing(info′) ? merge(info′, info) : info
end
state, false, info
end
Loading
Loading