Skip to content

Commit

Permalink
Merge pull request #71 from medema-group/feature/classify-query
Browse files Browse the repository at this point in the history
Feature/classify query
  • Loading branch information
adraismawur authored Nov 1, 2023
2 parents d88c864 + 6dc49b2 commit 4645f96
Show file tree
Hide file tree
Showing 19 changed files with 648 additions and 112 deletions.
1 change: 1 addition & 0 deletions big_scape/cli/benchmark_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def benchmark(ctx, *args, **kwargs):
"""
# get context parameters
ctx.obj.update(ctx.params)
ctx.obj["mode"] = "Benchmark"

# workflow validations
validate_output_paths(ctx)
Expand Down
16 changes: 0 additions & 16 deletions big_scape/cli/cli_common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,22 +252,6 @@ def common_cluster_query(fn):
"the listed accessions will be analysed."
),
),
# comparison parameters
click.option(
"--legacy_classify",
is_flag=True,
help=(
"Does not use antiSMASH BGC classes to run analyses on "
"class-based bins, instead it uses BiG-SCAPE v1 predefined groups: "
"PKS1, PKSOther, NRPS, NRPS-PKS-hybrid, RiPP, Saccharide, Terpene, Others."
"Will also use BiG-SCAPEv1 legacy_weights for distance calculations."
"This feature is available for backwards compatibility with "
"antiSMASH versions up to v7. For higher antiSMASH versions, use"
" at your own risk, as BGC classes may have changed. All antiSMASH"
"classes that this legacy mode does not recognize will be grouped in"
" 'others'."
),
),
click.option(
"--legacy_weights",
is_flag=True,
Expand Down
17 changes: 17 additions & 0 deletions big_scape/cli/cluster_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@
"all input BGCs to the query in a one-vs-all mode."
),
)
# comparison parameters
@click.option(
"--legacy_classify",
is_flag=True,
help=(
"Does not use antiSMASH BGC classes to run analyses on "
"class-based bins, instead it uses BiG-SCAPE v1 predefined groups: "
"PKS1, PKSOther, NRPS, NRPS-PKS-hybrid, RiPP, Saccharide, Terpene, Others."
"Will also use BiG-SCAPEv1 legacy_weights for distance calculations."
"This feature is available for backwards compatibility with "
"antiSMASH versions up to v7. For higher antiSMASH versions, use"
" at your own risk, as BGC classes may have changed. All antiSMASH"
"classes that this legacy mode does not recognize will be grouped in"
" 'others'."
),
)
# binning parameters
@click.option("--no_mix", is_flag=True, help=("Dont run the all-vs-all analysis"))
# networking parameters
Expand All @@ -52,6 +68,7 @@ def cluster(ctx, *args, **kwargs):
# get context parameters
ctx.obj.update(ctx.params)
ctx.obj["query_bgc_path"] = None
ctx.obj["mode"] = "Cluster"

# workflow validations
validate_binning_cluster_workflow(ctx)
Expand Down
2 changes: 2 additions & 0 deletions big_scape/cli/query_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def query(ctx, *args, **kwarg):
# get context parameters
ctx.obj.update(ctx.params)
ctx.obj["no_mix"] = None
ctx.obj["legacy_classify"] = False
ctx.obj["mode"] = "Query"

# workflow validations
validate_skip_hmmscan(ctx)
Expand Down
2 changes: 2 additions & 0 deletions big_scape/comparison/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
QueryToRefRecordPairGenerator,
RefToRefRecordPairGenerator,
MissingRecordPairGenerator,
ConnectedComponenetPairGenerator,
generate_mix,
legacy_bin_generator,
legacy_get_class,
Expand All @@ -24,6 +25,7 @@
"QueryToRefRecordPairGenerator",
"RefToRefRecordPairGenerator",
"MissingRecordPairGenerator",
"ConnectedComponenetPairGenerator",
"generate_mix",
"ComparableRegion",
"generate_edges",
Expand Down
127 changes: 118 additions & 9 deletions big_scape/comparison/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,40 @@ def add_records(self, record_list: list[BGCRecord]):
if None in self.record_ids:
raise ValueError("Region in bin has no db id!")

def cull_singletons(self, cutoff: float):
"""Culls singletons for given cutoff, i.e. records which have either no edges
in the database, or all edges have a distance above/equal to the cutoff"""

if not DB.metadata:
raise RuntimeError("DB.metadata is None")

distance_table = DB.metadata.tables["distance"]

# get all distances in the table below the cutoff
select_statement = (
select(distance_table.c.region_a_id, distance_table.c.region_b_id)
.where(
distance_table.c.region_a_id.in_(self.record_ids)
| distance_table.c.region_b_id.in_(self.record_ids)
)
.where(distance_table.c.distance < cutoff)
.where(distance_table.c.weights == self.weights)
)

edges = DB.execute(select_statement).fetchall()

# get all record_ids in the edges
filtered_record_ids = set()
for edge in edges:
filtered_record_ids.update(edge)

self.record_ids = filtered_record_ids
self.source_records = [
record
for record in self.source_records
if record._db_id in filtered_record_ids
]

def __repr__(self) -> str:
return (
f"Bin '{self.label}': {self.num_pairs()} pairs from "
Expand All @@ -141,8 +175,8 @@ class QueryToRefRecordPairGenerator(RecordPairGenerator):
ref <-> ref pairs
"""

