Skip to content

Commit

Permalink
Merge pull request #70 from medema-group/feature/classify-protocluster
Browse files Browse the repository at this point in the history
Feature/classify protocluster
  • Loading branch information
nlouwen authored Oct 26, 2023
2 parents afc1ad8 + 5a77df7 commit d88c864
Show file tree
Hide file tree
Showing 30 changed files with 433 additions and 139 deletions.
8 changes: 8 additions & 0 deletions big_scape/cli/cli_common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
validate_includelist,
validate_gcf_cutoffs,
validate_filter_gbk,
validate_record_type,
validate_classify,
validate_output_dir,
)
Expand Down Expand Up @@ -330,6 +331,13 @@ def common_cluster_query(fn):
type=click.Path(path_type=Path, dir_okay=False),
help="Path to sqlite db output file. Default: output_dir/data_sqlite.db.",
),
click.option(
"--record_type",
type=click.Choice(["region", "proto_cluster", "proto_core"]),
default="region",
callback=validate_record_type,
help="Use a specific type of record for comparison. Default: region",
),
]
for opt in options[::-1]:
fn = opt(fn)
Expand Down
10 changes: 10 additions & 0 deletions big_scape/cli/cli_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,13 @@ def validate_pfam_path(ctx) -> None:
"BiG-SCAPE database not provided, a pfam file is "
"required in order to detect domains."
)


def validate_record_type(ctx, _, record_type) -> Optional[bs_enums.genbank.RECORD_TYPE]:
"""Validates whether a region_type is provided when running classify"""
valid_types = [mode.value for mode in bs_enums.genbank.RECORD_TYPE]

for valid_type in valid_types:
if record_type == valid_type:
return bs_enums.genbank.RECORD_TYPE[valid_type.upper()]
return None
4 changes: 2 additions & 2 deletions big_scape/comparison/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
legacy_get_class,
as_class_bin_generator,
get_weight_category,
get_region_category,
get_record_category,
)
from .comparable_region import ComparableRegion
from .legacy_workflow_alt import generate_edges
Expand All @@ -31,7 +31,7 @@
"legacy_get_class",
"as_class_bin_generator",
"get_weight_category",
"get_region_category",
"get_record_category",
"save_edge_to_db",
"lcs",
"save_edges_to_db",
Expand Down
132 changes: 72 additions & 60 deletions big_scape/comparison/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# from other modules
from big_scape.data import DB
from big_scape.genbank import BGCRecord, GBK, Region
from big_scape.genbank import BGCRecord, Region, ProtoCluster, ProtoCore
from big_scape.enums import SOURCE_TYPE, CLASSIFY_MODE

# from this module
Expand Down Expand Up @@ -493,6 +493,7 @@ def num_pairs(self) -> int:
select(func.count(distance_table.c.region_a_id))
.where(distance_table.c.region_a_id.in_(self.bin.record_ids))
.where(distance_table.c.region_b_id.in_(self.bin.record_ids))
.where(distance_table.c.weights == self.bin.weights)
)

# get count
Expand Down Expand Up @@ -541,8 +542,8 @@ def __init__(self, region_a: BGCRecord, region_b: BGCRecord):
raise ValueError("Region in pair has no parent GBK!")

# comparable regions start at the full ranges
a_len = len(region_a.parent_gbk.genes)
b_len = len(region_b.parent_gbk.genes)
a_len = len(region_a.get_cds())
b_len = len(region_b.get_cds())

