@@ -83,6 +83,28 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend):
8383    np .testing .assert_allclose (idata .fit ["covariance_matrix" ].values , bda_cov , rtol = 1e-3 , atol = 1e-3 )
8484
8585
86+ def  test_fit_laplace_outside_model_context ():
87+     with  pm .Model () as  m :
88+         mu  =  pm .Normal ("mu" , 0 , 1 )
89+         sigma  =  pm .Exponential ("sigma" , 1 )
90+         y_hat  =  pm .Normal ("y_hat" , mu = mu , sigma = sigma , observed = np .random .normal (size = 10 ))
91+ 
92+     idata  =  fit_laplace (
93+         model = m ,
94+         optimize_method = "L-BFGS-B" ,
95+         use_grad = True ,
96+         progressbar = False ,
97+         chains = 1 ,
98+         draws = 100 ,
99+     )
100+ 
101+     assert  hasattr (idata , "posterior" )
102+     assert  hasattr (idata , "fit" )
103+     assert  hasattr (idata , "optimizer_result" )
104+ 
105+     assert  idata .posterior ["mu" ].shape  ==  (1 , 100 )
106+ 
107+ 
86108@pytest .mark .parametrize ( 
87109    "include_transformed" , [True , False ], ids = ["include_transformed" , "no_transformed" ] 
88110) 
@@ -208,6 +230,50 @@ def test_model_with_nonstandard_dimensionality(rng):
208230    assert  "class"  in  list (idata .unconstrained_posterior .sigma_log__ .coords .keys ())
209231
210232
233+ def  test_laplace_nonstandard_dims_2d ():
234+     true_P  =  np .array ([[0.5 , 0.3 , 0.2 ], [0.1 , 0.6 , 0.3 ], [0.2 , 0.4 , 0.4 ]])
235+     y_obs  =  pm .draw (
236+         pmx .DiscreteMarkovChain .dist (
237+             P = true_P ,
238+             init_dist = pm .Categorical .dist (
239+                 logit_p = np .ones (
240+                     3 ,
241+                 )
242+             ),
243+             shape = (100 , 5 ),
244+         )
245+     )
246+ 
247+     with  pm .Model (
248+         coords = {
249+             "time" : range (y_obs .shape [0 ]),
250+             "state" : list ("ABC" ),
251+             "next_state" : list ("ABC" ),
252+             "unit" : [1 , 2 , 3 , 4 , 5 ],
253+         }
254+     ) as  model :
255+         y  =  pm .Data ("y" , y_obs , dims = ["time" , "unit" ])
256+         init_dist  =  pm .Categorical .dist (
257+             logit_p = np .ones (
258+                 3 ,
259+             )
260+         )
261+         P  =  pm .Dirichlet ("P" , a = np .eye (3 ) *  2  +  1 , dims = ["state" , "next_state" ])
262+         y_hat  =  pmx .DiscreteMarkovChain (
263+             "y_hat" , P = P , init_dist = init_dist , dims = ["time" , "unit" ], observed = y_obs 
264+         )
265+ 
266+         idata  =  pmx .fit_laplace (progressbar = True )
267+ 
268+         # The simplex transform should drop from the right-most dimension, so the left dimension should be unmodified 
269+         assert  "state"  in  list (idata .unconstrained_posterior .P_simplex__ .coords .keys ())
270+ 
271+         # The mutated dimension should be unknown coords 
272+         assert  "P_simplex___dim_1"  in  list (idata .unconstrained_posterior .P_simplex__ .coords .keys ())
273+ 
274+         assert  idata .unconstrained_posterior .P_simplex__ .shape [- 2 :] ==  (3 , 2 )
275+ 
276+ 
211277def  test_laplace_nonscalar_rv_without_dims ():
212278    with  pm .Model (coords = {"test" : ["A" , "B" , "C" ]}) as  model :
213279        x_loc  =  pm .Normal ("x_loc" , mu = 0 , sigma = 1 , dims = ["test" ])
0 commit comments