1+ import psycopg2
12from mesa_frames .concrete .datacollector import DataCollector
23from mesa_frames import Model , AgentSet , AgentSetRegistry
34import pytest
@@ -15,6 +16,7 @@ def custom_trigger(model):
1516class ExampleAgentSet1 (AgentSet ):
1617 def __init__ (self , model : Model ):
1718 super ().__init__ (model )
19+ self ["unique_id" ] = pl .Series ("unique_id" , [101 , 102 , 103 , 104 ], dtype = pl .Int64 )
1820 self ["wealth" ] = pl .Series ("wealth" , [1 , 2 , 3 , 4 ])
1921 self ["age" ] = pl .Series ("age" , [10 , 20 , 30 , 40 ])
2022
@@ -28,6 +30,7 @@ def step(self) -> None:
2830class ExampleAgentSet2 (AgentSet ):
2931 def __init__ (self , model : Model ):
3032 super ().__init__ (model )
33+ self ["unique_id" ] = pl .Series ("unique_id" , [201 , 202 , 203 , 204 ], dtype = pl .Int64 )
3134 self ["wealth" ] = pl .Series ("wealth" , [10 , 20 , 30 , 40 ])
3235 self ["age" ] = pl .Series ("age" , [11 , 22 , 33 , 44 ])
3336
@@ -41,6 +44,7 @@ def step(self) -> None:
4144class ExampleAgentSet3 (AgentSet ):
4245 def __init__ (self , model : Model ):
4346 super ().__init__ (model )
47+ self ["unique_id" ] = pl .Series ("unique_id" , [301 , 302 , 303 , 304 ], dtype = pl .Int64 )
4448 self ["age" ] = pl .Series ("age" , [1 , 2 , 3 , 4 ])
4549 self ["wealth" ] = pl .Series ("wealth" , [1 , 2 , 3 , 4 ])
4650
@@ -147,11 +151,16 @@ def test__init__(self, fix1_model, postgres_uri):
147151 ):
148152 model .test_dc = DataCollector (model = model , storage = "S3-csv" )
149153
154+ try :
155+ psycopg2 .connect (postgres_uri )
156+ except psycopg2 .OperationalError :
157+ pass
158+
150159 with pytest .raises (
151160 ValueError ,
152161 match = "Please define a storage_uri to if to be stored not in memory" ,
153162 ):
154- model .test_dc = DataCollector (model = model , storage = "postgresql" )
163+ model .test_dc = DataCollector (model = model , storage = "postgresql" , storage_uri = None )
155164
156165 def test_collect (self , fix1_model ):
157166 model = fix1_model
@@ -185,12 +194,15 @@ def test_collect(self, fix1_model):
185194 with pytest .raises (pl .exceptions .ColumnNotFoundError , match = "max_wealth" ):
186195 collected_data ["model" ]["max_wealth" ]
187196
188- assert collected_data ["agent" ].shape == (4 , 7 )
197+ assert collected_data ["agent" ].shape == (4 , 10 )
189198 assert set (collected_data ["agent" ].columns ) == {
190199 "wealth" ,
191200 "age_ExampleAgentSet1" ,
192201 "age_ExampleAgentSet2" ,
193202 "age_ExampleAgentSet3" ,
203+ "unique_id_ExampleAgentSet1" ,
204+ "unique_id_ExampleAgentSet2" ,
205+ "unique_id_ExampleAgentSet3" ,
194206 "step" ,
195207 "seed" ,
196208 "batch" ,
@@ -242,12 +254,15 @@ def test_collect_step(self, fix1_model):
242254 assert collected_data ["model" ]["step" ].to_list () == [5 ]
243255 assert collected_data ["model" ]["total_agents" ].to_list () == [12 ]
244256
245- assert collected_data ["agent" ].shape == (4 , 7 )
257+ assert collected_data ["agent" ].shape == (4 , 10 )
246258 assert set (collected_data ["agent" ].columns ) == {
247259 "wealth" ,
248260 "age_ExampleAgentSet1" ,
249261 "age_ExampleAgentSet2" ,
250262 "age_ExampleAgentSet3" ,
263+ "unique_id_ExampleAgentSet1" ,
264+ "unique_id_ExampleAgentSet2" ,
265+ "unique_id_ExampleAgentSet3" ,
251266 "step" ,
252267 "seed" ,
253268 "batch" ,
@@ -297,25 +312,20 @@ def test_conditional_collect(self, fix1_model):
297312 assert collected_data ["model" ]["step" ].to_list () == [2 , 4 ]
298313 assert collected_data ["model" ]["total_agents" ].to_list () == [12 , 12 ]
299314
300- assert collected_data ["agent" ].shape == (8 , 7 )
301- assert set (collected_data ["agent" ].columns ) == {
302- "wealth" ,
303- "age_ExampleAgentSet1" ,
304- "age_ExampleAgentSet2" ,
305- "age_ExampleAgentSet3" ,
306- "step" ,
307- "seed" ,
308- "batch" ,
309- }
315+ assert collected_data ["agent" ].shape == (8 , 10 )
310316 assert set (collected_data ["agent" ].columns ) == {
311317 "wealth" ,
312318 "age_ExampleAgentSet1" ,
313319 "age_ExampleAgentSet2" ,
314320 "age_ExampleAgentSet3" ,
321+ "unique_id_ExampleAgentSet1" ,
322+ "unique_id_ExampleAgentSet2" ,
323+ "unique_id_ExampleAgentSet3" ,
315324 "step" ,
316325 "seed" ,
317326 "batch" ,
318327 }
328+
319329 assert collected_data ["agent" ]["wealth" ].to_list () == [3 , 4 , 5 , 6 , 5 , 6 , 7 , 8 ]
320330 assert collected_data ["agent" ]["age_ExampleAgentSet1" ].to_list () == [
321331 10 ,
@@ -394,15 +404,25 @@ def test_flush_local_csv(self, fix1_model):
394404 assert model_df ["step" ].to_list () == [2 ]
395405 assert model_df ["total_agents" ].to_list () == [12 ]
396406
407+ agent_overrides = {
408+ "seed" : pl .Utf8 ,
409+ "unique_id_ExampleAgentSet1" : pl .Utf8 ,
410+ "unique_id_ExampleAgentSet2" : pl .Utf8 ,
411+ "unique_id_ExampleAgentSet3" : pl .Utf8 ,
412+ }
413+
397414 agent_df = pl .read_csv (
398415 os .path .join (tmpdir , "agent_step2_batch0.csv" ),
399- schema_overrides = { "seed" : pl . Utf8 } ,
416+ schema_overrides = agent_overrides ,
400417 )
401418 assert set (agent_df .columns ) == {
402419 "wealth" ,
403420 "age_ExampleAgentSet1" ,
404421 "age_ExampleAgentSet2" ,
405422 "age_ExampleAgentSet3" ,
423+ "unique_id_ExampleAgentSet1" ,
424+ "unique_id_ExampleAgentSet2" ,
425+ "unique_id_ExampleAgentSet3" ,
406426 "step" ,
407427 "seed" ,
408428 "batch" ,
@@ -420,7 +440,7 @@ def test_flush_local_csv(self, fix1_model):
420440
421441 agent_df = pl .read_csv (
422442 os .path .join (tmpdir , "agent_step4_batch0.csv" ),
423- schema_overrides = { "seed" : pl . Utf8 } ,
443+ schema_overrides = agent_overrides ,
424444 )
425445 assert agent_df ["step" ].to_list () == [4 , 4 , 4 , 4 ]
426446 assert agent_df ["wealth" ].to_list () == [5 , 6 , 7 , 8 ]
@@ -474,10 +494,13 @@ def test_flush_local_parquet(self, fix1_model):
474494 reason = "PostgreSQL tests are skipped on Windows runners" ,
475495 )
476496 def test_postgress (self , fix1_model , postgres_uri ):
477- model = fix1_model
497+ try :
498+ conn = psycopg2 .connect (postgres_uri )
499+ conn .close ()
500+ except psycopg2 .OperationalError :
501+ pytest .skip ("PostgreSQL not available" )
478502
479- # Connect directly and validate data
480- import psycopg2
503+ model = fix1_model
481504
482505 conn = psycopg2 .connect (postgres_uri )
483506 cur = conn .cursor ()
@@ -496,6 +519,9 @@ def test_postgress(self, fix1_model, postgres_uri):
496519 step INTEGER,
497520 seed VARCHAR,
498521 batch INTEGER,
522+ "unique_id_ExampleAgentSet1" INTEGER,
523+ "unique_id_ExampleAgentSet2" INTEGER,
524+ "unique_id_ExampleAgentSet3" INTEGER,
499525 age_ExampleAgentSet1 INTEGER,
500526 age_ExampleAgentSet2 INTEGER,
501527 age_ExampleAgentSet3 INTEGER,
@@ -580,12 +606,15 @@ def test_batch_memory(self, fix2_model):
580606 assert collected_data ["model" ]["batch" ].to_list () == [0 , 1 , 0 , 1 ]
581607 assert collected_data ["model" ]["total_agents" ].to_list () == [12 , 12 , 12 , 12 ]
582608
583- assert collected_data ["agent" ].shape == (16 , 7 )
609+ assert collected_data ["agent" ].shape == (16 , 10 )
584610 assert set (collected_data ["agent" ].columns ) == {
585611 "wealth" ,
586612 "age_ExampleAgentSet1" ,
587613 "age_ExampleAgentSet2" ,
588614 "age_ExampleAgentSet3" ,
615+ "unique_id_ExampleAgentSet1" ,
616+ "unique_id_ExampleAgentSet2" ,
617+ "unique_id_ExampleAgentSet3" ,
589618 "step" ,
590619 "seed" ,
591620 "batch" ,
@@ -596,6 +625,9 @@ def test_batch_memory(self, fix2_model):
596625 "age_ExampleAgentSet1" ,
597626 "age_ExampleAgentSet2" ,
598627 "age_ExampleAgentSet3" ,
628+ "unique_id_ExampleAgentSet1" ,
629+ "unique_id_ExampleAgentSet2" ,
630+ "unique_id_ExampleAgentSet3" ,
599631 "step" ,
600632 "seed" ,
601633 "batch" ,
@@ -773,16 +805,26 @@ def test_batch_save(self, fix2_model):
773805 assert model_df_step4_batch0 ["step" ].to_list () == [4 ]
774806 assert model_df_step4_batch0 ["total_agents" ].to_list () == [12 ]
775807
808+ agent_overrides = {
809+ "seed" : pl .Utf8 ,
810+ "unique_id_ExampleAgentSet1" : pl .Utf8 ,
811+ "unique_id_ExampleAgentSet2" : pl .Utf8 ,
812+ "unique_id_ExampleAgentSet3" : pl .Utf8 ,
813+ }
814+
776815 # test agent batch reset
777816 agent_df_step2_batch0 = pl .read_csv (
778817 os .path .join (tmpdir , "agent_step2_batch0.csv" ),
779- schema_overrides = { "seed" : pl . Utf8 } ,
818+ schema_overrides = agent_overrides ,
780819 )
781820 assert set (agent_df_step2_batch0 .columns ) == {
782821 "wealth" ,
783822 "age_ExampleAgentSet1" ,
784823 "age_ExampleAgentSet2" ,
785824 "age_ExampleAgentSet3" ,
825+ "unique_id_ExampleAgentSet1" ,
826+ "unique_id_ExampleAgentSet2" ,
827+ "unique_id_ExampleAgentSet3" ,
786828 "step" ,
787829 "seed" ,
788830 "batch" ,
@@ -810,13 +852,16 @@ def test_batch_save(self, fix2_model):
810852
811853 agent_df_step2_batch1 = pl .read_csv (
812854 os .path .join (tmpdir , "agent_step2_batch1.csv" ),
813- schema_overrides = { "seed" : pl . Utf8 } ,
855+ schema_overrides = agent_overrides ,
814856 )
815857 assert set (agent_df_step2_batch1 .columns ) == {
816858 "wealth" ,
817859 "age_ExampleAgentSet1" ,
818860 "age_ExampleAgentSet2" ,
819861 "age_ExampleAgentSet3" ,
862+ "unique_id_ExampleAgentSet1" ,
863+ "unique_id_ExampleAgentSet2" ,
864+ "unique_id_ExampleAgentSet3" ,
820865 "step" ,
821866 "seed" ,
822867 "batch" ,
@@ -844,13 +889,16 @@ def test_batch_save(self, fix2_model):
844889
845890 agent_df_step4_batch0 = pl .read_csv (
846891 os .path .join (tmpdir , "agent_step4_batch0.csv" ),
847- schema_overrides = { "seed" : pl . Utf8 } ,
892+ schema_overrides = agent_overrides ,
848893 )
849894 assert set (agent_df_step4_batch0 .columns ) == {
850895 "wealth" ,
851896 "age_ExampleAgentSet1" ,
852897 "age_ExampleAgentSet2" ,
853898 "age_ExampleAgentSet3" ,
899+ "unique_id_ExampleAgentSet1" ,
900+ "unique_id_ExampleAgentSet2" ,
901+ "unique_id_ExampleAgentSet3" ,
854902 "step" ,
855903 "seed" ,
856904 "batch" ,
0 commit comments