diff --git a/GeneralisedFilters/Project.toml b/GeneralisedFilters/Project.toml index 47390ab..34ea995 100644 --- a/GeneralisedFilters/Project.toml +++ b/GeneralisedFilters/Project.toml @@ -10,10 +10,12 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d" +KalmanFilters = "272a6111-cf0e-4c1b-a056-8d658cb314ee" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 878ea84..9a67641 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -1,9 +1,11 @@ export KalmanFilter, filter, BatchKalmanFilter using GaussianDistributions using CUDA: i32 +using PDMats +using KalmanFilters import LinearAlgebra: hermitianpart -export KalmanFilter, KF, KalmanSmoother, KS +export KalmanFilter, KF, KalmanSmoother, KS, SRKF struct KalmanFilter <: AbstractFilter end @@ -129,6 +131,58 @@ function update( return BatchGaussianDistribution(μ_filt, Σ_filt), dropdims(log_likes; dims=1) end +## SQUARE-ROOT KALMAN FILTER ############################################################### + +""" + SRKF() + + A square-root Kalman filter. + + Implemented by wrapping KalmanFilters.jl. + """ +struct SRKF <: AbstractFilter end + +function initialise( + rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::SRKF; kwargs... +) + μ0, Σ0 = calc_initial(model.dyn; kwargs...) + return Gaussian(μ0, Σ0) +end + +function predict( + rng::AbstractRNG, + model::LinearGaussianStateSpaceModel, + ::SRKF, + iter::Integer, + state::Gaussian, + observation=nothing; + kwargs..., +) + μ, Σ = GaussianDistributions.pair(state) + A, b, Q = calc_params(model.dyn, iter; kwargs...) + !all(b .== 0) && error("SKRF doesn't current support non-zero b") + + tu = KalmanFilters.time_update(μ, cholesky(Σ), A, cholesky(Q)) + return Gaussian(tu.state, PDMat(tu.covariance)) +end + +function update( + model::LinearGaussianStateSpaceModel, + algo::SRKF, + iter::Integer, + state::Gaussian, + observation::AbstractVector; + kwargs..., +) + μ, Σ = GaussianDistributions.pair(state) + H, c, R = calc_params(model.obs, iter; kwargs...) + !all(c .== 0) && error("SKRF doesn't current support non-zero c") + + mu = KalmanFilters.measurement_update(μ, cholesky(Σ), observation, H, cholesky(R)) + ll = logpdf(MvNormal(mu.innovation, PDMat(mu.innovation_covariance)), zero(observation)) + return Gaussian(mu.state, PDMat(mu.covariance)), ll +end + ## KALMAN SMOOTHER ######################################################################### struct KalmanSmoother <: AbstractSmoother end diff --git a/GeneralisedFilters/src/models/hierarchical.jl b/GeneralisedFilters/src/models/hierarchical.jl index ecdc8a6..d458489 100644 --- a/GeneralisedFilters/src/models/hierarchical.jl +++ b/GeneralisedFilters/src/models/hierarchical.jl @@ -29,24 +29,28 @@ function AbstractMCMC.sample( ) outer_dyn, inner_model = model.outer_dyn, model.inner_model - zs = Vector{eltype(inner_model.dyn)}(undef, T) - xs = Vector{eltype(outer_dyn)}(undef, T) + zs = OffsetVector(Vector{eltype(inner_model.dyn)}(undef, T + 1), -1) + xs = OffsetVector(Vector{eltype(outer_dyn)}(undef, T + 1), -1) ys = Vector{eltype(inner_model.obs)}(undef, T) # Simulate outer dynamics - x0 = simulate(rng, outer_dyn; kwargs...) - z0 = simulate(rng, inner_model.dyn; new_outer=x0, kwargs...) + xs[0] = simulate(rng, outer_dyn; kwargs...) + zs[0] = simulate(rng, inner_model.dyn; new_outer=xs[0], kwargs...) for t in 1:T - prev_x = t == 1 ? x0 : xs[t - 1] - prev_z = t == 1 ? z0 : zs[t - 1] - xs[t] = simulate(rng, model.outer_dyn, t, prev_x; kwargs...) + xs[t] = simulate(rng, model.outer_dyn, t, xs[t - 1]; kwargs...) zs[t] = simulate( - rng, inner_model.dyn, t, prev_z; prev_outer=prev_x, new_outer=xs[t], kwargs... + rng, + inner_model.dyn, + t, + zs[t - 1]; + prev_outer=xs[t - 1], + new_outer=xs[t], + kwargs..., ) ys[t] = simulate(rng, inner_model.obs, t, zs[t]; new_outer=xs[t], kwargs...) end - return x0, z0, xs, zs, ys + return xs, zs, ys end ## Methods to make HierarchicalSSM compatible with the bootstrap filter diff --git a/GeneralisedFilters/src/models/linear_gaussian.jl b/GeneralisedFilters/src/models/linear_gaussian.jl index da4a2ef..67dae91 100644 --- a/GeneralisedFilters/src/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/models/linear_gaussian.jl @@ -76,13 +76,18 @@ end ########################################### struct HomogeneousLinearGaussianLatentDynamics{ - T<:Real,ΣT<:AbstractMatrix{T},AT<:AbstractMatrix{T},QT<:AbstractMatrix{T} + T<:Real, + T_μ0<:AbstractVector{T}, + T_Σ0<:AbstractMatrix{T}, + T_A<:AbstractMatrix{T}, + T_b<:AbstractVector{T}, + T_Q<:AbstractMatrix{T}, } <: LinearGaussianLatentDynamics{T} - μ0::Vector{T} - Σ0::ΣT - A::AT - b::Vector{T} - Q::QT + μ0::T_μ0 + Σ0::T_Σ0 + A::T_A + b::T_b + Q::T_Q end calc_μ0(dyn::HomogeneousLinearGaussianLatentDynamics; kwargs...) = dyn.μ0 calc_Σ0(dyn::HomogeneousLinearGaussianLatentDynamics; kwargs...) = dyn.Σ0 @@ -91,11 +96,11 @@ calc_b(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn calc_Q(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn.Q struct HomogeneousLinearGaussianObservationProcess{ - T<:Real,HT<:AbstractMatrix{T},RT<:AbstractMatrix{T} + T<:Real,T_H<:AbstractMatrix{T},T_c<:AbstractVector{T},T_R<:AbstractMatrix{T} } <: LinearGaussianObservationProcess{T} - H::HT - c::Vector{T} - R::RT + H::T_H + c::T_c + R::T_R end calc_H(obs::HomogeneousLinearGaussianObservationProcess, ::Integer; kwargs...) = obs.H calc_c(obs::HomogeneousLinearGaussianObservationProcess, ::Integer; kwargs...) = obs.c diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index ea8dfa2..3cc9ec4 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -21,7 +21,7 @@ include("resamplers.jl") for Dy in Dys rng = StableRNG(1234) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) - _, _, ys = sample(rng, model, 1) + _, ys = sample(rng, model, 1) filtered, ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys) @@ -48,6 +48,62 @@ include("resamplers.jl") end end +@testitem "Square root Kalman filter test" begin + using GeneralisedFilters + using LinearAlgebra + using PDMats + using StableRNGs + using StaticArrays + + rng = StableRNG(1234) + μ0 = rand(rng, 2) + Σ0 = rand(rng, 2, 2) + Σ0 = Σ0 * Σ0' # make Σ0 positive definite + Σ0 = PDMat(Σ0) + A = rand(rng, 2, 2) + b = zeros(2) + Q = rand(rng, 2, 2) + Q = Q * Q' # make Q positive definite + Q = PDMat(Q) + H = rand(rng, 2, 2) + c = zeros(2) + R = rand(rng, 2, 2) + R = R * R' # make R positive definite + R = PDMat(R) + + model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) + + T = 3 + observations = [rand(rng, 2) for _ in 1:T] + + # kf_state, kf_ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), observations) + # srkf_state, srkf_ll = GeneralisedFilters.filter(rng, model, SRKF(), observations) + + # @test isa(srkf_state.Σ, PDMat) + # @test kf_state.μ ≈ srkf_state.μ + # @test kf_state.Σ ≈ srkf_state.Σ + # @test kf_ll ≈ srkf_ll + + # Test with StaticArrays.jl + μ0_static = SVector{2}(μ0) + Σ0_static = PDMat(SMatrix{2,2}(Σ0.mat)) + A_static = SMatrix{2,2}(A) + b_static = SVector{2}(b) + Q_static = PDMat(SMatrix{2,2}(Q.mat)) + H_static = SMatrix{2,2}(H) + c_static = SVector{2}(c) + R_static = PDMat(SMatrix{2,2}(R.mat)) + model_static = create_homogeneous_linear_gaussian_model( + μ0_static, Σ0_static, A_static, b_static, Q_static, H_static, c_static, R_static + ) + observations_static = [SVector{2}(rand(rng, 2)) for _ in 1:T] + srkf_state, _ = GeneralisedFilters.filter( + rng, model_static, SRKF(), observations_static + ) + # @test isa(srkf_state.μ, SVector) + # @test isa(srkf_state.Σ.mat, SMatrix) +end + @testitem "Kalman smoother test" begin using GeneralisedFilters using Distributions @@ -61,7 +117,7 @@ end for Dy in Dys rng = StableRNG(1234) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) - _, _, ys = sample(rng, model, 2) + _, ys = sample(rng, model, 2) states, ll = GeneralisedFilters.smooth(rng, model, KalmanSmoother(), ys) @@ -88,7 +144,7 @@ end rng = StableRNG(1234) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) - _, _, ys = sample(rng, model, 10) + _, ys = sample(rng, model, 10) bf = BF(2^12; threshold=0.8) bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, ys) @@ -128,7 +184,7 @@ end rng = StableRNG(1234) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) - _, _, ys = sample(rng, model, 10) + _, ys = sample(rng, model, 10) algo = PF(2^10, LinearGaussianProposal(); threshold=0.6) kf_states, kf_ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys) @@ -212,7 +268,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, D_outer, D_inner, D_obs ) - _, _, ys = sample(rng, full_model, T) + _, ys = sample(rng, full_model, T) # Ground truth Kalman filtering kf_states, kf_ll = GeneralisedFilters.filter(rng, full_model, KalmanFilter(), ys) @@ -255,7 +311,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, D_outer, D_inner, D_obs, ET ) - _, _, ys = sample(rng, full_model, T) + _, ys = sample(rng, full_model, T) # Ground truth Kalman filtering kf_state, kf_ll = GeneralisedFilters.filter(full_model, KalmanFilter(), ys) @@ -290,7 +346,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, 1, 1, 1 ) - _, _, ys = sample(rng, full_model, T) + _, ys = sample(rng, full_model, T) # Manually create tree to force expansion on second step particle_type = GeneralisedFilters.RaoBlackwellisedParticle{ @@ -325,7 +381,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, D_outer, D_inner, D_obs ) - _, _, ys = sample(rng, full_model, T) + _, ys = sample(rng, full_model, T) # Ground truth Kalman filtering kf_states, kf_ll = GeneralisedFilters.filter(rng, full_model, KalmanFilter(), ys) @@ -373,7 +429,7 @@ end rng = StableRNG(SEED) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) - _, _, ys = sample(rng, model, K) + _, ys = sample(rng, model, K) ref_traj = OffsetVector([rand(rng, 1) for _ in 0:K], -1) @@ -413,7 +469,7 @@ end rng = StableRNG(SEED) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) - _, _, ys = sample(rng, model, K) + _, ys = sample(rng, model, K) # Kalman smoother state, ks_ll = GeneralisedFilters.smooth( @@ -477,7 +533,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, D_outer, D_inner, D_obs, T; static_arrays=true ) - _, _, ys = sample(rng, full_model, K) + _, ys = sample(rng, full_model, K) # Kalman smoother state, _ = GeneralisedFilters.smooth( @@ -562,7 +618,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, D_outer, D_inner, D_obs, T ) - _, _, ys = sample(rng, full_model, K) + _, ys = sample(rng, full_model, K) # Generate random reference trajectory ref_trajectory = [CuArray(rand(rng, T, D_outer, 1)) for _ in 0:K] @@ -595,7 +651,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, D_outer, D_inner, D_obs, T ) - _, _, ys = sample(rng, full_model, K) + _, ys = sample(rng, full_model, K) # Manually create tree to force expansion on second step M = N_particles * 2 - 1 @@ -642,7 +698,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, D_outer, D_inner, D_obs, T ) - _, _, ys = sample(rng, full_model, K) + _, ys = sample(rng, full_model, K) # Kalman smoother state, _ = GeneralisedFilters.smooth( diff --git a/SSMProblems/Project.toml b/SSMProblems/Project.toml index e445d32..9919615 100644 --- a/SSMProblems/Project.toml +++ b/SSMProblems/Project.toml @@ -1,14 +1,12 @@ name = "SSMProblems" uuid = "26aad666-b158-4e64-9d35-0e672562fa48" -authors = [ - "FredericWantiez ", - "THargreaves " -] -version = "0.5.2" +authors = ["FredericWantiez ", "THargreaves "] +version = "0.6.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [extras] diff --git a/SSMProblems/examples/kalman-filter/script.jl b/SSMProblems/examples/kalman-filter/script.jl index 4cf7114..91c8535 100644 --- a/SSMProblems/examples/kalman-filter/script.jl +++ b/SSMProblems/examples/kalman-filter/script.jl @@ -157,7 +157,7 @@ model = StateSpaceModel(dyn, obs); # functions we defined above. rng = MersenneTwister(SEED); -x0, xs, ys = sample(rng, model, T); +xs, ys = sample(rng, model, T); # We can then run the Kalman filter and plot the filtering results against the ground truth. @@ -165,8 +165,8 @@ x_filts, P_filts = AbstractMCMC.sample(model, KalmanFilter(), ys); # Plot trajectory for first dimension p = plot(; title="First Dimension Kalman Filter Estimates", xlabel="Step", ylabel="Value") -plot!(p, 1:T, first.(xs); label="Truth") -scatter!(p, 1:T, first.(ys); label="Observations") +plot!(p, first.(xs); label="Truth") +scatter!(p, first.(ys); label="Observations") plot!( p, 1:T, diff --git a/SSMProblems/src/utils/forward_simulation.jl b/SSMProblems/src/utils/forward_simulation.jl index cdafc10..d6dde17 100644 --- a/SSMProblems/src/utils/forward_simulation.jl +++ b/SSMProblems/src/utils/forward_simulation.jl @@ -1,6 +1,9 @@ """Forward simulation of state space models.""" +using OffsetArrays: OffsetVector + import AbstractMCMC: sample + export sample """ @@ -8,8 +11,9 @@ export sample Simulate a trajectory of length `T` from the state space model. -Returns a tuple `(x0, xs, ys)` where `x0` is the initial state, `xs` is a vector of latent states, -and `ys` is a vector of observations. +Returns a tuple `(xs, ys)` where `xs` is an (offset) vector of latent states, and `ys` is a +vector of observations. The latent states are indexed from `0` to `T`, where `xs[0]` is the +initial state. """ function sample( rng::AbstractRNG, model::StateSpaceModel{<:Real,LD,OP}, T::Integer; kwargs... @@ -17,16 +21,16 @@ function sample( T_dyn = eltype(LD) T_obs = eltype(OP) - xs = Vector{T_dyn}(undef, T) + xs = OffsetVector(Vector{T_dyn}(undef, T + 1), -1) ys = Vector{T_obs}(undef, T) - x0 = simulate(rng, model.dyn; kwargs...) + xs[0] = simulate(rng, model.dyn; kwargs...) for t in 1:T - xs[t] = simulate(rng, model.dyn, t, t == 1 ? x0 : xs[t - 1]; kwargs...) + xs[t] = simulate(rng, model.dyn, t, xs[t - 1]; kwargs...) ys[t] = simulate(rng, model.obs, t, xs[t]; kwargs...) end - return x0, xs, ys + return xs, ys end """ @@ -36,4 +40,4 @@ Simulate a trajectory using the default random number generator. """ function sample(model::AbstractStateSpaceModel, T::Integer; kwargs...) return sample(default_rng(), model, T; kwargs...) -end \ No newline at end of file +end diff --git a/research/maximum_likelihood/script.jl b/research/maximum_likelihood/script.jl index 41404a6..a76f420 100644 --- a/research/maximum_likelihood/script.jl +++ b/research/maximum_likelihood/script.jl @@ -25,7 +25,7 @@ end # data generation process rng = MersenneTwister(1234) true_model = toy_model(1.0) -_, _, ys = sample(rng, true_model, 1000) +_, ys = sample(rng, true_model, 1000) # evaluate and return the log evidence function logℓ(θ, data) diff --git a/research/variational_filter/script.jl b/research/variational_filter/script.jl index b53c134..93ec50f 100644 --- a/research/variational_filter/script.jl +++ b/research/variational_filter/script.jl @@ -37,7 +37,7 @@ end rng = MersenneTwister(1234) true_model = toy_model(Float32, 10, 10) -_, _, ys = sample(rng, true_model, 100) +_, ys = sample(rng, true_model, 100) # ## Deep Gaussian Proposal