From bc4eb47ec0bdd140b17d44814361b6fe0d5661b8 Mon Sep 17 00:00:00 2001 From: Ben Date: Sun, 30 Mar 2025 02:54:51 +0530 Subject: [PATCH 1/2] fixed and added test --- mesa_frames/concrete/polars/agentset.py | 11 ++++++++++- tests/polars/test_agentset_polars.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/mesa_frames/concrete/polars/agentset.py b/mesa_frames/concrete/polars/agentset.py index 4bec1ea5..0dedd32f 100644 --- a/mesa_frames/concrete/polars/agentset.py +++ b/mesa_frames/concrete/polars/agentset.py @@ -136,6 +136,7 @@ def add( if new_agents["unique_id"].dtype != pl.Int64: raise TypeError("unique_id column must be of type int64.") + # If self._mask is pl.Expr, then new mask is the same. # If self._mask is pl.Series[bool], then new mask has to be updated. @@ -143,7 +144,15 @@ def add( if isinstance(obj._mask, pl.Series): original_active_indices = obj._agents.filter(obj._mask)["unique_id"] - obj._agents = pl.concat([obj._agents, new_agents], how="diagonal_relaxed") + + + combined_agents= pl.concat([obj._agents, new_agents], how="diagonal_relaxed") + if combined_agents["unique_id"].is_duplicated().any(): + raise ValueError( + "Some ids are duplicated in the AgentSet that are trying to be added." + ) + + obj._agents = combined_agents if isinstance(obj._mask, pl.Series): obj._update_mask(original_active_indices, new_agents["unique_id"]) diff --git a/tests/polars/test_agentset_polars.py b/tests/polars/test_agentset_polars.py index 9c311727..9835cb1d 100644 --- a/tests/polars/test_agentset_polars.py +++ b/tests/polars/test_agentset_polars.py @@ -32,6 +32,17 @@ def fix1_AgentSetPolars() -> ExampleAgentSetPolars: return agents +@pytest.fixture +def fix4_AgentSetPolars() -> ExampleAgentSetPolars: + model = ModelDF() + agents = ExampleAgentSetPolars(model) + agents.add({"unique_id": [0, 1, 2, 3]}) + agents["wealth"] = agents.starting_wealth + agents["age"] = [10, 20, 30, 40] + model.agents.add(agents) + return agents + + @pytest.fixture def fix2_AgentSetPolars() -> ExampleAgentSetPolars: model = ModelDF() @@ -73,9 +84,17 @@ def test_add( self, fix1_AgentSetPolars: ExampleAgentSetPolars, fix2_AgentSetPolars: ExampleAgentSetPolars, + fix4_AgentSetPolars: ExampleAgentSetPolars, ): agents = fix1_AgentSetPolars agents2 = fix2_AgentSetPolars + agents4 = fix4_AgentSetPolars + + with pytest.raises( + ValueError, + match="Some ids are duplicated in the AgentSet that are trying to be added.", + ): + result = agents.add(agents4.agents, inplace=False) # Test with a DataFrame result = agents.add(agents2.agents, inplace=False) From cfcbd5c87e7af826bf4789946c60d957d802befb Mon Sep 17 00:00:00 2001 From: Ben Date: Sun, 30 Mar 2025 03:01:29 +0530 Subject: [PATCH 2/2] clearer error --- mesa_frames/concrete/polars/agentset.py | 7 ++----- tests/polars/test_agentset_polars.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mesa_frames/concrete/polars/agentset.py b/mesa_frames/concrete/polars/agentset.py index 0dedd32f..85b2d498 100644 --- a/mesa_frames/concrete/polars/agentset.py +++ b/mesa_frames/concrete/polars/agentset.py @@ -136,7 +136,6 @@ def add( if new_agents["unique_id"].dtype != pl.Int64: raise TypeError("unique_id column must be of type int64.") - # If self._mask is pl.Expr, then new mask is the same. # If self._mask is pl.Series[bool], then new mask has to be updated. @@ -144,12 +143,10 @@ def add( if isinstance(obj._mask, pl.Series): original_active_indices = obj._agents.filter(obj._mask)["unique_id"] - - - combined_agents= pl.concat([obj._agents, new_agents], how="diagonal_relaxed") + combined_agents = pl.concat([obj._agents, new_agents], how="diagonal_relaxed") if combined_agents["unique_id"].is_duplicated().any(): raise ValueError( - "Some ids are duplicated in the AgentSet that are trying to be added." + "Some ids are duplicated in the AgentSet that are trying to be added together." ) obj._agents = combined_agents diff --git a/tests/polars/test_agentset_polars.py b/tests/polars/test_agentset_polars.py index 9835cb1d..180a5b9d 100644 --- a/tests/polars/test_agentset_polars.py +++ b/tests/polars/test_agentset_polars.py @@ -92,7 +92,7 @@ def test_add( with pytest.raises( ValueError, - match="Some ids are duplicated in the AgentSet that are trying to be added.", + match="Some ids are duplicated in the AgentSet that are trying to be added together.", ): result = agents.add(agents4.agents, inplace=False)