From 334043545f0289f637878bec2c8a30587050c088 Mon Sep 17 00:00:00 2001 From: Alexander Shtuchkin Date: Thu, 22 Sep 2022 12:33:08 -0400 Subject: [PATCH] Fix process noise handling in UKF --- filterpy/kalman/UKF.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/filterpy/kalman/UKF.py b/filterpy/kalman/UKF.py index f8baf19..5a90dc2 100644 --- a/filterpy/kalman/UKF.py +++ b/filterpy/kalman/UKF.py @@ -361,7 +361,7 @@ def __init__(self, dim_x, dim_z, dt, hx, fx, points, self.x_post = self.x.copy() self.P_post = self.P.copy() - def predict(self, dt=None, UT=None, fx=None, **fx_args): + def predict(self, dt=None, Q=None, UT=None, fx=None, **fx_args): r""" Performs the predict step of the UKF. On return, self.x and self.P contain the predicted state (x) and covariance (P). ' @@ -376,6 +376,10 @@ def predict(self, dt=None, UT=None, fx=None, **fx_args): If specified, the time step to be used for this prediction. self._dt is used if this is not provided. + Q : numpy.array((dim_x, dim_x)), optional + Process noise. If provided, overrides self.Q for + this function call. + fx : callable f(x, dt, **fx_args), optional State transition function. If not provided, the default function passed in during construction will be used. @@ -393,6 +397,9 @@ def predict(self, dt=None, UT=None, fx=None, **fx_args): if dt is None: dt = self._dt + if Q is None: + Q = self.Q + if UT is None: UT = unscented_transform @@ -400,7 +407,7 @@ def predict(self, dt=None, UT=None, fx=None, **fx_args): self.compute_process_sigmas(dt, fx, **fx_args) #and pass sigmas through the unscented transform to compute prior - self.x, self.P = UT(self.sigmas_f, self.Wm, self.Wc, self.Q, + self.x, self.P = UT(self.sigmas_f, self.Wm, self.Wc, Q, self.x_mean, self.residual_x) # update sigma points to reflect the new variance of the points @@ -718,7 +725,7 @@ def rts_smoother(self, Xs, Ps, Qs=None, dts=None, UT=None): sigmas_f[i] = self.fx(sigmas[i], dts[k]) xb, Pb = UT( - sigmas_f, self.Wm, self.Wc, self.Q, + sigmas_f, self.Wm, self.Wc, Qs[k], self.x_mean, self.residual_x) # compute cross variance