diff --git a/README.md b/README.md index fb444c1..0e75733 100644 --- a/README.md +++ b/README.md @@ -94,9 +94,22 @@ concept ID. All clinical events in OMOP, such as conditions, drug exposures, pro represented as concepts. You can get patient counts and prevalence associated with each concept by accessing the method `get_concept_stats()` with a code snippet example shown below. ```angular2html - cohort_concepts = baseline_cohort_data.get_concept_stats(concept_type='condition_occurrence') + cohort_concepts, cohort_concept_hierarchy = baseline_cohort_data.get_concept_stats(concept_type='condition_occurrence') print(pd.DataFrame(cohort_concepts["condition_occurrence"])) + print(f"returned cohort_concept_hierarchy object converted to dict: {cohort_concept_hierarchy.to_dict()}") ``` + The returned cohort_concept_hierarchy object stores concept hierarchical relationsips with concept nodes indexed +to allow quick information retrival of a concept node and provides hierarchy traversal methods for concept hierarchy +navigation. For more details, refer to the corresponding tutorial notebook [BiasAnalyzerCohortConceptTutorial.ipynb](https://github.com/VACLab/BiasAnalyzer/blob/main/notebooks/BiasAnalyzerCohortConceptTutorial.ipynb). +- There is also an API method `get_cohorts_concept_stats(list_of_cohort_ids, concept_type='condition_occurrence', filter_count=0, vocab=None)` +that enables users to explore union of concept prevalences over multiple cohorts to facilitate potential cohort +selection bias exploration. An example code snippet is shown below to illustrate how to use this method. + ```angular2html + cohort_list = [baseline_cohort_data.cohort_id, study_cohort_data.cohort_id] + aggregated_cohort_metrics_dict = bias.get_cohorts_concept_stats(cohort_list) + print('Aggregated concept prevalence metrics over the baseline and study cohorts are:') + print(aggregated_cohort_metrics_dict) + ``` - There is also an API method that enables users to compare distributions of two cohorts by calling `bias.compare_cohorts(cohort1_id, cohort2_id)` where cohort1_id and cohort2_id are integers and can be obtained from metadata of a cohort object. Currently, only hellinger distances between distributions of two cohorts are computed. diff --git a/biasanalyzer/api.py b/biasanalyzer/api.py index d09163f..df009ac 100644 --- a/biasanalyzer/api.py +++ b/biasanalyzer/api.py @@ -1,5 +1,6 @@ import time from pydantic import ValidationError +from typing import List from biasanalyzer.database import OMOPCDMDatabase, BiasDatabase from biasanalyzer.cohort import CohortAction from biasanalyzer.config import load_config @@ -158,6 +159,31 @@ def create_cohort(self, cohort_name: str, cohort_desc: str, query_or_yaml_file: return None + def get_cohorts_concept_stats(self, cohorts: List[int], + concept_type: str='condition_occurrence', + filter_count: int=0, + vocab=None): + """ + compute concept statistics such as concept prevalence in a union of multiple cohorts + :param cohorts: list of cohort ids + :param concept_type: concept type to consider with default "condition_occurrence" + :param filter_count: filtering out those concepts with less than this count. Default is 0 meaning no filtering + :param vocab: vocabulary to consider with default None meaning using the default vocabulary corresponding to + the domain instead as defined in DOMAIN_MAPPING variable in models.py + :return: ConceptHierarchy object + """ + if not cohorts: + notify_users('The input cohorts list is empty. At least one cohort id must be provided.') + return None + c_action = self._set_cohort_action() + if c_action: + return c_action.get_cohorts_concept_stats(cohorts, concept_type=concept_type, filter_count=filter_count, + vocab=vocab) + else: + notify_users('failed to get concept prevalence stats for the union of cohorts') + return None + + def compare_cohorts(self, cohort_id1, cohort_id2): c_action = self._set_cohort_action() if c_action: diff --git a/biasanalyzer/cohort.py b/biasanalyzer/cohort.py index 83aa2e8..d6b9698 100644 --- a/biasanalyzer/cohort.py +++ b/biasanalyzer/cohort.py @@ -1,14 +1,17 @@ from sqlalchemy.exc import SQLAlchemyError +from functools import reduce import duckdb import pandas as pd from datetime import datetime from tqdm.auto import tqdm from pydantic import ValidationError +from typing import List from biasanalyzer.models import CohortDefinition from biasanalyzer.config import load_cohort_creation_config from biasanalyzer.database import OMOPCDMDatabase, BiasDatabase from biasanalyzer.utils import hellinger_distance, clean_string, notify_users from biasanalyzer.cohort_query_builder import CohortQueryBuilder +from biasanalyzer.concept import ConceptHierarchy class CohortData: @@ -51,7 +54,7 @@ def get_distributions(self, variable): return self.bias_db.get_cohort_distributions(self.cohort_id, variable) def get_concept_stats(self, concept_type='condition_occurrence', filter_count=0, - vocab=None, include_hierarchy=False): + vocab=None, print_concept_hierarchy=False): """ Get cohort concept statistics such as concept prevalence """ @@ -60,8 +63,9 @@ def get_concept_stats(self, concept_type='condition_occurrence', filter_count=0, concept_type=concept_type, filter_count=filter_count, vocab=vocab, - include_hierarchy=include_hierarchy) - return cohort_stats + print_concept_hierarchy=print_concept_hierarchy) + return (cohort_stats, + ConceptHierarchy.build_concept_hierarchy_from_results(self.cohort_id, cohort_stats[concept_type])) def __del__(self): @@ -148,6 +152,19 @@ def create_cohort(self, cohort_name: str, description: str, query_or_yaml_file: omop_session.close() return None + def get_cohorts_concept_stats(self, cohorts: List[int], + concept_type: str = 'condition_occurrence', + filter_count: int = 0, + vocab=None): + cohort_concept_stats = [self.bias_db.get_cohort_concept_stats(c, self._query_builder, + concept_type=concept_type, + filter_count=filter_count, + vocab=vocab) + for c in cohorts] + hierarchies = [ConceptHierarchy.build_concept_hierarchy_from_results(c, c_stats.get(concept_type, [])) + for c, c_stats in zip(cohorts, cohort_concept_stats)] + return reduce(lambda h1, h2: h1.union(h2), hierarchies).to_dict() + def compare_cohorts(self, cohort_id_1: int, cohort_id_2: int): """ Compare the distributions of two cohorts in BiasDatabase. diff --git a/biasanalyzer/cohort_query_builder.py b/biasanalyzer/cohort_query_builder.py index 79fda66..c0a4e80 100644 --- a/biasanalyzer/cohort_query_builder.py +++ b/biasanalyzer/cohort_query_builder.py @@ -19,7 +19,7 @@ def __init__(self, cohort_creation=True): except ModuleNotFoundError: # pragma: no cover template_path = os.path.join(os.path.dirname(__file__), "sql_templates") - print(f'template_path: {template_path}, cohort_creation: {cohort_creation}') + print(f'template_path: {template_path}') self.env = Environment(loader=FileSystemLoader(template_path), extensions=['jinja2.ext.do']) if cohort_creation: self.env.globals.update( @@ -71,15 +71,13 @@ def build_query_cohort_creation(self, cohort_config: dict) -> str: temporal_events=temporal_events ) - def build_concept_prevalence_query(self, concept_type: str, cid: int, filter_count: int, vocab: str, - include_hierarchy: bool) -> str: + def build_concept_prevalence_query(self, concept_type: str, cid: int, filter_count: int, vocab: str) -> str: """ Build a SQL query for concept prevalence statistics for a given domain and cohort. :param concept_type: Domain from DOMAIN_MAPPING (e.g., 'condition_occurrence'). :param cid: Cohort definition ID. :param filter_count: Minimum count threshold for concepts with 0 meaning no filtering :param vocab: Vocabulary ID. Defaults to domain-specific vocabulary as defined in DOMAIN_MAPPING if set to None - :param include_hierarchy: Include concept hierarchy in results or not :return: The rendered SQL query :raises ValueError if concept_type is not invalid """ @@ -100,8 +98,7 @@ def build_concept_prevalence_query(self, concept_type: str, cid: int, filter_cou start_date_column=DOMAIN_MAPPING[concept_type]["start_date"], cid=cid, filter_count=filter_count, - vocab=effective_vocab, - include_hierarchy=include_hierarchy + vocab=effective_vocab ) @staticmethod diff --git a/biasanalyzer/concept.py b/biasanalyzer/concept.py new file mode 100644 index 0000000..90a179f --- /dev/null +++ b/biasanalyzer/concept.py @@ -0,0 +1,205 @@ +import networkx as nx +from typing import List, Optional, Union +from _collections import deque + + +class ConceptNode: + def __init__(self, concept_id: int, ch: "ConceptHierarchy"): + self.id = concept_id + self._ch = ch # reference back to ConceptHierarchy + + @property + def name(self) -> str: + return self._ch.graph.nodes[self.id]["concept_name"] + + @property + def code(self) -> str: + return self._ch.graph.nodes[self.id]["concept_code"] + + @property + def parents(self) -> List["ConceptNode"]: + return [ConceptNode(p, self._ch) for p in self._ch.graph.predecessors(self.id)] + + @property + def children(self) -> List["ConceptNode"]: + return [ConceptNode(c, self._ch) for c in self._ch.graph.successors(self.id)] + + def get_metrics(self, cohort_id: Union[int, str]) -> dict: + metrics = self._ch.graph.nodes[self.id].get("metrics", {}) + return metrics.get(str(cohort_id), {}) + + def get_union_metrics(self) -> dict: + # simple aggregation example + metrics = self._ch.graph.nodes[self.id].get("metrics", {}) + counts = [m["count"] for m in metrics.values()] + prevalences = [m["prevalence"] for m in metrics.values()] + return { + "count": sum(counts), + "prevalence": sum(prevalences) / len(prevalences) if prevalences else 0.0, + } + + def to_dict(self, include_children: bool = True) -> dict: + """ + Serialize this node into a dict. Optionally include nested children. + """ + data = { + "concept_id": self.id, + "concept_name": self.name, + "concept_code": self.code, + "metrics": { + "union": self.get_union_metrics(), + "cohorts": self._ch.graph.nodes[self.id].get("metrics", {}), + }, + "parent_ids": list(self._ch.graph.predecessors(self.id)), + } + if include_children: + data["children"] = [c.to_dict(include_children=True) for c in self.children] + return data + + +class ConceptHierarchy: + _graph_cache = {} + + def __init__(self, input_g: nx.DiGraph, identifier: str): + self.graph = input_g + self.identifier = ConceptHierarchy._normalize_identifier(identifier) + + @staticmethod + def _normalize_identifier(identifier: str) -> str: + # Split on "+" to allow union identifiers + if "+" not in identifier: + return identifier.strip() + else: + parts = identifier.split("+") + parts = [p.strip() for p in parts if p and p.strip() != ""] + parts = sorted(set(parts)) # deduplicate + sort + return "+".join(parts) + + @classmethod + def build_concept_hierarchy_from_results(cls, cohort_id: int, results: List[dict]): + """ + build concept hierarchy tree managed by networkx from list of dicts returned from the concept prevalence SQL + with cache management + :param results: list of dicts from prevalence SQL + :param cohort_id: cohort id to get concept hierarchy for + :return: ConceptHierarchy object + """ + identifer = str(cohort_id) + if identifer in cls._graph_cache: + return cls._graph_cache[identifer] + + # node metrics + metrics_by_concept = {} + node_metadata = {} + + for row in results: + cid = row["descendant_concept_id"] + if cid not in node_metadata: + node_metadata[cid] = { + "concept_name": row["concept_name"], + "concept_code": row["concept_code"], + } + metrics_by_concept[cid] = { + "count": row["count_in_cohort"], + "prevalence": row["prevalence"], + } + + graph = nx.DiGraph() + # add nodes with metadata + metrics + for cid, meta in node_metadata.items(): + graph.add_node(cid, **meta, metrics={identifer: metrics_by_concept[cid]}) + + # add parent-child edges + for row in results: + anc = row["ancestor_concept_id"] + desc = row["descendant_concept_id"] + if anc and desc and anc != desc: + graph.add_edge(anc, desc) + + hierarchy = ConceptHierarchy(graph, identifer) + cls._graph_cache[identifer] = hierarchy + return hierarchy + + @classmethod + def clear_cache(cls): + cls._graph_cache.clear() + + def get_node(self, concept_id: int, serialization: bool = False): + concept_node = ConceptNode(concept_id, self) if concept_id in self.graph.nodes else None + return concept_node.to_dict(include_children=False) if serialization else concept_node + + def get_root_nodes(self, serialization: bool = False) -> List: + roots = [n for n in self.graph.nodes if self.graph.in_degree(n) == 0] + root_nodes = [ConceptNode(r, self) for r in roots] + if serialization: + return [rn.to_dict(include_children=False) for rn in root_nodes] + else: + return root_nodes + + def get_leaf_nodes(self, serialization: bool = False) -> List: + leaves = [n for n in self.graph.nodes if self.graph.out_degree(n) == 0] + leave_nodes = [ConceptNode(l, self) for l in leaves] + if serialization: + return [ln.to_dict(include_children=False) for ln in leave_nodes] + else: + return leave_nodes + + def iter_nodes(self, root_id: int, order: str = "bfs", + serialization: bool = False): + """Iterate nodes in BFS or DFS order from a given root.""" + if root_id not in self.graph.nodes: + raise ValueError(f"Root node {root_id} not found in graph.") + + if order == "bfs": + queue = deque([root_id]) + while queue: + node = queue.popleft() + if serialization: + yield ConceptNode(node, self).to_dict(include_children=False) + else: + yield ConceptNode(node, self) + queue.extend(self.graph.successors(node)) + elif order == "dfs": + stack = [root_id] + while stack: + node = stack.pop() + if serialization: + yield ConceptNode(node, self).to_dict(include_children=False) + else: + yield ConceptNode(node, self) + stack.extend(self.graph.successors(node)) + else: + raise ValueError("order must be 'bfs' or 'dfs'") + + def union(self, other: "ConceptHierarchy") -> "ConceptHierarchy": + new_ident = ConceptHierarchy._normalize_identifier( + f"{self.identifier}+{other.identifier}" + ) + if new_ident in ConceptHierarchy._graph_cache: + return ConceptHierarchy._graph_cache[new_ident] + + """Merge two hierarchies into a new one, aggregating metrics.""" + composed_graph = nx.compose(self.graph, other.graph) + # merge node metrics + for n in composed_graph.nodes: + metrics_self = self.graph.nodes.get(n, {}).get("metrics", {}) + metrics_other = other.graph.nodes.get(n, {}).get("metrics", {}) + composed_graph.nodes[n]["metrics"] = {**metrics_self, **metrics_other} + + new_hierarchy = ConceptHierarchy(composed_graph, new_ident) + ConceptHierarchy._graph_cache[new_ident] = new_hierarchy + return new_hierarchy + + def to_dict(self, root_id: Optional[int] = None) -> dict: + """ + Convert the concept hierarchy or a sub-hierarchy to a nested dict structure + :param root_id: if provided, return the sub-hierarchy rooted at this concept_id; + if None, return the whole hierarchy with all roots. + :return: nested dict representation of the hierarchy or sub-hierarchy + """ + if root_id is not None: + if root_id not in self.graph: + raise ValueError(f"Input concept id {root_id} not found in the concept hierarchy graph") + return {"hierarchy": [ConceptNode(root_id, self).to_dict()]} + + return {"hierarchy": [r.to_dict() for r in self.get_root_nodes()]} diff --git a/biasanalyzer/database.py b/biasanalyzer/database.py index 03efa6e..e6e30e2 100644 --- a/biasanalyzer/database.py +++ b/biasanalyzer/database.py @@ -229,7 +229,7 @@ def get_cohort_distributions(self, cohort_definition_id: int, variable: str): def get_cohort_concept_stats(self, cohort_definition_id: int, qry_builder, concept_type='condition_occurrence', filter_count=0, vocab=None, - include_hierarchy=False): + print_concept_hierarchy=False): """ Get concept statistics for a cohort from the cohort table. """ @@ -243,34 +243,36 @@ def get_cohort_concept_stats(self, cohort_definition_id: int, qry_builder, valid_vocabs = self._execute_query("SELECT distinct vocabulary_id FROM concept") valid_vocab_ids = [row['vocabulary_id'] for row in valid_vocabs] if vocab not in valid_vocab_ids: - notify_users(f"input {vocab} is not a valid vocabulary in OMOP. " - f"Supported vocabulary ids are: {valid_vocab_ids}", - level='error') - return concept_stats + err_msg = (f"input {vocab} is not a valid vocabulary in OMOP. " + f"Supported vocabulary ids are: {valid_vocab_ids}") + notify_users(err_msg, level='error') + raise ValueError(err_msg) query = qry_builder.build_concept_prevalence_query(concept_type, cohort_definition_id, - filter_count, vocab, include_hierarchy) + filter_count, vocab) concept_stats[concept_type] = self._execute_query(query) cs_df = pd.DataFrame(concept_stats[concept_type]) # Combine concept_name and prevalence into a "details" column cs_df["details"] = cs_df.apply( lambda row: f"{row['concept_name']} (Code: {row['concept_code']}, " f"Count: {row['count_in_cohort']}, Prevalence: {row['prevalence']:.3%})", axis=1) - filtered_cs_df = cs_df[cs_df['ancestor_concept_id'] != cs_df['descendant_concept_id']] - roots = find_roots(filtered_cs_df) - hierarchy = build_concept_hierarchy(filtered_cs_df) - notify_users(f'cohort concept hierarchy for {concept_type} with root concept ids {roots}:') - for root in roots: - root_detail = cs_df[(cs_df['ancestor_concept_id'] == root) - & (cs_df['descendant_concept_id'] == root)]['details'].iloc[0] - print_hierarchy(hierarchy, parent=root, level=0, parent_details=root_detail) + + if print_concept_hierarchy: + filtered_cs_df = cs_df[cs_df['ancestor_concept_id'] != cs_df['descendant_concept_id']] + roots = find_roots(filtered_cs_df) + hierarchy = build_concept_hierarchy(filtered_cs_df) + notify_users(f'cohort concept hierarchy for {concept_type} with root concept ids {roots}:') + for root in roots: + root_detail = cs_df[(cs_df['ancestor_concept_id'] == root) + & (cs_df['descendant_concept_id'] == root)]['details'].iloc[0] + print_hierarchy(hierarchy, parent=root, level=0, parent_details=root_detail) return concept_stats else: - notify_users("Cannot connect to the OMOP database to query concept table") - return concept_stats + err_msg = "Cannot connect to the OMOP database to query concept table" + raise ValueError(err_msg) except Exception as e: - notify_users(f"Error computing cohort concept stats: {e}", level='error') - return concept_stats + err_msg = f"Error computing cohort concept stats: {e}" + raise ValueError(err_msg) def close(self): if self.conn: diff --git a/biasanalyzer/module_test.py b/biasanalyzer/module_test.py index 1c4de4e..9fad973 100644 --- a/biasanalyzer/module_test.py +++ b/biasanalyzer/module_test.py @@ -1,3 +1,4 @@ +import pprint from biasanalyzer.api import BIAS import time import os @@ -8,7 +9,10 @@ def cohort_creation_template_test(bias_obj): cohort_data = bias_obj.create_cohort('COVID-19 patients', 'COVID-19 patients', os.path.join(os.path.dirname(__file__), '..', 'tests', 'assets', 'cohort_creation', - 'test_cohort_creation_condition_occurrence_config.yaml'), + # 'extras', + # 'covid_example3', + # 'cohort_creation_config_baseline_example3.yaml'), + 'test_cohort_creation_condition_occurrence_config_study.yaml'), 'system') if cohort_data: md = cohort_data.metadata @@ -33,7 +37,7 @@ def condition_cohort_test(bias_obj): 'WHERE c.condition_concept_id = 37311061 ' 'AND p.gender_concept_id = 8532 AND p.year_of_birth > 2000') cohort_data = bias_obj.create_cohort('COVID-19 patients', 'COVID-19 patients', - baseline_cohort_query, 'system') + baseline_cohort_query, 'system') if cohort_data: md = cohort_data.metadata print(f'cohort_definition: {md}') @@ -45,12 +49,29 @@ def condition_cohort_test(bias_obj): print(f'the cohort ethnicity stats: {cohort_data.get_stats("ethnicity")}') print(f'the cohort age distributions: {cohort_data.get_distributions("age")}') t1 = time.time() - cohort_concepts = cohort_data.get_concept_stats(concept_type='condition_occurrence', filter_count=5000) - # print('the cohort concept condition occurrence stats:') - # print(pd.DataFrame(cohort_concepts["condition_occurrence"])) - cohort_de_concepts = cohort_data.get_concept_stats(concept_type='drug_exposure', filter_count=500) - # print(f'the cohort concept drug exposure stats: \n{pd.DataFrame(cohort_de_concepts["drug_exposure"])}') - print(f'the time taken to get cohort concept stats is {time.time() - t1}s') + _, cohort_concept_hierarchy = cohort_data.get_concept_stats(concept_type='condition_occurrence', + filter_count=5000) + concept_node = cohort_concept_hierarchy.get_node(concept_id=37311061) + print(f'concept_node 37311061 metric: {concept_node.get_metrics(md["id"])}') + + # Print the root node + root_nodes = cohort_concept_hierarchy.get_root_nodes() + root = [(n.name, n.code, n.get_metrics(md["id"])) for n in root_nodes] + leave_nodes = cohort_concept_hierarchy.get_leaf_nodes() + leaves = [(n.name, n.code, n.get_metrics(md["id"])) for n in leave_nodes] + print(f"Root: {root}", flush=True) + print(f"Leaves: {leaves}", flush=True) + for node in cohort_concept_hierarchy.iter_nodes(root_nodes[0].id, serialization=True): + print(node) + + hier_dict = cohort_concept_hierarchy.to_dict() + pprint.pprint(hier_dict, indent=2) + + + _, cohort_de_concept_hierarchy = cohort_data.get_concept_stats(concept_type='drug_exposure', + filter_count=500) + de_hier_dict = cohort_de_concept_hierarchy.to_dict() + pprint.pprint(de_hier_dict, indent=2) compare_stats = bias_obj.compare_cohorts(cohort_data.metadata['id'], cohort_data.metadata['id']) print(f'compare_stats: {compare_stats}') return diff --git a/biasanalyzer/sql_templates/cohort_concept_prevalence_query.sql.j2 b/biasanalyzer/sql_templates/cohort_concept_prevalence_query.sql.j2 index ff3be54..b832cd8 100644 --- a/biasanalyzer/sql_templates/cohort_concept_prevalence_query.sql.j2 +++ b/biasanalyzer/sql_templates/cohort_concept_prevalence_query.sql.j2 @@ -56,6 +56,5 @@ JOIN concept c ON ac.concept_id = c.concept_id WHERE ac.count_in_cohort > {{ filter_count }} - AND ({{ include_hierarchy }} = True OR ch.ancestor_concept_id = ch.descendant_concept_id) ORDER BY prevalence DESC; \ No newline at end of file diff --git a/notebooks/BiasAnalyzerCohortConceptTutorial.ipynb b/notebooks/BiasAnalyzerCohortConceptTutorial.ipynb index 2b284eb..0ee4282 100644 --- a/notebooks/BiasAnalyzerCohortConceptTutorial.ipynb +++ b/notebooks/BiasAnalyzerCohortConceptTutorial.ipynb @@ -96,12 +96,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b99aedde4936451e9c0b8e75f2bcc620", + "model_id": "af2f7ec9bd3544b486882c38ba0aa738", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Cohort creation: 0%| | 0/3 [00:00=3.8" +files = [ + {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, + {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, +] + +[package.extras] +default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] +developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] +doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] +test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] + [[package]] name = "numpy" version = "1.24.4" @@ -1370,4 +1388,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.8.10,<3.12" -content-hash = "f06e2e5bcc2170425b33b4887b1d42d216812771cbb34c01b3e79a6c962f0524" +content-hash = "7ef5e6a3bec2bcef8429f74816408f554f0d021da19349481077a67065489833" diff --git a/pyproject.toml b/pyproject.toml index 5f63a93..7600f55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ ipytree = "^0.2.2" ipywidgets = "^8.1.5" jinja2 = "3.1.6" tqdm = "4.67.1" +networkx = "3.1" [tool.poetry.dev-dependencies] pytest = "^8.3.3" diff --git a/tests/query_based/test_hierarchical_prevalence.py b/tests/query_based/test_hierarchical_prevalence.py index e0ad55c..c0a7022 100644 --- a/tests/query_based/test_hierarchical_prevalence.py +++ b/tests/query_based/test_hierarchical_prevalence.py @@ -1,4 +1,7 @@ -import logging +import pytest +from functools import reduce +from biasanalyzer.concept import ConceptHierarchy, ConceptNode + def test_cohort_concept_hierarchical_prevalence(test_db, caplog): bias = test_db @@ -18,49 +21,199 @@ def test_cohort_concept_hierarchical_prevalence(test_db, caplog): # Test cohort object and methods assert cohort is not None, "Cohort creation failed" # test concept_type must be one of the supported OMOP domain name - caplog.clear() - with caplog.at_level(logging.ERROR): - concept_stats = cohort.get_concept_stats(concept_type='dummy_invalid') - assert 'Invalid concept_type' in caplog.text - assert concept_stats == {} + with pytest.raises(ValueError): + cohort.get_concept_stats(concept_type='dummy_invalid') # test vocab must be None to use the default vocab or one of the supported OMOP vocabulary id - caplog.clear() - with caplog.at_level(logging.ERROR): - concept_stats = cohort.get_concept_stats(vocab='dummy_invalid_vocab') - assert 'is not a valid vocabulary' in caplog.text - assert concept_stats == {} + with pytest.raises(ValueError): + cohort.get_concept_stats(vocab='dummy_invalid_vocab') # test the cohort does not have procedure_occurrence related concepts - concept_stats = cohort.get_concept_stats(concept_type='procedure_occurrence') - assert concept_stats == {'procedure_occurrence': []} - - include_hierarchy_flags = [True, False] - for flag in include_hierarchy_flags: - concept_stats = cohort.get_concept_stats(vocab='ICD10CM', include_hierarchy=flag) - assert concept_stats is not None, "Failed to fetch concept stats" - assert len(concept_stats) > 0, "No concept stats returned" - # check returned data with different include_hierarchy flag - if flag is True: - assert not all(s['ancestor_concept_id'] == s['descendant_concept_id'] - for s in concept_stats['condition_occurrence']), \ - "Some ancestor_concept_id and descendant_concept_id should differ when include_hierarchy is True" - else: - assert all(s['ancestor_concept_id'] == s['descendant_concept_id'] for s in - concept_stats['condition_occurrence']), \ - "ancestor_concept_id and descendant_concept_id must be equal when include_hierarchy is False" - # Check concept prevalence for overlaps - diabetes_prevalence = next((c for c in concept_stats['condition_occurrence'] - if c['ancestor_concept_id'] == 1 and c['descendant_concept_id'] == 1), None) - assert diabetes_prevalence is not None, "Parent diabetes concept prevalence missing" - type1_prevalence = next((c for c in concept_stats['condition_occurrence'] - if c['ancestor_concept_id'] == 2 and c['descendant_concept_id'] == 2), None) - assert type1_prevalence is not None, "Child type 1 diabetes concept prevalence missing" - type2_prevalence = next((c for c in concept_stats['condition_occurrence'] - if c['ancestor_concept_id'] == 3 and c['descendant_concept_id'] == 3), None) - assert type2_prevalence is not None, "Child type 2 diabetes concept prevalence missing" - print(f"type1_prevalence: {type1_prevalence['prevalence']}, type2_prevalence: {type2_prevalence['prevalence']}, " - f"diabetes_prevalence: {diabetes_prevalence['prevalence']}") - assert diabetes_prevalence['prevalence'] < type1_prevalence['prevalence'] + type2_prevalence['prevalence'], \ - ("Parent diabetes concept prevalence does not reflect overlap between type 1 and type 2 diabetes " - "children concept prevalence") + with pytest.raises(ValueError): + cohort.get_concept_stats(concept_type='procedure_occurrence') + + concept_stats, _ = cohort.get_concept_stats(vocab='ICD10CM', print_concept_hierarchy=True) + assert concept_stats is not None, "Failed to fetch concept stats" + assert len(concept_stats) > 0, "No concept stats returned" + # check returned data + assert not all(s['ancestor_concept_id'] == s['descendant_concept_id'] + for s in concept_stats['condition_occurrence']), \ + "Some ancestor_concept_id and descendant_concept_id should differ" + # Check concept prevalence for overlaps + diabetes_prevalence = next((c for c in concept_stats['condition_occurrence'] + if c['ancestor_concept_id'] == 1 and c['descendant_concept_id'] == 1), None) + assert diabetes_prevalence is not None, "Parent diabetes concept prevalence missing" + type1_prevalence = next((c for c in concept_stats['condition_occurrence'] + if c['ancestor_concept_id'] == 2 and c['descendant_concept_id'] == 2), None) + assert type1_prevalence is not None, "Child type 1 diabetes concept prevalence missing" + type2_prevalence = next((c for c in concept_stats['condition_occurrence'] + if c['ancestor_concept_id'] == 3 and c['descendant_concept_id'] == 3), None) + assert type2_prevalence is not None, "Child type 2 diabetes concept prevalence missing" + print(f"type1_prevalence: {type1_prevalence['prevalence']}, type2_prevalence: {type2_prevalence['prevalence']}, " + f"diabetes_prevalence: {diabetes_prevalence['prevalence']}") + assert diabetes_prevalence['prevalence'] < type1_prevalence['prevalence'] + type2_prevalence['prevalence'], \ + ("Parent diabetes concept prevalence does not reflect overlap between type 1 and type 2 diabetes " + "children concept prevalence") + +def test_identifier_normalization_and_cache(): + ConceptHierarchy.clear_cache() + # identifiers are normalized + assert ConceptHierarchy._normalize_identifier("2+1") == "1+2" + assert ConceptHierarchy._normalize_identifier("1+2+2") == "1+2" + + # fake minimal results to build hierarchy + results = [ + {"ancestor_concept_id": 1, "descendant_concept_id": 1, + "concept_name": "Diabetes", "concept_code": "DIA", + "count_in_cohort": 5, "prevalence": 0.5} + ] + h1 = ConceptHierarchy.build_concept_hierarchy_from_results(1, results) + h2 = ConceptHierarchy.build_concept_hierarchy_from_results(1, results) + assert h1 is h2 # cache reuse + assert h1.identifier == "1" + +def test_union_and_cache_behavior(): + ConceptHierarchy.clear_cache() + results1 = [ + {"ancestor_concept_id": 1, "descendant_concept_id": 1, + "concept_name": "Diabetes", "concept_code": "DIA", + "count_in_cohort": 5, "prevalence": 0.5} + ] + results2 = [ + {"ancestor_concept_id": 2, "descendant_concept_id": 2, + "concept_name": "Hypertension", "concept_code": "HYP", + "count_in_cohort": 3, "prevalence": 0.3} + ] + + h1 = ConceptHierarchy.build_concept_hierarchy_from_results(1, results1) + h2 = ConceptHierarchy.build_concept_hierarchy_from_results(2, results2) + assert "1" in ConceptHierarchy._graph_cache + assert "2" in ConceptHierarchy._graph_cache + h12 = h1.union(h2) + h21 = h2.union(h1) + assert h12.identifier == "1+2" + assert h21.identifier == "1+2" + assert h12 is h21 + +def test_traversal_and_serialization(): + ConceptHierarchy.clear_cache() + results = [ + {"ancestor_concept_id": 1, "descendant_concept_id": 1, + "concept_name": "Root", "concept_code": "R", + "count_in_cohort": 5, "prevalence": 0.5}, + {"ancestor_concept_id": 1, "descendant_concept_id": 2, + "concept_name": "Child", "concept_code": "C", + "count_in_cohort": 2, "prevalence": 0.2} + ] + h = ConceptHierarchy.build_concept_hierarchy_from_results(1, results) + + # roots + roots = h.get_root_nodes() + assert len(roots) == 1 + assert roots[0].name == "Root" + assert roots[0].get_metrics(1) == {"count": 5, "prevalence": 0.5} + children = roots[0].children + ch_names = [ch.name for ch in children] + assert ch_names == ["Child"] + # leaves + assert h.get_leaf_nodes(serialization=True) == [ + { + 'concept_id': 2, + 'concept_name': 'Child', + 'concept_code': 'C', + 'metrics': { + 'union': { + 'count': 2, + 'prevalence': 0.2 + }, + 'cohorts': { + '1': { + 'count': 2, 'prevalence': 0.2 + } + } + }, + 'parent_ids': [1] + } + ] + leaves = h.get_leaf_nodes() + assert len(leaves) == 1 + assert leaves[0].name == "Child" + parents = leaves[0].parents + par_names = [par.name for par in parents] + assert par_names == ["Root"] + + assert h.get_node(1, serialization=True) == { + "concept_id": 1, + "concept_name": "Root", + "concept_code": "R", + "metrics": { + "union": { + "count": 5, + "prevalence": 0.5 + }, + "cohorts": { + "1": { + "count": 5, + "prevalence": 0.5 + } + } + }, + "parent_ids": [] + } + + # graph traversal + with pytest.raises(ValueError): + # make sure to use list() to force generator execution + # test invalid root_id raises ValueError + list(h.iter_nodes(111, order="bfs")) + + with pytest.raises(ValueError): + # make sure to use list() to force generator execution + # test invalid order raises ValueError + list(h.iter_nodes(1, order="dummy")) + + bfs_nodes = [n.id for n in h.iter_nodes(1, order="bfs")] + assert bfs_nodes == [1, 2] + + # DFS traversal + dfs_nodes = [n.id for n in h.iter_nodes(1, order="dfs")] + assert set(dfs_nodes) == {1, 2} + + dfs_nodes = [n['concept_id'] for n in h.iter_nodes(1, order="dfs", serialization=True)] + assert set(dfs_nodes) == {1, 2} + + # serialization + serialized_root = h.get_root_nodes(serialization=True)[0] + assert serialized_root["concept_name"] == "Root" + assert "metrics" in serialized_root + + serialized_iter = list(h.iter_nodes(1, serialization=True)) + assert all(isinstance(n, dict) for n in serialized_iter) + assert serialized_iter[0]["concept_id"] == 1 + + with pytest.raises(ValueError): + h.to_dict(111) + + h_dict = h.to_dict(1) + assert h_dict == {'hierarchy': [{ + 'concept_id': 1, 'concept_name': 'Root', 'concept_code': 'R', + 'metrics': {'union': {'count': 5, 'prevalence': 0.5}, + 'cohorts': {'1': {'count': 5, 'prevalence': 0.5}}}, + 'parent_ids': [], + 'children': [{'concept_id': 2, 'concept_name': 'Child', 'concept_code': 'C', + 'metrics': {'union': {'count': 2, 'prevalence': 0.2}, + 'cohorts': {'1': {'count': 2, 'prevalence': 0.2}}}, + 'parent_ids': [1], 'children': []}]} + ]} + + h_dict = h.to_dict() + assert h_dict == {'hierarchy': [{ + 'concept_id': 1, 'concept_name': 'Root', 'concept_code': 'R', + 'metrics': {'union': {'count': 5, 'prevalence': 0.5}, + 'cohorts': {'1': {'count': 5, 'prevalence': 0.5}}}, + 'parent_ids': [], + 'children': [{'concept_id': 2, 'concept_name': 'Child', 'concept_code': 'C', + 'metrics': {'union': {'count': 2, 'prevalence': 0.2}, + 'cohorts': {'1': {'count': 2, 'prevalence': 0.2}}}, + 'parent_ids': [1], 'children': []}]} + ]} diff --git a/tests/test_biasanalyzer_api.py b/tests/test_biasanalyzer_api.py index 0ed8ae3..6efcf5f 100644 --- a/tests/test_biasanalyzer_api.py +++ b/tests/test_biasanalyzer_api.py @@ -3,7 +3,7 @@ import logging import pytest from ipytree import Node - +from biasanalyzer.concept import ConceptHierarchy from biasanalyzer import __version__ @@ -106,6 +106,78 @@ def test_compare_cohort_with_no_action(caplog, fresh_bias_obj): fresh_bias_obj.compare_cohorts(1, 2) assert 'failed to create a valid cohort action object' in caplog.text +def test_cohorts_concept_stats_empty_input_cohorts(caplog, fresh_bias_obj): + caplog.clear() + with caplog.at_level(logging.INFO): + fresh_bias_obj.get_cohorts_concept_stats([]) + assert 'The input cohorts list is empty. At least one cohort id must be provided.' in caplog.text + +def test_cohorts_concept_stats_no_cohort_action(caplog, fresh_bias_obj): + caplog.clear() + with caplog.at_level(logging.INFO): + fresh_bias_obj.get_cohorts_concept_stats([1]) + assert 'failed to get concept prevalence stats for the union of cohorts' in caplog.text + +def test_cohorts_union_concept_stats(test_db): + ConceptHierarchy.clear_cache() + # Show what cohorts exist in the test DB and print cohorts and stats so we know what raw data looks like + cohorts_df = test_db.bias_db.conn.execute(""" + SELECT cohort_definition_id, COUNT(*) as n_subjects + FROM cohort + WHERE cohort_definition_id = 1 or cohort_definition_id = 2 + GROUP BY cohort_definition_id + ORDER BY cohort_definition_id + """).fetchdf() + print("Cohorts in DB:\n", cohorts_df.to_string(index=False), flush=True) + + # Show concept stats per cohort (aggregated for clarity) + stats_df = test_db.bias_db.conn.execute(""" + SELECT c.cohort_definition_id, + co.condition_concept_id, + COUNT(*) as n + FROM cohort c + JOIN condition_occurrence co + ON c.subject_id = co.person_id + WHERE c.cohort_definition_id = 1 or c.cohort_definition_id = 2 + GROUP BY c.cohort_definition_id, co.condition_concept_id + ORDER BY c.cohort_definition_id, co.condition_concept_id + """).fetchdf() + print("Concept stats per cohort:\n", stats_df.to_string(index=False), flush=True) + + union_result = test_db.get_cohorts_concept_stats([1, 2]) + print(f'union_result: {union_result}', flush=True) + union_result['hierarchy'] = sorted(union_result['hierarchy'], key=lambda x: x['concept_id']) + # NOTE: The union_result takes cohort_start_date and cohort_end_date into account + # when joining cohort with condition_occurrence for inclusion/exclusion criteria. + # That means counts may differ from the raw numbers above. For example: + # - Concept 4041664 appears 5 times in cohort 1 raw, but only 4 fall within + # the cohort window → {'1': {'count': 4}} + # - Concept 4041664 appears 5 times in cohort 2 raw, but only 1 falls within + # the window → {'2': {'count': 1}} + # - Concept 5 disappears entirely, because its single occurrence is outside + # the cohort date window. + # This explains why union_result values differ from the raw stats above. + assert union_result == {'hierarchy': [ + {'concept_id': 316139, 'concept_name': 'Heart failure', 'concept_code': '84114007', + 'metrics': {'union': {'count': 4, 'prevalence': 0.45}, + 'cohorts': {'1': {'count': 2, 'prevalence': 0.4}, + '2': {'count': 2, 'prevalence': 0.5}}}, + 'parent_ids': [], 'children': []}, + {'concept_id': 4041664, 'concept_name': 'Difficulty breathing', 'concept_code': '230145002', + 'metrics': { + 'union': {'count': 5, 'prevalence': 0.525}, + 'cohorts': {'1': {'count': 4, 'prevalence': 0.8}, + '2': {'count': 1, 'prevalence': 0.25} + } + }, + 'parent_ids': [], 'children': []}, + {'concept_id': 37311061, 'concept_name': 'COVID-19', 'concept_code': '840539006', + 'metrics': {'union': {'count': 8, 'prevalence': 0.9}, + 'cohorts': {'1': {'count': 4, 'prevalence': 0.8}, + '2': {'count': 4, 'prevalence': 1.0}}}, + 'parent_ids': [], 'children': []}, + ]} + def test_get_domains_and_vocabularies_invalid(caplog, fresh_bias_obj): caplog.clear() with caplog.at_level(logging.INFO): diff --git a/tests/test_database.py b/tests/test_database.py index d2cb1ff..9944965 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -156,11 +156,8 @@ def test_get_cohort_concept_stats_handles_exception(caplog): db = BiasDatabase(":memory:") db.omop_cdm_db_url = 'duckdb' qry_builder = CohortQueryBuilder(cohort_creation=False) - caplog.clear() - with caplog.at_level(logging.ERROR): - result = db.get_cohort_concept_stats(123, qry_builder) - assert 'Error computing cohort concept stats' in caplog.text - assert result == {} + with pytest.raises(ValueError): + db.get_cohort_concept_stats(123, qry_builder) def test_get_cohort_attributes_handles_exception(): BiasDatabase._instance = None @@ -171,6 +168,5 @@ def test_get_cohort_attributes_handles_exception(): assert result_stats is None result = db.get_cohort_distributions(123, 'age') assert result is None - result = db.get_cohort_concept_stats(123, qry_builder) - assert result == {} - + with pytest.raises(ValueError): + db.get_cohort_concept_stats(123, qry_builder)