def __init__(self, label: str):
super().__init__(label)
def __init__(self, label: str, weights: Optional[str] = None):
super().__init__(label, weights)
self.reference_records: list[BGCRecord] = []
self.query_records: list[BGCRecord] = []

Expand Down Expand Up @@ -225,10 +259,11 @@ class RefToRefRecordPairGenerator(RecordPairGenerator):
source_records (list[BGCRecord]): List of BGC records to generate pairs from
"""

def __init__(self, label: str):
def __init__(self, label: str, weights: Optional[str] = None):
self.record_id_to_obj: dict[int, BGCRecord] = {}
self.reference_record_ids: set[int] = set()
self.done_record_ids: set[int] = set()
super().__init__(label)
super().__init__(label, weights)

def generate_pairs(self, legacy_sorting=False) -> Generator[RecordPair, None, None]:
"""Returns an Generator for Region pairs in this bin, pairs are only generated between
Expand Down Expand Up @@ -288,6 +323,9 @@ def add_records(self, record_list: list[BGCRecord]):
raise ValueError("Region in bin has no db id!")

self.record_id_to_obj[record._db_id] = record
if record.parent_gbk is not None:
if record.parent_gbk.source_type == SOURCE_TYPE.REFERENCE:
self.reference_record_ids.add(record._db_id)

return super().add_records(record_list)

Expand Down Expand Up @@ -315,16 +353,18 @@ def get_connected_reference_nodes(self) -> set[BGCRecord]:
select(distance_table.c.region_a_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
),
bgc_record_table.c.id.in_(
select(distance_table.c.region_b_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
),
)
)
.where(bgc_record_table.c.id.notin_(self.done_record_ids))
.where(gbk_table.c.source_type == SOURCE_TYPE.REFERENCE.value)
.where(bgc_record_table.c.id.in_(self.reference_record_ids))
.join(gbk_table, bgc_record_table.c.gbk_id == gbk_table.c.id)
)

Expand Down Expand Up @@ -364,16 +404,18 @@ def get_connected_reference_node_count(self) -> int:
select(distance_table.c.region_a_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
),
bgc_record_table.c.id.in_(
select(distance_table.c.region_b_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
),
)
)
.where(bgc_record_table.c.id.notin_(self.done_record_ids))
.where(gbk_table.c.source_type == SOURCE_TYPE.REFERENCE.value)
.where(bgc_record_table.c.id.in_(self.reference_record_ids))
.join(gbk_table, bgc_record_table.c.gbk_id == gbk_table.c.id)
)

Expand Down Expand Up @@ -406,16 +448,18 @@ def get_singleton_reference_nodes(self) -> set[BGCRecord]:
select(distance_table.c.region_a_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
)
)
.where(
bgc_record_table.c.id.notin_(
select(distance_table.c.region_b_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
)
)
.where(gbk_table.c.source_type == SOURCE_TYPE.REFERENCE.value)
.where(bgc_record_table.c.id.in_(self.reference_record_ids))
.join(gbk_table, bgc_record_table.c.gbk_id == gbk_table.c.id)
)

