99import pytest
1010
1111from numpy .testing import assert_allclose
12+ from pymc .testing import mock_sample_setup_and_teardown
13+ from pytensor .compile import SharedVariable
14+ from pytensor .graph .basic import graph_inputs
1215
1316from pymc_extras .statespace .core .statespace import FILTER_FACTORY , PyMCStateSpace
1417from pymc_extras .statespace .models import structural as st
3033floatX = pytensor .config .floatX
3134nile = load_nile_test_data ()
3235ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
36+ mock_pymc_sample = pytest .fixture (scope = "session" )(mock_sample_setup_and_teardown )
3337
3438
3539def make_statespace_mod (k_endog , k_states , k_posdef , filter_type , verbose = False , data_info = None ):
@@ -170,7 +174,7 @@ def exog_pymc_mod(exog_ss_mod, exog_data):
170174 )
171175 beta_exog = pm .Normal ("beta_exog" , mu = 0 , sigma = 1 , dims = ["exog_state" ])
172176
173- exog_ss_mod .build_statespace_graph (exog_data ["y" ])
177+ exog_ss_mod .build_statespace_graph (exog_data ["y" ], save_kalman_filter_outputs_in_idata = True )
174178
175179 return struct_model
176180
@@ -212,7 +216,7 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):
212216
213217
214218@pytest .fixture (scope = "session" )
215- def idata (pymc_mod , rng ):
219+ def idata (pymc_mod , rng , mock_pymc_sample ):
216220 with pymc_mod :
217221 idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
218222 idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
@@ -222,7 +226,7 @@ def idata(pymc_mod, rng):
222226
223227
224228@pytest .fixture (scope = "session" )
225- def idata_exog (exog_pymc_mod , rng ):
229+ def idata_exog (exog_pymc_mod , rng , mock_pymc_sample ):
226230 with exog_pymc_mod :
227231 idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
228232 idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
@@ -231,7 +235,7 @@ def idata_exog(exog_pymc_mod, rng):
231235
232236
233237@pytest .fixture (scope = "session" )
234- def idata_no_exog (pymc_mod_no_exog , rng ):
238+ def idata_no_exog (pymc_mod_no_exog , rng , mock_pymc_sample ):
235239 with pymc_mod_no_exog :
236240 idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
237241 idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
@@ -240,7 +244,7 @@ def idata_no_exog(pymc_mod_no_exog, rng):
240244
241245
242246@pytest .fixture (scope = "session" )
243- def idata_no_exog_dt (pymc_mod_no_exog_dt , rng ):
247+ def idata_no_exog_dt (pymc_mod_no_exog_dt , rng , mock_pymc_sample ):
244248 with pymc_mod_no_exog_dt :
245249 idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
246250 idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
@@ -895,6 +899,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
895899 assert_allclose (regression_effect , regression_effect_expected )
896900
897901
902+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
903+ @pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
904+ @pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
905+ @pytest .mark .filterwarnings ("ignore:Skipping `CheckAndRaise` Op" )
906+ @pytest .mark .filterwarnings ("ignore:No frequency was specific on the data's DateTimeIndex." )
907+ def test_build_forecast_model (rng , exog_ss_mod , exog_pymc_mod , exog_data , idata_exog ):
908+ data_before_build_forecast_model = {d .name : d .get_value () for d in exog_pymc_mod .data_vars }
909+
910+ scenario = pd .DataFrame (
911+ {
912+ "date" : pd .date_range (start = "2023-05-11" , end = "2023-05-20" , freq = "D" ),
913+ "x1" : rng .choice (2 , size = 10 , replace = True ).astype (float ),
914+ }
915+ )
916+ scenario .set_index ("date" , inplace = True )
917+
918+ time_index = exog_ss_mod ._get_fit_time_index ()
919+ t0 , forecast_index = exog_ss_mod ._build_forecast_index (
920+ time_index = time_index ,
921+ start = exog_data .index [- 1 ],
922+ end = scenario .index [- 1 ],
923+ scenario = scenario ,
924+ )
925+
926+ test_forecast_model = exog_ss_mod ._build_forecast_model (
927+ time_index = time_index ,
928+ t0 = t0 ,
929+ forecast_index = forecast_index ,
930+ scenario = scenario ,
931+ filter_output = "predicted" ,
932+ mvn_method = "svd" ,
933+ )
934+
935+ frozen_shared_inputs = [
936+ inpt
937+ for inpt in graph_inputs ([test_forecast_model .x0_slice , test_forecast_model .P0_slice ])
938+ if isinstance (inpt , SharedVariable )
939+ and not isinstance (inpt .get_value (), np .random .Generator )
940+ ]
941+
942+ assert (
943+ len (frozen_shared_inputs ) == 0
944+ ) # check there are no non-random generator SharedVariables in the frozen inputs
945+
946+ unfrozen_shared_inputs = [
947+ inpt
948+ for inpt in graph_inputs ([test_forecast_model .forecast_combined ])
949+ if isinstance (inpt , SharedVariable )
950+ and not isinstance (inpt .get_value (), np .random .Generator )
951+ ]
952+
953+ # Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data
954+ assert len (unfrozen_shared_inputs ) == 1
955+ assert unfrozen_shared_inputs [0 ].name == "data_exog"
956+
957+ data_after_build_forecast_model = {d .name : d .get_value () for d in test_forecast_model .data_vars }
958+
959+ with test_forecast_model :
960+ dummy_obs_data = np .zeros ((len (forecast_index ), exog_ss_mod .k_endog ))
961+ pm .set_data (
962+ {"data_exog" : scenario } | {"data" : dummy_obs_data },
963+ coords = {"data_time" : np .arange (len (forecast_index ))},
964+ )
965+ idata_forecast = pm .sample_posterior_predictive (
966+ idata_exog , var_names = ["x0_slice" , "P0_slice" ]
967+ )
968+
969+ np .testing .assert_allclose (
970+ unfrozen_shared_inputs [0 ].get_value (), scenario ["x1" ].values .reshape ((- 1 , 1 ))
971+ ) # ensure the replaced data matches the exogenous data
972+
973+ for k in data_before_build_forecast_model .keys ():
974+ assert ( # check that the data needed to init the forecasts doesn't change
975+ data_before_build_forecast_model [k ].mean () == data_after_build_forecast_model [k ].mean ()
976+ )
977+
978+ # Check that the frozen states and covariances correctly match the sliced index
979+ np .testing .assert_allclose (
980+ idata_exog .posterior ["predicted_covariance" ].sel (time = t0 ).mean (("chain" , "draw" )).values ,
981+ idata_forecast .posterior_predictive ["P0_slice" ].mean (("chain" , "draw" )).values ,
982+ )
983+ np .testing .assert_allclose (
984+ idata_exog .posterior ["predicted_state" ].sel (time = t0 ).mean (("chain" , "draw" )).values ,
985+ idata_forecast .posterior_predictive ["x0_slice" ].mean (("chain" , "draw" )).values ,
986+ )
987+
988+
898989@pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
899990@pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
900991@pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
0 commit comments