Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions biasanalyzer/cohort_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,38 @@ def render_event(event):
if not domain or not domain["table"]:
return ""

base_sql = f"SELECT person_id, event_start_date, event_end_date FROM ranked_events_{event['event_type']}"
conditions = [f"concept_id = {event['event_concept_id']}"]
# Handle event_instance, including negative values
rank_table = f"ranked_asc_{event['event_type']}"
if "event_instance" in event and event["event_instance"] is not None:
conditions.append(f"event_instance >= {event['event_instance']}")
event_instance = int(event["event_instance"])
abs_instance = abs(event_instance)
if event_instance < 0:
rank_table = f"ranked_desc_{event['event_type']}"
instance_condition = f" AND event_instance = {abs_instance}"
else:
instance_condition = ""
# Handle offset for cohort window
offset = event.get("offset", 0)
if offset == 0:
adjusted_start = "event_start_date"
adjusted_end = "event_end_date"
else:
# Apply offset to start_date for negative, end_date for positive
adjusted_start = f"event_start_date - INTERVAL '{abs(offset)} days'" if offset < 0 else "event_start_date"
adjusted_end = f"event_end_date + INTERVAL '{offset} days'" if offset > 0 else "event_end_date"

base_sql = f"""
SELECT
person_id,
event_start_date,
event_end_date,
{adjusted_start} AS adjusted_start,
{adjusted_end} AS adjusted_end
FROM {rank_table}
WHERE concept_id = {event['event_concept_id']}{instance_condition}
"""

return f"{base_sql} WHERE {' AND '.join(conditions)}"
return base_sql


@staticmethod
Expand Down Expand Up @@ -163,7 +189,7 @@ def render_event_group(event_group, alias_prefix="evt"):
"""
# Then, union all events for qualifying person_ids
combined_sql = f"""
SELECT person_id, event_start_date, event_end_date
SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end
FROM (
{' UNION ALL '.join(f'({q})' for q in queries)}
) AS all_events
Expand All @@ -174,13 +200,15 @@ def render_event_group(event_group, alias_prefix="evt"):
return combined_sql

