GeometricKalman

Kalman filters on manifolds with an affine connection, unifying the Lie group and Riemannian approaches.

arXiv preprint: https://arxiv.org/abs/2506.01086.

Getting started

Basic setting (example: a car driving on a sphere)

using Manifolds
using RecursiveArrayTools
using LinearAlgebra

using Distributions

using GeometricKalman
using GeometricKalman: gen_car_sphere_data, car_sphere_f, car_sphere_h

M = TangentBundle(Manifolds.Sphere(2)) # state manifold
M_obs = Manifolds.Sphere(2) # observation manifold
retraction = Manifolds.FiberBundleProductRetraction()
inverse_retraction = Manifolds.FiberBundleInverseProductRetraction()
Manifolds.FiberBundleInverseProductRetraction()

Generating data using the gen_car_sphere_data function.

dt = 0.01
vt = 5

times, samples, controls, measurements = gen_car_sphere_data(;
    vt = vt,
    N = 200,
    noise_f_distr = MvNormal(
        [0.0, 0.0, 0.0, 0.0],
        1e4 * diagm([1e-3, 1e-3, 1e-2, 1e-2]),
    ),
    noise_h_distr = MvNormal([0.0, 0.0], diagm([0.01, 0.01])),
    retraction = retraction,
)
([0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09  …  1.91, 1.92, 1.93, 1.94, 1.95, 1.96, 1.97, 1.98, 1.99, 2.0], RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}[RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([1.0, 0.0, 0.0], [0.0, 1.0, 0.0])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([0.9981974497181397, 0.05861455627509118, -0.012891282669935543], [-0.058573730050185756, 0.9979738804299262, 0.00214472210909708])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([0.9942123288633348, 0.10713788174104043, -0.007957350827210356], [-0.10710975530961783, 0.9942745290821743, 0.00435165488232123])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([0.9860361388658084, 0.16618849200543634, 0.010682601535554373], [-0.16692674078032943, 0.9898083112600988, 0.009459149944891226])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([0.9768687262250207, 0.21355008861568292, 0.01112885329850786], [-0.21535260908441523, 0.9844649279981341, 0.012459173355502331])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([0.9628596937661658, 0.26987901398753245, 0.008156465511674862], [-0.2731708204065776, 0.974290211323244, 0.01038328584687189])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([0.9529977643376695, 0.30292754272605604, -0.005492269599868843], [-0.307191481911822, 0.9666310222909218, 0.012083314304894352])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([0.9362414306777989, 0.35125635084371887, 0.008423744674601913], [-0.3590202084795983, 0.9565266040782067, 0.01704101931864676])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([0.9135569789286017, 0.40666105253947965, 0.0063588205140028605], [-0.41409675352227265, 0.9299422249818743, 0.020396096985136187])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([0.898100388424205, 0.4397904795969305, 0.00016238286288790434], [-0.4480307403436383, 0.9149197559083359, 0.022069354538263544]))  …  RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([-0.9989576960475754, -0.0126483430202399, 0.04385819109541056], [-0.08789451984895424, -2.0613077005766183, -2.5964370865854165])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([-0.9888175220775817, -0.11773416846014276, -0.09153454871986562], [0.48326826358262454, -2.0193284057477787, -2.6232737207315555])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([-0.9488642573312042, -0.21109166432487678, -0.23472735335247008], [1.064029096526177, -1.9262536389745006, -2.5689510970016642])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([-0.8826349212566388, -0.30901523384427765, -0.3542106449987816], [1.5988919721635568, -1.7712141895867064, -2.4389603620410267])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([-0.804679955078721, -0.38809789372527304, -0.44929967146717736], [2.0395859141280566, -1.5950717880273735, -2.2750292630926183])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([-0.7045164313764616, -0.453256085414381, -0.5460911269700381], [2.4628509918238453, -1.3717088462410905, -2.0388238424120884])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([-0.579657891233468, -0.5113571539687959, -0.6344372232267758], [2.855308091555207, -1.091102806465895, -1.7293415348342678])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([-0.41557728451887854, -0.5584397792924205, -0.7179418733405563], [3.2180799281901997, -0.7276792979251029, -1.2967565842916287])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([-0.2509202198313356, -0.5776044152011688, -0.776796101187377], [3.4565328047770336, -0.368143515723935, -0.8427857067232267])), RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(([-0.06989423882408771, -0.5880761316601831, -0.8057799071401545], [3.5991041042501695, 0.020158828928833805, -0.3269026264074697]))], Any[0.0, 0.004999979166692708, 0.009999833334166664, 0.01499943750632809, 0.01999866669333308, 0.024997395914712332, 0.02999550020249566, 0.034992854604336196, 0.03998933418663416, 0.044984814037660234  …  0.8163137404456835, 0.8191915683009983, 0.8220489164097717, 0.8248857133384501, 0.8277018881672576, 0.8304973704919705, 0.8332720904256761, 0.8360259786005205, 0.838758966169443, 0.8414709848078965], [[0.9977812499036257, 0.028994499028791865, -0.05993243167791399], [0.946892513786505, 0.27350473008889664, -0.169084978499196], [0.9911800142095112, 0.10602974478905727, 0.07949762670420221], [0.9782128960655717, 0.20735266957847925, -0.010217650888796798], [0.9742684660959368, 0.16514550595652316, 0.15338812807200844], [0.9894695791003493, 0.14131780544684214, -0.03129264927529183], [0.9709897932757799, 0.23897789107314185, -0.008269760122998217], [0.842733852301092, 0.5302639708219569, -0.09284274572572962], [0.9068494652531333, 0.3633884908999152, -0.21347798961857617], [0.8900095101576119, 0.39755620842640427, -0.2232311200765466]  …  [-0.994428801515572, 0.10076011018641125, -0.030963832313278157], [-0.9941145781143043, 0.01208715624444591, -0.10765735569175228], [-0.9484772088979293, -0.15130628194772516, -0.2783835362307705], [-0.9101104757540861, -0.3187905048346369, -0.2647102868230621], [-0.8398386844759985, -0.28587448976927954, -0.4614615478637172], [-0.6317908078789913, -0.5977055182400391, -0.4935468453399507], [-0.5515371131991982, -0.5292013292007487, -0.6447889313070249], [-0.4220006078618824, -0.5915051845353233, -0.6870495641742556], [-0.147456957138132, -0.5836472202774269, -0.7985063356379806], [-0.046924519159592865, -0.5892820356227253, -0.8065635573183169]])

