Skip to content

Commit f5be02a

Browse files
committed
feat: allow collecting agents reporters by unique id
1 parent 41cee3c commit f5be02a

File tree

2 files changed

+83
-25
lines changed

2 files changed

+83
-25
lines changed

mesa_frames/concrete/datacollector.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ def _is_str_collection(x: Any) -> bool:
190190

191191
agent_data_dict: dict[str, pl.Series] = {}
192192

193+
for set_name, aset in self._model.sets.items():
194+
agent_data_dict[f"unique_id_{set_name}"] = aset["unique_id"]
195+
193196
for col_name, reporter in self._agent_reporters.items():
194197
# 1) String or collection[str]: shorthand to fetch columns
195198
if isinstance(reporter, str) or _is_str_collection(reporter):
@@ -449,18 +452,20 @@ def _validate_inputs(self):
449452
- Ensures a `storage_uri` is provided if needed.
450453
- For PostgreSQL, validates that required tables and columns exist.
451454
"""
452-
if self._storage != "memory" and self._storage_uri == None:
455+
if self._storage != "memory" and self._storage_uri is None:
453456
raise ValueError(
454457
"Please define a storage_uri to if to be stored not in memory"
455458
)
456459

457460
if self._storage == "postgresql":
458-
conn = self._get_db_connection(self._storage_uri)
461+
conn = None
459462
try:
463+
conn = self._get_db_connection(self._storage_uri)
460464
self._validate_postgress_table_exists(conn)
461465
self._validate_postgress_columns_exists(conn)
462466
finally:
463-
conn.close()
467+
if conn:
468+
conn.close()
464469

465470
def _validate_postgress_table_exists(self, conn: connection):
466471
"""
@@ -556,6 +561,11 @@ def _is_str_collection(x: Any) -> bool:
556561
return False
557562

558563
expected_columns: set[str] = set()
564+
565+
if table_name == "agent_data":
566+
for set_name, _ in self._model.sets.items():
567+
expected_columns.add(f"unique_id_{set_name}".lower())
568+
559569
for col_name, req in reporter.items():
560570
# Strings → one column per set with suffix
561571
if isinstance(req, str):

tests/test_datacollector.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import psycopg2
12
from mesa_frames.concrete.datacollector import DataCollector
23
from mesa_frames import Model, AgentSet, AgentSetRegistry
34
import pytest
@@ -15,6 +16,7 @@ def custom_trigger(model):
1516
class ExampleAgentSet1(AgentSet):
1617
def __init__(self, model: Model):
1718
super().__init__(model)
19+
self["unique_id"] = pl.Series("unique_id", [101, 102, 103, 104], dtype=pl.Int64)
1820
self["wealth"] = pl.Series("wealth", [1, 2, 3, 4])
1921
self["age"] = pl.Series("age", [10, 20, 30, 40])
2022

@@ -28,6 +30,7 @@ def step(self) -> None:
2830
class ExampleAgentSet2(AgentSet):
2931
def __init__(self, model: Model):
3032
super().__init__(model)
33+
self["unique_id"] = pl.Series("unique_id", [201, 202, 203, 204], dtype=pl.Int64)
3134
self["wealth"] = pl.Series("wealth", [10, 20, 30, 40])
3235
self["age"] = pl.Series("age", [11, 22, 33, 44])
3336

@@ -41,6 +44,7 @@ def step(self) -> None:
4144
class ExampleAgentSet3(AgentSet):
4245
def __init__(self, model: Model):
4346
super().__init__(model)
47+
self["unique_id"] = pl.Series("unique_id", [301, 302, 303, 304], dtype=pl.Int64)
4448
self["age"] = pl.Series("age", [1, 2, 3, 4])
4549
self["wealth"] = pl.Series("wealth", [1, 2, 3, 4])
4650

@@ -147,11 +151,16 @@ def test__init__(self, fix1_model, postgres_uri):
147151
):
148152
model.test_dc = DataCollector(model=model, storage="S3-csv")
149153

154+
try:
155+
psycopg2.connect(postgres_uri)
156+
except psycopg2.OperationalError:
157+
pass
158+
150159
with pytest.raises(
151160
ValueError,
152161
match="Please define a storage_uri to if to be stored not in memory",
153162
):
154-
model.test_dc = DataCollector(model=model, storage="postgresql")
163+
model.test_dc = DataCollector(model=model, storage="postgresql", storage_uri=None)
155164

156165
def test_collect(self, fix1_model):
157166
model = fix1_model
@@ -185,12 +194,15 @@ def test_collect(self, fix1_model):
185194
with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"):
186195
collected_data["model"]["max_wealth"]
187196