elif event_group["operator"] == "OR":
return f"SELECT person_id, event_start_date, event_end_date FROM ({' UNION '.join(queries)}) AS {alias_prefix}_or"
return (f"SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end "
f"FROM ({' UNION '.join(queries)}) AS {alias_prefix}_or")
elif event_group["operator"] == "NOT":
not_query = queries[0]
# Return a query that selects all persons from a base table (e.g., person),
# excluding those in the NOT subquery, while allowing dates from other criteria
return f"""
SELECT p.person_id, NULL AS event_start_date, NULL AS event_end_date
SELECT p.person_id, NULL AS event_start_date, NULL AS event_end_date,
NULL AS adjusted_start, NULL AS adjusted_end,
FROM person p
WHERE p.person_id NOT IN (
SELECT person_id FROM ({not_query}) AS {alias_prefix}_not
Expand All @@ -199,14 +227,14 @@ def render_event_group(event_group, alias_prefix="evt"):
if timestamp_event_index < non_timestamp_event_index:
# timestamp needs to happen before non-timestamp event
return f"""
SELECT person_id, event_start_date, event_end_date
SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end
FROM ({queries[0]}) AS {alias_prefix}_0
WHERE event_start_date > DATE '{timestamp}'
"""
else:
# non-timestamp event needs to happen before timestamp
return f"""
SELECT person_id, event_start_date, event_end_date
SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end
FROM ({queries[0]}) AS {alias_prefix}_0
WHERE event_start_date < DATE '{timestamp}'
"""
Expand All @@ -218,14 +246,16 @@ def render_event_group(event_group, alias_prefix="evt"):

# Ensure both events contribute dates with temporal order and interval
return f"""
SELECT {e1_alias}.person_id, {e1_alias}.event_start_date, {e1_alias}.event_end_date
SELECT {e1_alias}.person_id, {e1_alias}.event_start_date, {e1_alias}.event_end_date,
{e1_alias}.adjusted_start, {e1_alias}.adjusted_end
FROM ({queries[0]}) AS {e1_alias}
JOIN ({queries[1]}) AS {e2_alias}
ON {e1_alias}.person_id = {e2_alias}.person_id
AND {e1_alias}.event_start_date < {e2_alias}.event_start_date
{interval_sql}
UNION ALL
SELECT {e2_alias}.person_id, {e2_alias}.event_start_date, {e2_alias}.event_end_date
SELECT {e2_alias}.person_id, {e2_alias}.event_start_date, {e2_alias}.event_end_date,
{e2_alias}.adjusted_start, {e2_alias}.adjusted_end
FROM ({queries[1]}) AS {e2_alias}
JOIN ({queries[0]}) AS {e1_alias}
ON {e2_alias}.person_id = {e1_alias}.person_id
Expand Down Expand Up @@ -277,7 +307,8 @@ def temporal_event_filter(self, event_groups, alias='c'):
# events:
# - event_type: drug_exposure
# event_concept_id: 67890
return (f"SELECT person_id, event_start_date, event_end_date FROM "
return (f"SELECT person_id, event_start_date, event_end_date, "
f"adjusted_start, adjusted_end FROM "
f"({' UNION ALL '.join(filters)}) AS combined_events")

# Single event group case with operator defined
Expand Down
2 changes: 1 addition & 1 deletion biasanalyzer/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def cohort_distribution_variables(self):

def get_cohort_distributions(self, cohort_definition_id: int, variable: str):
"""
Get age distribution statistics for a cohort from the cohort table.
Get distribution statistics for a cohort from the cohort table.
"""
try:
if self._create_omop_table('person'):
Expand Down
4 changes: 2 additions & 2 deletions biasanalyzer/sql_templates/base.sql.j2
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ domain_qualifying_events AS (
filtered_cohort AS (
SELECT c.person_id,
{% if temporal_events %}
MIN(c.event_start_date) AS cohort_start_date,
MAX(c.event_end_date) AS cohort_end_date
MIN(c.adjusted_start) AS cohort_start_date,
MAX(c.adjusted_end) AS cohort_end_date
{% else %}
MIN(all_events.event_start_date) AS cohort_start_date,
MAX(all_events.event_end_date) AS cohort_end_date
Expand Down
18 changes: 15 additions & 3 deletions biasanalyzer/sql_templates/cohort_creation_query.sql.j2
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
{% block domain_events %}
{% if ranked_domains %}
WITH
{% for domain_type, domain in ranked_domains.items() %}
{% for event_type, domain in ranked_domains.items() %}
{% if domain.table %}
ranked_events_{{ domain_type }} AS (
ranked_asc_{{ event_type }} AS (
SELECT
person_id,
{{ domain.concept_id }} AS concept_id,
Expand All @@ -16,6 +16,18 @@ ranked_events_{{ domain_type }} AS (
) AS event_instance
FROM {{ domain.table }}
),
ranked_desc_{{ event_type }} AS (
SELECT
person_id,
{{ domain.concept_id }} AS concept_id,
{{ domain.start_date }} AS event_start_date,
{{ domain.end_date }} AS event_end_date,
ROW_NUMBER() OVER (
PARTITION BY person_id, {{ domain.concept_id }}
ORDER BY {{ domain.start_date }} DESC
) AS event_instance
FROM {{ domain.table }}
),
{% endif %}
{% endfor %}
{% endif %}
Expand All @@ -25,7 +37,7 @@ ranked_events_{{ domain_type }} AS (
{% if inclusion_criteria.temporal_events %}
{{ temporal_event_filter(inclusion_criteria.temporal_events) }}
{% else %}
SELECT person_id
SELECT person_id, NULL AS event_start_date, NULL AS event_end_date, NULL AS adjusted_start, NULL AS adjusted_end
FROM person p
{% endif %}
{% endblock %}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
inclusion_criteria:
demographics:
gender: female
min_birth_year: 1970
max_birth_year: 2000
temporal_events:
- operator: AND
events:
- event_type: condition_occurrence
event_concept_id: 201826 # Type 2 diabetes (valid OMOP ID)
event_instance: -1 # Last occurrence
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
inclusion_criteria:
demographics:
gender: female
min_birth_year: 1970
max_birth_year: 2000
temporal_events:
- operator: AND
events:
- event_type: condition_occurrence
event_concept_id: 201826 # Type 2 diabetes (valid OMOP ID)
event_instance: -1 # Last occurrence
offset: 180 # 180 days after
14 changes: 14 additions & 0 deletions tests/assets/cohort_creation/test_cohort_creation_offset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
inclusion_criteria:
demographics:
gender: female
min_birth_year: 1970
max_birth_year: 2000
temporal_events:
- operator: AND
events:
- event_type: condition_occurrence
event_concept_id: 201826 # Type 2 diabetes (valid OMOP ID)
offset: 180 # 180 days after
- event_type: condition_occurrence
event_concept_id: 201826 # Type 2 diabetes (valid OMOP ID)
offset: -730 # 2 years before
30 changes: 23 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def test_db():
(2, 8532, 0, 0, 1996), -- Female, qualifying, not excluded due to not having cardiac surgery
(3, 8532, 0, 0, 1996), -- Female, has cardiac surgery
(4, 8507, 0, 0, 1980), -- Male, wrong gender
(5, 8532, 0, 0, 1980); -- Female, missing insulin
(5, 8532, 0, 0, 1980), -- Female, missing insulin
-- for offset and negative instance testing
(6, 8532, 0, 0, 1985), -- Female, multiple diabetes records, last one too early
(7, 8532, 0, 0, 1990); -- Female, diabetes record too recent
""")

# Insert mock concepts as needed
Expand All @@ -141,7 +144,8 @@ def test_db():
(5, 'Fever', '2012-04-01', '2020-04-01', 'R50.9', 'ICD10CM', 'Condition'),
(37311061, 'COVID-19', '2012-04-01', '2020-04-01', '840539006', 'SNOMED', 'Condition'),
(4041664, 'Difficulty breathing', '2012-04-01', '2020-04-01', '230145002', 'SNOMED', 'Condition'),
(316139, 'Heart failure', '2012-04-01', '2020-04-01', '84114007', 'SNOMED', 'Condition');
(316139, 'Heart failure', '2012-04-01', '2020-04-01', '84114007', 'SNOMED', 'Condition'),
(201826, 'Type 2 diabetes mellitus', '2012-04-01', '2020-04-01', '44054006', 'SNOMED', 'Condition');
""")

# Insert hierarchical relationships as needed
Expand All @@ -163,7 +167,8 @@ def test_db():
(1, 3, 1), -- Diabetes -> Type 2
(1, 4, 2), -- Diabetes -> Retinopathy
(2, 4, 1), -- Type 1 -> Diabetes Retinopathy
(3, 4, 1); -- Type 2 -> Diabetes Retinopathy
(3, 4, 1), -- Type 2 -> Diabetes Retinopathy
(201826, 201826, 0); -- Type 2 diabetes SNOMED
""")

# Insert mock condition occurrences as needed
Expand Down Expand Up @@ -199,7 +204,11 @@ def test_db():
(2, 201826, '2020-06-01', '2020-06-01'), -- Person 2: Diabetes
(3, 201826, '2020-06-01', '2020-06-01'), -- Person 3: Diabetes
(4, 201826, '2020-06-01', '2020-06-01'), -- Person 4: Diabetes
(5, 201826, '2020-06-01', '2020-06-01'); -- Person 5: Diabetes
(5, 201826, '2020-06-01', '2020-06-01'), -- Person 5: Diabetes
-- for negative event instance and offset testing
(6, 201826, '2017-01-01', '2017-01-01'), -- Person 6: Early diabetes record
(6, 201826, '2018-01-01', '2018-01-01'), -- Person 6: Last diabetes record, still early
(7, 201826, '2023-01-01', '2023-01-01'); -- Person 7: Recent diabetes record
""")

# Insert mock visit data
Expand All @@ -220,7 +229,10 @@ def test_db():
(2, 9, 9202, '2020-06-10', '2020-06-10'), -- Person 2: Outpatient
(3, 10, 9202, '2020-06-10', '2020-06-10'), -- Person 3: Outpatient
(4, 11, 9202, '2020-06-10', '2020-06-10'), -- Person 4: Outpatient
(5, 12, 9202, '2020-06-10', '2020-06-10'); -- Person 5: Outpatient
(5, 12, 9202, '2020-06-10', '2020-06-10'), -- Person 5: Outpatient
-- New patients (no visits needed for exclusion testing)
(6, 13, 9202, '2018-01-10', '2018-01-10'), -- Person 6: Outpatient
(7, 14, 9202, '2023-01-10', '2023-01-10'); -- Person 7: Outpatient
""")

# Insert mock procedure_occurrence data for mixed domain testing
Expand All @@ -234,7 +246,9 @@ def test_db():
(3, 3, 4048609, '2020-06-20'), -- Person 3: Blood test
(3, 4, 619339, '2020-06-25'), -- Person 3: Cardiac surgery (exclusion)
(4, 5, 4048609, '2020-06-20'), -- Person 4: Blood test
(5, 6, 4048609, '2020-06-20'); -- Person 5: Blood test
(5, 6, 4048609, '2020-06-20'), -- Person 5: Blood test
(6, 7, 4048609, '2018-01-15'), -- Person 6: Blood test
(7, 8, 4048609, '2023-01-15'); -- Person 7: Blood test
""")

# Insert mock procedure_occurrence data for mixed domain testing
Expand All @@ -246,7 +260,9 @@ def test_db():
(1, 4285892, '2020-06-15', '2020-06-15'), -- Person 1: Insulin 14 days after
(2, 4285892, '2020-06-15', '2020-06-15'), -- Person 2: Insulin
(3, 4285892, '2020-06-15', '2020-06-15'), -- Person 3: Insulin
(4, 4285892, '2020-06-15', '2020-06-15'); -- Person 4: Insulin
(4, 4285892, '2020-06-15', '2020-06-15'), -- Person 4: Insulin
(6, 4285892, '2018-01-20', '2018-01-20'), -- Person 6: Insulin
(7, 4285892, '2023-01-20', '2023-01-20'); -- Person 7: Insulin
-- Person 5: No insulin
""")

Expand Down
Loading