Setting initial conditions for filters

p0 = ArrayPartition([1.0, 0.0, 0.0], [0.0, 1.0, 0.0])
P0 = diagm([0.1, 0.1, 0.1, 0.1])
Q = diagm([0.1, 0.1, 0.01, 0.01])
R = diagm([0.01, 0.01])
2×2 Matrix{Float64}:
 0.01  0.0
 0.0   0.01

Adapting system dynamics to the interface expected by Kalman filters.

car_f_adapted(p, q, noise, t::Real) = car_sphere_f(p, q, noise, t::Real; vt = vt)
f_tilde = GeometricKalman.default_discretization(
    M,
    car_f_adapted;
    dt = dt,
    retraction = retraction,
)
(::GeometricKalman.var"#tilde_f#27"{Float64, Manifolds.FiberBundleProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, typeof(Main.car_f_adapted)}) (generic function with 1 method)

Filter-specific settings

sp = WanMerweSigmaPoints(; α = 1.0)
filter_params = [
    ( # extended Kalman filter
        "EKF",
        (;
            propagator = EKFPropagator(M, f_tilde; B_M = DefaultOrthonormalBasis()),
            updater = EKFUpdater(
                M,
                M_obs,
                car_sphere_h;
                B_M = DefaultOrthonormalBasis(),
                B_M_obs = DefaultOrthonormalBasis(),
            ),
        ),
    ),
    ( # unscented Kalman filter
        "UKF",
        (;
            propagator = UnscentedPropagator(
                M;
                sigma_points = sp,
                inverse_retraction_method = inverse_retraction,
            ),
            updater = UnscentedUpdater(; sigma_points = sp),
        ),
    ),
    ( # adaptive extended Kalman filter
        "EKF adaptive M α=0.99",
        (;
            propagator = EKFPropagator(M, f_tilde; B_M = DefaultOrthonormalBasis()),
            updater = EKFUpdater(
                M,
                M_obs,
                car_sphere_h;
                B_M = DefaultOrthonormalBasis(),
                B_M_obs = DefaultOrthonormalBasis(),
            ),
            measurement_covariance_adapter = CovarianceMatchingMeasurementCovarianceAdapter(
                0.99,
            ),
        ),
    ),
]
3-element Vector{Tuple{String, NamedTuple}}:
 ("EKF", (propagator = EKFPropagator{GeometricKalman.var"#jacobian_p#3"{DefaultOrthonormalBasis{ℝ, TangentSpaceType}, DefaultOrthonormalBasis{ℝ, TangentSpaceType}, Manifolds.FiberBundleProductRetraction, Manifolds.FiberBundleInverseProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, GeometricKalman.var"#tilde_f#27"{Float64, Manifolds.FiberBundleProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, typeof(Main.car_f_adapted)}}}(GeometricKalman.var"#jacobian_p#3"{DefaultOrthonormalBasis{ℝ, TangentSpaceType}, DefaultOrthonormalBasis{ℝ, TangentSpaceType}, Manifolds.FiberBundleProductRetraction, Manifolds.FiberBundleInverseProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, GeometricKalman.var"#tilde_f#27"{Float64, Manifolds.FiberBundleProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, typeof(Main.car_f_adapted)}}(DefaultOrthonormalBasis(ℝ), DefaultOrthonormalBasis(ℝ), Manifolds.FiberBundleProductRetraction(), Manifolds.FiberBundleInverseProductRetraction(), TangentBundle(Sphere(2, ℝ)), TangentBundle(Sphere(2, ℝ)), GeometricKalman.var"#tilde_f#27"{Float64, Manifolds.FiberBundleProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, typeof(Main.car_f_adapted)}(0.01, Manifolds.FiberBundleProductRetraction(), TangentBundle(Sphere(2, ℝ)), Main.car_f_adapted))), updater = EKFUpdater{GeometricKalman.var"#jacobian_p#3"{DefaultOrthonormalBasis{ℝ, TangentSpaceType}, DefaultOrthonormalBasis{ℝ, TangentSpaceType}, Manifolds.FiberBundleProductRetraction, LogarithmicInverseRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, typeof(GeometricKalman.car_sphere_h)}}(GeometricKalman.var"#jacobian_p#3"{DefaultOrthonormalBasis{ℝ, TangentSpaceType}, DefaultOrthonormalBasis{ℝ, TangentSpaceType}, Manifolds.FiberBundleProductRetraction, LogarithmicInverseRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, typeof(GeometricKalman.car_sphere_h)}(DefaultOrthonormalBasis(ℝ), DefaultOrthonormalBasis(ℝ), Manifolds.FiberBundleProductRetraction(), LogarithmicInverseRetraction(), TangentBundle(Sphere(2, ℝ)), Sphere(2, ℝ), GeometricKalman.car_sphere_h))))
 ("UKF", (propagator = UnscentedPropagator{WanMerweSigmaPoints{Float64}, Manifolds.FiberBundleInverseProductRetraction, GradientDescentEstimation}(WanMerweSigmaPoints{Float64}(1.0, 2.0, 0.0), Manifolds.FiberBundleInverseProductRetraction(), GradientDescentEstimation()), updater = UnscentedUpdater{WanMerweSigmaPoints{Float64}, LogarithmicInverseRetraction}(WanMerweSigmaPoints{Float64}(1.0, 2.0, 0.0), LogarithmicInverseRetraction())))
 ("EKF adaptive M α=0.99", (propagator = EKFPropagator{GeometricKalman.var"#jacobian_p#3"{DefaultOrthonormalBasis{ℝ, TangentSpaceType}, DefaultOrthonormalBasis{ℝ, TangentSpaceType}, Manifolds.FiberBundleProductRetraction, Manifolds.FiberBundleInverseProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, GeometricKalman.var"#tilde_f#27"{Float64, Manifolds.FiberBundleProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, typeof(Main.car_f_adapted)}}}(GeometricKalman.var"#jacobian_p#3"{DefaultOrthonormalBasis{ℝ, TangentSpaceType}, DefaultOrthonormalBasis{ℝ, TangentSpaceType}, Manifolds.FiberBundleProductRetraction, Manifolds.FiberBundleInverseProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, GeometricKalman.var"#tilde_f#27"{Float64, Manifolds.FiberBundleProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, typeof(Main.car_f_adapted)}}(DefaultOrthonormalBasis(ℝ), DefaultOrthonormalBasis(ℝ), Manifolds.FiberBundleProductRetraction(), Manifolds.FiberBundleInverseProductRetraction(), TangentBundle(Sphere(2, ℝ)), TangentBundle(Sphere(2, ℝ)), GeometricKalman.var"#tilde_f#27"{Float64, Manifolds.FiberBundleProductRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, typeof(Main.car_f_adapted)}(0.01, Manifolds.FiberBundleProductRetraction(), TangentBundle(Sphere(2, ℝ)), Main.car_f_adapted))), updater = EKFUpdater{GeometricKalman.var"#jacobian_p#3"{DefaultOrthonormalBasis{ℝ, TangentSpaceType}, DefaultOrthonormalBasis{ℝ, TangentSpaceType}, Manifolds.FiberBundleProductRetraction, LogarithmicInverseRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, typeof(GeometricKalman.car_sphere_h)}}(GeometricKalman.var"#jacobian_p#3"{DefaultOrthonormalBasis{ℝ, TangentSpaceType}, DefaultOrthonormalBasis{ℝ, TangentSpaceType}, Manifolds.FiberBundleProductRetraction, LogarithmicInverseRetraction, FiberBundle{ℝ, TangentSpaceType, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, Manifolds.FiberBundleProductVectorTransport{ParallelTransport, ParallelTransport}}, Sphere{ManifoldsBase.TypeParameter{Tuple{2}}, ℝ}, typeof(GeometricKalman.car_sphere_h)}(DefaultOrthonormalBasis(ℝ), DefaultOrthonormalBasis(ℝ), Manifolds.FiberBundleProductRetraction(), LogarithmicInverseRetraction(), TangentBundle(Sphere(2, ℝ)), Sphere(2, ℝ), GeometricKalman.car_sphere_h)), measurement_covariance_adapter = CovarianceMatchingMeasurementCovarianceAdapter{Float64}(0.99)))

