Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,22 @@ concept ID. All clinical events in OMOP, such as conditions, drug exposures, pro
represented as concepts. You can get patient counts and prevalence associated with each concept by accessing
the method `get_concept_stats()` with a code snippet example shown below.
```angular2html
cohort_concepts = baseline_cohort_data.get_concept_stats(concept_type='condition_occurrence')
cohort_concepts, cohort_concept_hierarchy = baseline_cohort_data.get_concept_stats(concept_type='condition_occurrence')
print(pd.DataFrame(cohort_concepts["condition_occurrence"]))
print(f"returned cohort_concept_hierarchy object converted to dict: {cohort_concept_hierarchy.to_dict()}")
```
The returned cohort_concept_hierarchy object stores concept hierarchical relationsips with concept nodes indexed
to allow quick information retrival of a concept node and provides hierarchy traversal methods for concept hierarchy
navigation. For more details, refer to the corresponding tutorial notebook [BiasAnalyzerCohortConceptTutorial.ipynb](https://github.com/VACLab/BiasAnalyzer/blob/main/notebooks/BiasAnalyzerCohortConceptTutorial.ipynb).
- There is also an API method `get_cohorts_concept_stats(list_of_cohort_ids, concept_type='condition_occurrence', filter_count=0, vocab=None)`
that enables users to explore union of concept prevalences over multiple cohorts to facilitate potential cohort
selection bias exploration. An example code snippet is shown below to illustrate how to use this method.
```angular2html
cohort_list = [baseline_cohort_data.cohort_id, study_cohort_data.cohort_id]
aggregated_cohort_metrics_dict = bias.get_cohorts_concept_stats(cohort_list)
print('Aggregated concept prevalence metrics over the baseline and study cohorts are:')
print(aggregated_cohort_metrics_dict)
```
- There is also an API method that enables users to compare distributions of two cohorts by calling `bias.compare_cohorts(cohort1_id, cohort2_id)`
where cohort1_id and cohort2_id are integers and can be obtained from metadata of a cohort object. Currently,
only hellinger distances between distributions of two cohorts are computed.
Expand Down
26 changes: 26 additions & 0 deletions biasanalyzer/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from pydantic import ValidationError
from typing import List
from biasanalyzer.database import OMOPCDMDatabase, BiasDatabase
from biasanalyzer.cohort import CohortAction
from biasanalyzer.config import load_config
Expand Down Expand Up @@ -158,6 +159,31 @@ def create_cohort(self, cohort_name: str, cohort_desc: str, query_or_yaml_file:
return None


def get_cohorts_concept_stats(self, cohorts: List[int],
concept_type: str='condition_occurrence',
filter_count: int=0,
vocab=None):
"""
compute concept statistics such as concept prevalence in a union of multiple cohorts
:param cohorts: list of cohort ids
:param concept_type: concept type to consider with default "condition_occurrence"
:param filter_count: filtering out those concepts with less than this count. Default is 0 meaning no filtering
:param vocab: vocabulary to consider with default None meaning using the default vocabulary corresponding to
the domain instead as defined in DOMAIN_MAPPING variable in models.py
:return: ConceptHierarchy object
"""
if not cohorts:
notify_users('The input cohorts list is empty. At least one cohort id must be provided.')
return None
c_action = self._set_cohort_action()
if c_action:
return c_action.get_cohorts_concept_stats(cohorts, concept_type=concept_type, filter_count=filter_count,
vocab=vocab)
else:
notify_users('failed to get concept prevalence stats for the union of cohorts')
return None


def compare_cohorts(self, cohort_id1, cohort_id2):
c_action = self._set_cohort_action()
if c_action:
Expand Down
23 changes: 20 additions & 3 deletions biasanalyzer/cohort.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from sqlalchemy.exc import SQLAlchemyError
from functools import reduce
import duckdb
import pandas as pd
from datetime import datetime
from tqdm.auto import tqdm
from pydantic import ValidationError
from typing import List
from biasanalyzer.models import CohortDefinition
from biasanalyzer.config import load_cohort_creation_config
from biasanalyzer.database import OMOPCDMDatabase, BiasDatabase
from biasanalyzer.utils import hellinger_distance, clean_string, notify_users
from biasanalyzer.cohort_query_builder import CohortQueryBuilder
from biasanalyzer.concept import ConceptHierarchy


class CohortData:
Expand Down Expand Up @@ -51,7 +54,7 @@ def get_distributions(self, variable):
return self.bias_db.get_cohort_distributions(self.cohort_id, variable)

def get_concept_stats(self, concept_type='condition_occurrence', filter_count=0,
vocab=None, include_hierarchy=False):
vocab=None, print_concept_hierarchy=False):
"""
Get cohort concept statistics such as concept prevalence
"""
Expand All @@ -60,8 +63,9 @@ def get_concept_stats(self, concept_type='condition_occurrence', filter_count=0,
concept_type=concept_type,
filter_count=filter_count,
vocab=vocab,
include_hierarchy=include_hierarchy)
return cohort_stats
print_concept_hierarchy=print_concept_hierarchy)
return (cohort_stats,
ConceptHierarchy.build_concept_hierarchy_from_results(self.cohort_id, cohort_stats[concept_type]))