self.comparable_region: ComparableRegion = ComparableRegion(
self, 0, a_len, 0, b_len, False
Expand Down Expand Up @@ -603,58 +604,53 @@ def sort_name_key(record: BGCRecord) -> str:


def as_class_bin_generator(
gbks: list[GBK], weights: str, classify_mode: CLASSIFY_MODE
all_records: list[BGCRecord], weight_type: str, classify_mode: CLASSIFY_MODE
) -> Iterator[RecordPairGenerator]:
"""Generate bins for each antiSMASH class
Args:
gbks (list[GBK]): List of GBKs to generate bins for
weights (str): weights to use for each class
category_weights (str): weights to use for each class
Yields:
Iterator[RecordPairGenerator]: Generator that yields bins. Order is not guarenteed to be
consistent
"""

class_idx: dict[str, list[BGCRecord]] = {}
category_idx: dict[str, str] = {}

for gbk in gbks:
if gbk.region is None:
continue
if gbk.region.product is None:
continue
category_weights: dict[str, str] = {}

for record in all_records:
# get region class for bin label and index
if classify_mode == CLASSIFY_MODE.CLASS:
region_class = gbk.region.product
record_class = record.product

if classify_mode == CLASSIFY_MODE.CATEGORY:
region_class = get_region_category(gbk.region)
record_class = get_record_category(record)

try:
class_idx[region_class].append(gbk.region)
class_idx[record_class].append(record)
except KeyError:
class_idx[region_class] = [gbk.region]
class_idx[record_class] = [record]

if weights == "legacy_weights":
if weight_type == "legacy_weights":
# get region category for weights
region_weight_cat = get_weight_category(gbk.region)
region_weight_cat = get_weight_category(record)

if region_class not in category_idx.keys():
category_idx[region_class] = region_weight_cat
if record_class not in category_weights.keys():
category_weights[record_class] = region_weight_cat

if weights == "mix":
category_idx[region_class] = "mix"
if weight_type == "mix":
category_weights[record_class] = "mix"

for class_name, regions in class_idx.items():
weight_category = category_idx[class_name]
for class_name, records in class_idx.items():
weight_category = category_weights[class_name]
bin = RecordPairGenerator(class_name, weight_category)
bin.add_records(regions)
bin.add_records(records)
yield bin


def get_region_category(region: Region) -> str:
def get_record_category(record: BGCRecord) -> str:
"""Get the category of a BGC based on its antiSMASH product(s)
Args:
Expand All @@ -664,17 +660,22 @@ def get_region_category(region: Region) -> str:
str: BGC category
"""

categories = []
categories: list[str] = []

if isinstance(record, ProtoCluster) or isinstance(record, ProtoCore):
if record.category is not None:
categories.append(record.category)

# get categories from region object
for idx, cand_cluster in region.cand_clusters.items():
if cand_cluster is not None:
for idx, protocluster in cand_cluster.proto_clusters.items():
if protocluster is not None:
pc_category = protocluster.category
# avoid duplicates, hybrids of the same kind count as one category
if pc_category not in categories:
categories.append(pc_category)
if isinstance(record, Region):
# get categories from region object
for idx, cand_cluster in record.cand_clusters.items():
if cand_cluster is not None:
for idx, protocluster in cand_cluster.proto_clusters.items():
if protocluster is not None and protocluster.category is not None:
pc_category = protocluster.category
# avoid duplicates, hybrids of the same kind count as one category
if pc_category not in categories:
categories.append(pc_category)

if len(categories) == 0:
return "Categoryless"
Expand All @@ -685,7 +686,7 @@ def get_region_category(region: Region) -> str:
return ".".join(categories)


def get_weight_category(region: Region) -> str:
def get_weight_category(record: BGCRecord) -> str:
"""Get the category of a BGC based on its antiSMASH product(s)
and match it to the legacy weights classes
Expand All @@ -696,28 +697,40 @@ def get_weight_category(region: Region) -> str:
str: class category to be used in weight selection
"""

categories = []
categories: list[str] = []

# get categories from region object
for idx, cand_cluster in region.cand_clusters.items():
if cand_cluster is not None:
for idx, protocluster in cand_cluster.proto_clusters.items():
if protocluster is not None:
if protocluster.product == "T1PKS":
pc_category = protocluster.product
else:
pc_category = protocluster.category
# avoid duplicates, hybrids of the same kind use the same weight class
if pc_category not in categories:
categories.append(pc_category)
if isinstance(record, ProtoCluster) or isinstance(record, ProtoCore):
# T1PKS is the only case in which a antiSMASH category does not
# correspond to a legacy_weights class
if (
record.category is not None
): # for typing, we assume antismash 6 and up always have it
if record.product == "T1PKS":
categories.append(record.product)
else:
categories.append(record.category)

if isinstance(record, Region):
# get categories from region object
for idx, cand_cluster in record.cand_clusters.items():
if cand_cluster is not None:
for idx, protocluster in cand_cluster.proto_clusters.items():
if protocluster is not None and protocluster.category is not None:
if protocluster.product == "T1PKS":
pc_category = protocluster.product
else:
pc_category = protocluster.category
# avoid duplicates, hybrids of the same kind use the same weight class
if pc_category not in categories:
categories.append(pc_category)

# process into legacy_weights classes

# for versions that dont have category information
if len(categories) == 0:
logging.warning(
"No category found for %s",
region,
record,
"This should not happen as long as antiSMASH is run with"
"version 6 or up, consider whether there is something"
"special about this region",
Expand All @@ -741,7 +754,7 @@ def get_weight_category(region: Region) -> str:


def legacy_bin_generator(
gbks: list[GBK],
all_records: list[BGCRecord],
) -> Iterator[RecordPairGenerator]: # pragma no cover
"""Generate bins for each class as they existed in the BiG-SCAPE 1.0 implementation
Expand All @@ -758,23 +771,22 @@ def legacy_bin_generator(
class_name: [] for class_name in LEGACY_WEIGHTS.keys() if class_name != "mix"
}

for gbk in gbks:
if gbk.region is None:
for record in all_records:
if record is None:
continue

if gbk.region.product is None:
if record.product is None:
continue

# product hybrids of AS4 and under dealt with here and in legacy_output generate_run_data_js
product = ".".join(gbk.region.product.split("-"))
product = ".".join(record.product.split("-"))

region_class = legacy_get_class(product)

class_idx[region_class].append(gbk.region)
class_idx[region_class].append(record)

for class_name, regions in class_idx.items():
for class_name, records in class_idx.items():
bin = RecordPairGenerator(class_name)
bin.add_records(regions)
bin.add_records(records)
yield bin


Expand Down
2 changes: 1 addition & 1 deletion big_scape/comparison/comparable_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def cds_range_contains_biosynthetic(
if end_inclusive:
stop += 1

for cds in record.get_cds_with_domains(True, reverse=reverse)[cds_start:stop]:
for cds in record.get_cds_with_domains(reverse=reverse)[cds_start:stop]:
if cds.gene_kind is None:
continue

Expand Down
4 changes: 2 additions & 2 deletions big_scape/comparison/legacy_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def reset_expansion(
return

if a_stop is None:
a_stop = len(a_gbk.genes) + 1
a_stop = len(comparable_region.pair.region_a.get_cds()) + 1

if b_stop is None:
b_stop = len(b_gbk.genes) + 1
b_stop = len(comparable_region.pair.region_b.get_cds()) + 1

comparable_region.a_start = a_start
comparable_region.b_start = b_start
Expand Down
8 changes: 7 additions & 1 deletion big_scape/comparison/legacy_workflow_alt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def generate_edges(
continue

new_task = executor.submit(
calculate_scores_pair, (batch, alignment_mode, pair_generator.label)
calculate_scores_pair,
(batch, alignment_mode, pair_generator.weights),
)
running_tasks[new_task] = batch

Expand Down Expand Up @@ -170,6 +171,7 @@ def generate_edges(
jaccard,
adjacency,
dss,
pair_generator.weights,
)

done_pairs += len(results)
Expand Down Expand Up @@ -248,6 +250,10 @@ def calculate_scores_pair(
results = []

for pair in pairs:
if pair.region_a.parent_gbk == pair.region_b.parent_gbk:
results.append((0.0, 1.0, 1.0, 1.0))
continue

jaccard = calc_jaccard_pair(pair)

if jaccard == 0.0:
Expand Down
Loading

0 comments on commit d88c864

Please sign in to comment.