Running the filters. Results will be saved in reconstructions.

reconstructions = NamedTuple[]

for (name, filter_kwargs) in filter_params
    kf = discrete_kalman_filter_manifold(
        M,
        M_obs,
        p0,
        f_tilde,
        car_sphere_h,
        P0,
        copy(Q),
        copy(R);
        filter_kwargs...,
    )

    samples_kalman = []
    for i in eachindex(samples)
        GeometricKalman.update!(kf, controls[i], measurements[i])
        push!(samples_kalman, kf.p_n)
        predict!(kf, controls[i])
    end
    push!(reconstructions, (; data = samples_kalman, label = name))
end
┌ Warning: SpecialOrthogonal will move to LieGroups.jl and be renamed to SpecialOrthogonalGroup.
│   caller = ip:0x0
└ @ Core :-1

Plotting the estimated trajectory and measurements.

using Plots

function trajectory_plot3d(
    p0,
    samples::Vector,
    reconstructions::Vector{<:NamedTuple},
    measurements::Vector,
)
    fig = plot(
        [s.x[1][1] for s in samples],
        [s.x[1][2] for s in samples],
        [s.x[1][3] for s in samples];
        label = "original",
        linewidth = 5.0,
    )
    scatter3d!(map(v -> [v], p0.x[1])..., markersize = 15, label = "Starting point")

    for rec in reconstructions
        plot!(
            [s.x[1][1] for s in rec.data],
            [s.x[1][2] for s in rec.data],
            [s.x[1][3] for s in rec.data];
            label = rec.label,
        )
    end

    scatter!(
        [s[1] for s in measurements],
        [s[2] for s in measurements],
        [s[3] for s in measurements];
        label = "measurements",
    )
    return fig
end


trajectory_plot3d(p0, samples, reconstructions, measurements)
Example block output