def __del__(self):
Expand Down Expand Up @@ -148,6 +152,19 @@ def create_cohort(self, cohort_name: str, description: str, query_or_yaml_file:
omop_session.close()
return None

def get_cohorts_concept_stats(self, cohorts: List[int],
concept_type: str = 'condition_occurrence',
filter_count: int = 0,
vocab=None):
cohort_concept_stats = [self.bias_db.get_cohort_concept_stats(c, self._query_builder,
concept_type=concept_type,
filter_count=filter_count,
vocab=vocab)
for c in cohorts]
hierarchies = [ConceptHierarchy.build_concept_hierarchy_from_results(c, c_stats.get(concept_type, []))
for c, c_stats in zip(cohorts, cohort_concept_stats)]
return reduce(lambda h1, h2: h1.union(h2), hierarchies).to_dict()

def compare_cohorts(self, cohort_id_1: int, cohort_id_2: int):
"""
Compare the distributions of two cohorts in BiasDatabase.
Expand Down
9 changes: 3 additions & 6 deletions biasanalyzer/cohort_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, cohort_creation=True):
except ModuleNotFoundError: # pragma: no cover
template_path = os.path.join(os.path.dirname(__file__), "sql_templates")

print(f'template_path: {template_path}, cohort_creation: {cohort_creation}')
print(f'template_path: {template_path}')
self.env = Environment(loader=FileSystemLoader(template_path), extensions=['jinja2.ext.do'])
if cohort_creation:
self.env.globals.update(
Expand Down Expand Up @@ -71,15 +71,13 @@ def build_query_cohort_creation(self, cohort_config: dict) -> str:
temporal_events=temporal_events
)

def build_concept_prevalence_query(self, concept_type: str, cid: int, filter_count: int, vocab: str,
include_hierarchy: bool) -> str:
def build_concept_prevalence_query(self, concept_type: str, cid: int, filter_count: int, vocab: str) -> str:
"""
Build a SQL query for concept prevalence statistics for a given domain and cohort.
:param concept_type: Domain from DOMAIN_MAPPING (e.g., 'condition_occurrence').
:param cid: Cohort definition ID.
:param filter_count: Minimum count threshold for concepts with 0 meaning no filtering
:param vocab: Vocabulary ID. Defaults to domain-specific vocabulary as defined in DOMAIN_MAPPING if set to None
:param include_hierarchy: Include concept hierarchy in results or not
:return: The rendered SQL query
:raises ValueError if concept_type is not invalid
"""
Expand All @@ -100,8 +98,7 @@ def build_concept_prevalence_query(self, concept_type: str, cid: int, filter_cou
start_date_column=DOMAIN_MAPPING[concept_type]["start_date"],
cid=cid,
filter_count=filter_count,
vocab=effective_vocab,
include_hierarchy=include_hierarchy
vocab=effective_vocab
)

@staticmethod
Expand Down
205 changes: 205 additions & 0 deletions biasanalyzer/concept.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import networkx as nx
from typing import List, Optional, Union
from _collections import deque


class ConceptNode:
def __init__(self, concept_id: int, ch: "ConceptHierarchy"):
self.id = concept_id
self._ch = ch # reference back to ConceptHierarchy

@property
def name(self) -> str:
return self._ch.graph.nodes[self.id]["concept_name"]

@property
def code(self) -> str:
return self._ch.graph.nodes[self.id]["concept_code"]

@property
def parents(self) -> List["ConceptNode"]:
return [ConceptNode(p, self._ch) for p in self._ch.graph.predecessors(self.id)]

@property
def children(self) -> List["ConceptNode"]:
return [ConceptNode(c, self._ch) for c in self._ch.graph.successors(self.id)]

def get_metrics(self, cohort_id: Union[int, str]) -> dict:
metrics = self._ch.graph.nodes[self.id].get("metrics", {})
return metrics.get(str(cohort_id), {})

def get_union_metrics(self) -> dict:
# simple aggregation example
metrics = self._ch.graph.nodes[self.id].get("metrics", {})
counts = [m["count"] for m in metrics.values()]
prevalences = [m["prevalence"] for m in metrics.values()]
return {
"count": sum(counts),
"prevalence": sum(prevalences) / len(prevalences) if prevalences else 0.0,
}

def to_dict(self, include_children: bool = True) -> dict:
"""
Serialize this node into a dict. Optionally include nested children.
"""
data = {
"concept_id": self.id,
"concept_name": self.name,
"concept_code": self.code,
"metrics": {
"union": self.get_union_metrics(),
"cohorts": self._ch.graph.nodes[self.id].get("metrics", {}),
},
"parent_ids": list(self._ch.graph.predecessors(self.id)),
}
if include_children:
data["children"] = [c.to_dict(include_children=True) for c in self.children]
return data


class ConceptHierarchy:
_graph_cache = {}

def __init__(self, input_g: nx.DiGraph, identifier: str):
self.graph = input_g
self.identifier = ConceptHierarchy._normalize_identifier(identifier)

@staticmethod
def _normalize_identifier(identifier: str) -> str:
# Split on "+" to allow union identifiers
if "+" not in identifier:
return identifier.strip()
else:
parts = identifier.split("+")
parts = [p.strip() for p in parts if p and p.strip() != ""]
parts = sorted(set(parts)) # deduplicate + sort
return "+".join(parts)

@classmethod
def build_concept_hierarchy_from_results(cls, cohort_id: int, results: List[dict]):
"""
build concept hierarchy tree managed by networkx from list of dicts returned from the concept prevalence SQL
with cache management
:param results: list of dicts from prevalence SQL
:param cohort_id: cohort id to get concept hierarchy for
:return: ConceptHierarchy object
"""
identifer = str(cohort_id)
if identifer in cls._graph_cache:
return cls._graph_cache[identifer]

# node metrics
metrics_by_concept = {}
node_metadata = {}

for row in results:
cid = row["descendant_concept_id"]
if cid not in node_metadata:
node_metadata[cid] = {
"concept_name": row["concept_name"],
"concept_code": row["concept_code"],
}
metrics_by_concept[cid] = {
"count": row["count_in_cohort"],
"prevalence": row["prevalence"],
}

graph = nx.DiGraph()
# add nodes with metadata + metrics
for cid, meta in node_metadata.items():
graph.add_node(cid, **meta, metrics={identifer: metrics_by_concept[cid]})

# add parent-child edges
for row in results:
anc = row["ancestor_concept_id"]
desc = row["descendant_concept_id"]
if anc and desc and anc != desc:
graph.add_edge(anc, desc)

hierarchy = ConceptHierarchy(graph, identifer)
cls._graph_cache[identifer] = hierarchy
return hierarchy

@classmethod
def clear_cache(cls):
cls._graph_cache.clear()

def get_node(self, concept_id: int, serialization: bool = False):
concept_node = ConceptNode(concept_id, self) if concept_id in self.graph.nodes else None
return concept_node.to_dict(include_children=False) if serialization else concept_node

def get_root_nodes(self, serialization: bool = False) -> List:
roots = [n for n in self.graph.nodes if self.graph.in_degree(n) == 0]
root_nodes = [ConceptNode(r, self) for r in roots]
if serialization:
return [rn.to_dict(include_children=False) for rn in root_nodes]
else:
return root_nodes

def get_leaf_nodes(self, serialization: bool = False) -> List:
leaves = [n for n in self.graph.nodes if self.graph.out_degree(n) == 0]
leave_nodes = [ConceptNode(l, self) for l in leaves]
if serialization:
return [ln.to_dict(include_children=False) for ln in leave_nodes]
else:
return leave_nodes

def iter_nodes(self, root_id: int, order: str = "bfs",
serialization: bool = False):
"""Iterate nodes in BFS or DFS order from a given root."""
if root_id not in self.graph.nodes:
raise ValueError(f"Root node {root_id} not found in graph.")

if order == "bfs":
queue = deque([root_id])
while queue:
node = queue.popleft()
if serialization:
yield ConceptNode(node, self).to_dict(include_children=False)
else:
yield ConceptNode(node, self)
queue.extend(self.graph.successors(node))
elif order == "dfs":
stack = [root_id]
while stack:
node = stack.pop()
if serialization:
yield ConceptNode(node, self).to_dict(include_children=False)
else:
yield ConceptNode(node, self)
stack.extend(self.graph.successors(node))
else:
raise ValueError("order must be 'bfs' or 'dfs'")

def union(self, other: "ConceptHierarchy") -> "ConceptHierarchy":
new_ident = ConceptHierarchy._normalize_identifier(
f"{self.identifier}+{other.identifier}"
)
if new_ident in ConceptHierarchy._graph_cache:
return ConceptHierarchy._graph_cache[new_ident]

"""Merge two hierarchies into a new one, aggregating metrics."""
composed_graph = nx.compose(self.graph, other.graph)
# merge node metrics
for n in composed_graph.nodes:
metrics_self = self.graph.nodes.get(n, {}).get("metrics", {})
metrics_other = other.graph.nodes.get(n, {}).get("metrics", {})
composed_graph.nodes[n]["metrics"] = {**metrics_self, **metrics_other}

new_hierarchy = ConceptHierarchy(composed_graph, new_ident)
ConceptHierarchy._graph_cache[new_ident] = new_hierarchy
return new_hierarchy

def to_dict(self, root_id: Optional[int] = None) -> dict:
"""
Convert the concept hierarchy or a sub-hierarchy to a nested dict structure
:param root_id: if provided, return the sub-hierarchy rooted at this concept_id;
if None, return the whole hierarchy with all roots.
:return: nested dict representation of the hierarchy or sub-hierarchy
"""
if root_id is not None:
if root_id not in self.graph:
raise ValueError(f"Input concept id {root_id} not found in the concept hierarchy graph")
return {"hierarchy": [ConceptNode(root_id, self).to_dict()]}

return {"hierarchy": [r.to_dict() for r in self.get_root_nodes()]}
Loading