Expand Down Expand Up @@ -463,7 +507,7 @@ def get_singleton_reference_node_count(self) -> int:
.where(distance_table.c.distance < 1.0)
)
)
.where(gbk_table.c.source_type == SOURCE_TYPE.REFERENCE.value)
.where(bgc_record_table.c.id.in_(self.reference_record_ids))
.join(gbk_table, bgc_record_table.c.gbk_id == gbk_table.c.id)
)

Expand All @@ -472,6 +516,64 @@ def get_singleton_reference_node_count(self) -> int:
return singleton_reference_node_count


class ConnectedComponenetPairGenerator(RecordPairGenerator):
"""Generator that takes as input a conected component and generates
all pairs from the nodes in the component"""

def __init__(self, connected_component, label: str):
super().__init__(label)
self.connected_component = connected_component
self.record_id_to_obj: dict[int, BGCRecord] = {}

def add_records(self, record_list: list[BGCRecord]):
"""Adds BGC records to this bin and creates a generator for the pairs
also creates a dictionary of record id to record objects
"""
cc_record_ids = set()
cc_record_list = []

for edge in self.connected_component:
record_a_id, record_b_id, dist, jacc, adj, dss, weights = edge
# Ensure that the correct weights are used,
# the weights are set during the binning process
self.weights = weights
cc_record_ids.add(record_a_id)
cc_record_ids.add(record_b_id)

for record in record_list:
if record._db_id is None:
raise ValueError("Region in bin has no db id!")
if record._db_id not in cc_record_ids:
continue

self.record_id_to_obj[record._db_id] = record
cc_record_list.append(record)

return super().add_records(cc_record_list)

def generate_pairs(self, legacy_sorting=False) -> Generator[RecordPair, None, None]:
"""Returns a Generator for all pairs in this bin"""

for edge in self.connected_component:
record_a_id, record_b_id, dist, jacc, adj, dss, weights = edge
if self.weights != weights:
logging.error(
"Edge in connected component does not have the same weight as the bin!"
)

record_a = self.record_id_to_obj[record_a_id]
record_b = self.record_id_to_obj[record_b_id]

if legacy_sorting:
sorted_a, sorted_b = sorted((record_a, record_b), key=sort_name_key)
pair = RecordPair(sorted_a, sorted_b)
else:
pair = RecordPair(record_a, record_b)

yield pair


class MissingRecordPairGenerator(RecordPairGenerator):
"""Generator that wraps around another RecordPairGenerator to exclude any distances
already in the database
Expand Down Expand Up @@ -513,14 +615,21 @@ def generate_pairs(self, legacy_sorting=False) -> Generator[RecordPair, None, No
select(distance_table.c.region_a_id, distance_table.c.region_b_id)
.where(distance_table.c.region_a_id.in_(self.bin.record_ids))
.where(distance_table.c.region_b_id.in_(self.bin.record_ids))
.where(distance_table.c.weights == self.bin.weights)
)

# generate a set of tuples of region id pairs
existing_distances = set(DB.execute(select_statement).fetchall())

for pair in self.bin.generate_pairs(legacy_sorting):
# if the pair is not in the set of existing distances, yield it
if (pair.region_a._db_id, pair.region_b._db_id) not in existing_distances:
if (
pair.region_a._db_id,
pair.region_b._db_id,
) not in existing_distances and (
pair.region_a._db_id,
pair.region_b._db_id,
) not in existing_distances:
yield pair

def add_records(self, _: list[BGCRecord]):
Expand Down
3 changes: 1 addition & 2 deletions big_scape/data/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
CREATE TABLE IF NOT EXISTS gbk (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT,
source_type TEXT,
nt_seq TEXT,
UNIQUE(path)
);
Expand Down Expand Up @@ -76,7 +75,7 @@ CREATE TABLE IF NOT EXISTS distance (
adjacency REAL NOT NULL,
dss REAL NOT NULL,
weights TEXT NOT NULL,
UNIQUE(region_a_id, region_b_id)
UNIQUE(region_a_id, region_b_id, weights)
FOREIGN KEY(region_a_id) REFERENCES bgc_record(id)
FOREIGN KEY(region_b_id) REFERENCES bgc_record(id)
);
Loading

0 comments on commit 4645f96

Please sign in to comment.