188-
assert collected_data["agent"].shape == (4, 7)
197+
assert collected_data["agent"].shape == (4, 10)
189198
assert set(collected_data["agent"].columns) == {
190199
"wealth",
191200
"age_ExampleAgentSet1",
192201
"age_ExampleAgentSet2",
193202
"age_ExampleAgentSet3",
203+
"unique_id_ExampleAgentSet1",
204+
"unique_id_ExampleAgentSet2",
205+
"unique_id_ExampleAgentSet3",
194206
"step",
195207
"seed",
196208
"batch",
@@ -242,12 +254,15 @@ def test_collect_step(self, fix1_model):
242254
assert collected_data["model"]["step"].to_list() == [5]
243255
assert collected_data["model"]["total_agents"].to_list() == [12]
244256

245-
assert collected_data["agent"].shape == (4, 7)
257+
assert collected_data["agent"].shape == (4, 10)
246258
assert set(collected_data["agent"].columns) == {
247259
"wealth",
248260
"age_ExampleAgentSet1",
249261
"age_ExampleAgentSet2",
250262
"age_ExampleAgentSet3",
263+
"unique_id_ExampleAgentSet1",
264+
"unique_id_ExampleAgentSet2",
265+
"unique_id_ExampleAgentSet3",
251266
"step",
252267
"seed",
253268
"batch",
@@ -297,25 +312,20 @@ def test_conditional_collect(self, fix1_model):
297312
assert collected_data["model"]["step"].to_list() == [2, 4]
298313
assert collected_data["model"]["total_agents"].to_list() == [12, 12]
299314

300-
assert collected_data["agent"].shape == (8, 7)
301-
assert set(collected_data["agent"].columns) == {
302-
"wealth",
303-
"age_ExampleAgentSet1",
304-
"age_ExampleAgentSet2",
305-
"age_ExampleAgentSet3",
306-
"step",
307-
"seed",
308-
"batch",
309-
}
315+
assert collected_data["agent"].shape == (8, 10)
310316
assert set(collected_data["agent"].columns) == {
311317
"wealth",
312318
"age_ExampleAgentSet1",
313319
"age_ExampleAgentSet2",
314320
"age_ExampleAgentSet3",
321+
"unique_id_ExampleAgentSet1",
322+
"unique_id_ExampleAgentSet2",
323+
"unique_id_ExampleAgentSet3",
315324
"step",
316325
"seed",
317326
"batch",
318327
}
328+
319329
assert collected_data["agent"]["wealth"].to_list() == [3, 4, 5, 6, 5, 6, 7, 8]
320330
assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [
321331
10,
@@ -394,15 +404,25 @@ def test_flush_local_csv(self, fix1_model):
394404
assert model_df["step"].to_list() == [2]
395405
assert model_df["total_agents"].to_list() == [12]
396406

407+
agent_overrides = {
408+
"seed": pl.Utf8,
409+
"unique_id_ExampleAgentSet1": pl.Utf8,
410+
"unique_id_ExampleAgentSet2": pl.Utf8,
411+
"unique_id_ExampleAgentSet3": pl.Utf8,
412+
}
413+
397414
agent_df = pl.read_csv(
398415
os.path.join(tmpdir, "agent_step2_batch0.csv"),
399-
schema_overrides={"seed": pl.Utf8},
416+
schema_overrides=agent_overrides,
400417
)
401418
assert set(agent_df.columns) == {
402419
"wealth",
403420
"age_ExampleAgentSet1",
404421
"age_ExampleAgentSet2",
405422
"age_ExampleAgentSet3",
423+
"unique_id_ExampleAgentSet1",
424+
"unique_id_ExampleAgentSet2",
425+
"unique_id_ExampleAgentSet3",
406426
"step",
407427
"seed",
408428
"batch",
@@ -420,7 +440,7 @@ def test_flush_local_csv(self, fix1_model):
420440

421441
agent_df = pl.read_csv(
422442
os.path.join(tmpdir, "agent_step4_batch0.csv"),
423-
schema_overrides={"seed": pl.Utf8},
443+
schema_overrides=agent_overrides,
424444
)
425445
assert agent_df["step"].to_list() == [4, 4, 4, 4]
426446
assert agent_df["wealth"].to_list() == [5, 6, 7, 8]
@@ -474,10 +494,13 @@ def test_flush_local_parquet(self, fix1_model):
474494
reason="PostgreSQL tests are skipped on Windows runners",
475495
)
476496
def test_postgress(self, fix1_model, postgres_uri):
477-
model = fix1_model
497+
try:
498+
conn = psycopg2.connect(postgres_uri)
499+
conn.close()
500+
except psycopg2.OperationalError:
501+
pytest.skip("PostgreSQL not available")
478502

479-
# Connect directly and validate data
480-
import psycopg2
503+
model = fix1_model
481504

482505
conn = psycopg2.connect(postgres_uri)
483506
cur = conn.cursor()
@@ -496,6 +519,9 @@ def test_postgress(self, fix1_model, postgres_uri):
496519
step INTEGER,
497520
seed VARCHAR,
498521
batch INTEGER,
522+
"unique_id_ExampleAgentSet1" INTEGER,
523+
"unique_id_ExampleAgentSet2" INTEGER,
524+
"unique_id_ExampleAgentSet3" INTEGER,
499525
age_ExampleAgentSet1 INTEGER,
500526
age_ExampleAgentSet2 INTEGER,
501527
age_ExampleAgentSet3 INTEGER,
@@ -580,12 +606,15 @@ def test_batch_memory(self, fix2_model):
580606
assert collected_data["model"]["batch"].to_list() == [0, 1, 0, 1]
581607
assert collected_data["model"]["total_agents"].to_list() == [12, 12, 12, 12]
582608

583-
assert collected_data["agent"].shape == (16, 7)
609+
assert collected_data["agent"].shape == (16, 10)
584610
assert set(collected_data["agent"].columns) == {
585611
"wealth",
586612
"age_ExampleAgentSet1",
587613
"age_ExampleAgentSet2",
588614
"age_ExampleAgentSet3",
615+
"unique_id_ExampleAgentSet1",
616+
"unique_id_ExampleAgentSet2",
617+
"unique_id_ExampleAgentSet3",
589618
"step",
590619
"seed",
591620
"batch",
@@ -596,6 +625,9 @@ def test_batch_memory(self, fix2_model):
596625
"age_ExampleAgentSet1",
597626
"age_ExampleAgentSet2",
598627
"age_ExampleAgentSet3",
628+
"unique_id_ExampleAgentSet1",
629+
"unique_id_ExampleAgentSet2",
630+
"unique_id_ExampleAgentSet3",
599631
"step",
600632
"seed",
601633
"batch",
@@ -773,16 +805,26 @@ def test_batch_save(self, fix2_model):
773805
assert model_df_step4_batch0["step"].to_list() == [4]
774806
assert model_df_step4_batch0["total_agents"].to_list() == [12]
775807

808+
agent_overrides = {
809+
"seed": pl.Utf8,
810+
"unique_id_ExampleAgentSet1": pl.Utf8,
811+
"unique_id_ExampleAgentSet2": pl.Utf8,
812+
"unique_id_ExampleAgentSet3": pl.Utf8,
813+
}
814+
776815
# test agent batch reset
777816
agent_df_step2_batch0 = pl.read_csv(
778817
os.path.join(tmpdir, "agent_step2_batch0.csv"),
779-
schema_overrides={"seed": pl.Utf8},
818+
schema_overrides=agent_overrides,
780819
)
781820
assert set(agent_df_step2_batch0.columns) == {
782821
"wealth",
783822
"age_ExampleAgentSet1",
784823
"age_ExampleAgentSet2",
785824
"age_ExampleAgentSet3",
825+
"unique_id_ExampleAgentSet1",
826+
"unique_id_ExampleAgentSet2",
827+
"unique_id_ExampleAgentSet3",
786828
"step",
787829
"seed",
788830
"batch",
@@ -810,13 +852,16 @@ def test_batch_save(self, fix2_model):
810852

811853
agent_df_step2_batch1 = pl.read_csv(
812854
os.path.join(tmpdir, "agent_step2_batch1.csv"),
813-
schema_overrides={"seed": pl.Utf8},
855+
schema_overrides=agent_overrides,
814856
)
815857
assert set(agent_df_step2_batch1.columns) == {
816858
"wealth",
817859
"age_ExampleAgentSet1",
818860
"age_ExampleAgentSet2",
819861
"age_ExampleAgentSet3",
862+
"unique_id_ExampleAgentSet1",
863+
"unique_id_ExampleAgentSet2",
864+
"unique_id_ExampleAgentSet3",
820865
"step",
821866
"seed",
822867
"batch",
@@ -844,13 +889,16 @@ def test_batch_save(self, fix2_model):
844889

845890
agent_df_step4_batch0 = pl.read_csv(
846891
os.path.join(tmpdir, "agent_step4_batch0.csv"),
847-
schema_overrides={"seed": pl.Utf8},
892+
schema_overrides=agent_overrides,
848893
)
849894
assert set(agent_df_step4_batch0.columns) == {
850895
"wealth",
851896
"age_ExampleAgentSet1",
852897
"age_ExampleAgentSet2",
853898
"age_ExampleAgentSet3",
899+
"unique_id_ExampleAgentSet1",
900+
"unique_id_ExampleAgentSet2",
901+
"unique_id_ExampleAgentSet3",
854902
"step",
855903
"seed",
856904
"batch",

0 commit comments

Comments
 (0)