Skip to content

Commit

Permalink
filter on bin label when removing ref only ccs
Browse files Browse the repository at this point in the history
  • Loading branch information
nlouwen committed Jan 20, 2025
1 parent 30c9890 commit c768fb1
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 14 deletions.
5 changes: 4 additions & 1 deletion big_scape/data/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Compiled,
Select,
Insert,
Delete,
CursorResult,
create_engine,
func,
Expand Down Expand Up @@ -283,7 +284,9 @@ def execute_raw_query(query: str) -> CursorResult:
return DB.connection.execute(text(query))

@staticmethod
def execute(query: Compiled | Select[Any] | Insert, commit=True) -> CursorResult:
def execute(
query: Compiled | Select[Any] | Insert | Delete, commit=True
) -> CursorResult:
"""Wrapper for SQLAlchemy.connection.execute expecting a Compiled query
Arguments:
Expand Down
2 changes: 1 addition & 1 deletion big_scape/distances/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def calculate_distances_query(
query_nodes = bs_network.get_nodes_from_cc(query_connected_component, query_records)

bs_network.remove_connected_component(
query_connected_component, max_cutoff, run["run_id"]
query_connected_component, query_bin.label, max_cutoff, run["run_id"]
)

query_bin_connected = bs_comparison.RecordPairGenerator(
Expand Down
2 changes: 1 addition & 1 deletion big_scape/network/families.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def run_family_assignments(
connected_component, bin.source_records
):
bs_network.remove_connected_component(
connected_component, cutoff, run["run_id"]
connected_component, bin.label, cutoff, run["run_id"]
)
continue

Expand Down
18 changes: 11 additions & 7 deletions big_scape/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def reference_only_connected_component(connected_component, bgc_records) -> bool


def get_connected_component_id(
connected_component: list, cutoff: float, run_id: int
connected_component: list, bin_label: str, cutoff: float, run_id: int
) -> int:
"""Get the connected component id for the given connected component
expects all edges to be in one connected component, if thats not the
Expand All @@ -575,33 +575,37 @@ def get_connected_component_id(
.distinct()
.where(
and_(
cc_table.c.cutoff == cutoff,
cc_table.c.record_id == record_id,
cc_table.c.bin_label == bin_label,
cc_table.c.cutoff == cutoff,
cc_table.c.run_id == run_id,
)
)
.limit(1)
)

cc_ids = DB.execute(select_statement).fetchone()
cc_id = DB.execute(select_statement).scalar_one()

return cc_ids[0]
return cc_id


def remove_connected_component(
connected_component: list, cutoff: float, run_id: int
connected_component: list, bin_label: str, cutoff: float, run_id: int
) -> None:
"""Removes a connected component from the cc table in the database"""

if DB.metadata is None:
raise RuntimeError("DB.metadata is None")

cc_id = get_connected_component_id(connected_component, cutoff, run_id)
cc_id = get_connected_component_id(connected_component, bin_label, cutoff, run_id)

cc_table = DB.metadata.tables["connected_component"]

delete_statement = delete(cc_table).where(
cc_table.c.id == cc_id, cc_table.c.cutoff == cutoff, cc_table.c.run_id == run_id
cc_table.c.id == cc_id,
cc_table.c.bin_label == bin_label,
cc_table.c.cutoff == cutoff,
cc_table.c.run_id == run_id,
)

DB.execute(delete_statement)
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def test_get_connected_components_no_ref_to_ref_ccs(self):
cc, include_records
)
if is_ref_only:
bs_network.remove_connected_component(cc, 0.8, 1)
bs_network.remove_connected_component(cc, mix_bin.label, 0.8, 1)

cc_table = bs_data.DB.metadata.tables["connected_component"]

Expand All @@ -488,7 +488,7 @@ def test_get_connected_components_no_ref_to_ref_ccs(self):
cc, include_records
)
if is_ref_only:
bs_network.remove_connected_component(cc, 0.5, 1)
bs_network.remove_connected_component(cc, mix_bin.label, 0.5, 1)

cc_table = bs_data.DB.metadata.tables["connected_component"]

Expand Down
4 changes: 2 additions & 2 deletions test/network/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_get_connected_component_id(self):

cc = next(bs_network.get_connected_components(0.5, 1, mix_bin, 1))

cc_id = bs_network.get_connected_component_id(cc, 0.5, 1)
cc_id = bs_network.get_connected_component_id(cc, mix_bin.label, 0.5, 1)

expected_data = 1

Expand Down Expand Up @@ -528,7 +528,7 @@ def test_remove_connected_component(self):
cc, include_records
)
if is_ref_only:
bs_network.remove_connected_component(cc, 0.5, 1)
bs_network.remove_connected_component(cc, mix_bin.label, 0.5, 1)

select_statement = select(cc_table.c.id).distinct()

Expand Down

0 comments on commit c768fb1

Please sign in to comment.