diff --git a/biasanalyzer/cohort_query_builder.py b/biasanalyzer/cohort_query_builder.py index c0a4e80..b995d3f 100644 --- a/biasanalyzer/cohort_query_builder.py +++ b/biasanalyzer/cohort_query_builder.py @@ -116,12 +116,38 @@ def render_event(event): if not domain or not domain["table"]: return "" - base_sql = f"SELECT person_id, event_start_date, event_end_date FROM ranked_events_{event['event_type']}" - conditions = [f"concept_id = {event['event_concept_id']}"] + # Handle event_instance, including negative values + rank_table = f"ranked_asc_{event['event_type']}" if "event_instance" in event and event["event_instance"] is not None: - conditions.append(f"event_instance >= {event['event_instance']}") + event_instance = int(event["event_instance"]) + abs_instance = abs(event_instance) + if event_instance < 0: + rank_table = f"ranked_desc_{event['event_type']}" + instance_condition = f" AND event_instance = {abs_instance}" + else: + instance_condition = "" + # Handle offset for cohort window + offset = event.get("offset", 0) + if offset == 0: + adjusted_start = "event_start_date" + adjusted_end = "event_end_date" + else: + # Apply offset to start_date for negative, end_date for positive + adjusted_start = f"event_start_date - INTERVAL '{abs(offset)} days'" if offset < 0 else "event_start_date" + adjusted_end = f"event_end_date + INTERVAL '{offset} days'" if offset > 0 else "event_end_date" + + base_sql = f""" + SELECT + person_id, + event_start_date, + event_end_date, + {adjusted_start} AS adjusted_start, + {adjusted_end} AS adjusted_end + FROM {rank_table} + WHERE concept_id = {event['event_concept_id']}{instance_condition} + """ - return f"{base_sql} WHERE {' AND '.join(conditions)}" + return base_sql @staticmethod @@ -163,7 +189,7 @@ def render_event_group(event_group, alias_prefix="evt"): """ # Then, union all events for qualifying person_ids combined_sql = f""" - SELECT person_id, event_start_date, event_end_date + SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end FROM ( {' UNION ALL '.join(f'({q})' for q in queries)} ) AS all_events @@ -174,13 +200,15 @@ def render_event_group(event_group, alias_prefix="evt"): return combined_sql elif event_group["operator"] == "OR": - return f"SELECT person_id, event_start_date, event_end_date FROM ({' UNION '.join(queries)}) AS {alias_prefix}_or" + return (f"SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end " + f"FROM ({' UNION '.join(queries)}) AS {alias_prefix}_or") elif event_group["operator"] == "NOT": not_query = queries[0] # Return a query that selects all persons from a base table (e.g., person), # excluding those in the NOT subquery, while allowing dates from other criteria return f""" - SELECT p.person_id, NULL AS event_start_date, NULL AS event_end_date + SELECT p.person_id, NULL AS event_start_date, NULL AS event_end_date, + NULL AS adjusted_start, NULL AS adjusted_end, FROM person p WHERE p.person_id NOT IN ( SELECT person_id FROM ({not_query}) AS {alias_prefix}_not @@ -199,14 +227,14 @@ def render_event_group(event_group, alias_prefix="evt"): if timestamp_event_index < non_timestamp_event_index: # timestamp needs to happen before non-timestamp event return f""" - SELECT person_id, event_start_date, event_end_date + SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end FROM ({queries[0]}) AS {alias_prefix}_0 WHERE event_start_date > DATE '{timestamp}' """ else: # non-timestamp event needs to happen before timestamp return f""" - SELECT person_id, event_start_date, event_end_date + SELECT person_id, event_start_date, event_end_date, adjusted_start, adjusted_end FROM ({queries[0]}) AS {alias_prefix}_0 WHERE event_start_date < DATE '{timestamp}' """ @@ -218,14 +246,16 @@ def render_event_group(event_group, alias_prefix="evt"): # Ensure both events contribute dates with temporal order and interval return f""" - SELECT {e1_alias}.person_id, {e1_alias}.event_start_date, {e1_alias}.event_end_date + SELECT {e1_alias}.person_id, {e1_alias}.event_start_date, {e1_alias}.event_end_date, + {e1_alias}.adjusted_start, {e1_alias}.adjusted_end FROM ({queries[0]}) AS {e1_alias} JOIN ({queries[1]}) AS {e2_alias} ON {e1_alias}.person_id = {e2_alias}.person_id AND {e1_alias}.event_start_date < {e2_alias}.event_start_date {interval_sql} UNION ALL - SELECT {e2_alias}.person_id, {e2_alias}.event_start_date, {e2_alias}.event_end_date + SELECT {e2_alias}.person_id, {e2_alias}.event_start_date, {e2_alias}.event_end_date, + {e2_alias}.adjusted_start, {e2_alias}.adjusted_end FROM ({queries[1]}) AS {e2_alias} JOIN ({queries[0]}) AS {e1_alias} ON {e2_alias}.person_id = {e1_alias}.person_id @@ -277,7 +307,8 @@ def temporal_event_filter(self, event_groups, alias='c'): # events: # - event_type: drug_exposure # event_concept_id: 67890 - return (f"SELECT person_id, event_start_date, event_end_date FROM " + return (f"SELECT person_id, event_start_date, event_end_date, " + f"adjusted_start, adjusted_end FROM " f"({' UNION ALL '.join(filters)}) AS combined_events") # Single event group case with operator defined diff --git a/biasanalyzer/database.py b/biasanalyzer/database.py index e6e30e2..35499a4 100644 --- a/biasanalyzer/database.py +++ b/biasanalyzer/database.py @@ -210,7 +210,7 @@ def cohort_distribution_variables(self): def get_cohort_distributions(self, cohort_definition_id: int, variable: str): """ - Get age distribution statistics for a cohort from the cohort table. + Get distribution statistics for a cohort from the cohort table. """ try: if self._create_omop_table('person'): diff --git a/biasanalyzer/sql_templates/base.sql.j2 b/biasanalyzer/sql_templates/base.sql.j2 index 69e2494..49db9e7 100644 --- a/biasanalyzer/sql_templates/base.sql.j2 +++ b/biasanalyzer/sql_templates/base.sql.j2 @@ -6,8 +6,8 @@ domain_qualifying_events AS ( filtered_cohort AS ( SELECT c.person_id, {% if temporal_events %} - MIN(c.event_start_date) AS cohort_start_date, - MAX(c.event_end_date) AS cohort_end_date + MIN(c.adjusted_start) AS cohort_start_date, + MAX(c.adjusted_end) AS cohort_end_date {% else %} MIN(all_events.event_start_date) AS cohort_start_date, MAX(all_events.event_end_date) AS cohort_end_date diff --git a/biasanalyzer/sql_templates/cohort_creation_query.sql.j2 b/biasanalyzer/sql_templates/cohort_creation_query.sql.j2 index fdc6d4d..628eced 100644 --- a/biasanalyzer/sql_templates/cohort_creation_query.sql.j2 +++ b/biasanalyzer/sql_templates/cohort_creation_query.sql.j2 @@ -2,9 +2,9 @@ {% block domain_events %} {% if ranked_domains %} WITH -{% for domain_type, domain in ranked_domains.items() %} +{% for event_type, domain in ranked_domains.items() %} {% if domain.table %} -ranked_events_{{ domain_type }} AS ( +ranked_asc_{{ event_type }} AS ( SELECT person_id, {{ domain.concept_id }} AS concept_id, @@ -16,6 +16,18 @@ ranked_events_{{ domain_type }} AS ( ) AS event_instance FROM {{ domain.table }} ), +ranked_desc_{{ event_type }} AS ( + SELECT + person_id, + {{ domain.concept_id }} AS concept_id, + {{ domain.start_date }} AS event_start_date, + {{ domain.end_date }} AS event_end_date, + ROW_NUMBER() OVER ( + PARTITION BY person_id, {{ domain.concept_id }} + ORDER BY {{ domain.start_date }} DESC + ) AS event_instance + FROM {{ domain.table }} +), {% endif %} {% endfor %} {% endif %} @@ -25,7 +37,7 @@ ranked_events_{{ domain_type }} AS ( {% if inclusion_criteria.temporal_events %} {{ temporal_event_filter(inclusion_criteria.temporal_events) }} {% else %} - SELECT person_id + SELECT person_id, NULL AS event_start_date, NULL AS event_end_date, NULL AS adjusted_start, NULL AS adjusted_end FROM person p {% endif %} {% endblock %} diff --git a/tests/assets/cohort_creation/test_cohort_creation_negative_instance.yaml b/tests/assets/cohort_creation/test_cohort_creation_negative_instance.yaml new file mode 100644 index 0000000..dbe5e4b --- /dev/null +++ b/tests/assets/cohort_creation/test_cohort_creation_negative_instance.yaml @@ -0,0 +1,11 @@ +inclusion_criteria: + demographics: + gender: female + min_birth_year: 1970 + max_birth_year: 2000 + temporal_events: + - operator: AND + events: + - event_type: condition_occurrence + event_concept_id: 201826 # Type 2 diabetes (valid OMOP ID) + event_instance: -1 # Last occurrence diff --git a/tests/assets/cohort_creation/test_cohort_creation_negative_instance_offset.yaml b/tests/assets/cohort_creation/test_cohort_creation_negative_instance_offset.yaml new file mode 100644 index 0000000..9feeb4d --- /dev/null +++ b/tests/assets/cohort_creation/test_cohort_creation_negative_instance_offset.yaml @@ -0,0 +1,12 @@ +inclusion_criteria: + demographics: + gender: female + min_birth_year: 1970 + max_birth_year: 2000 + temporal_events: + - operator: AND + events: + - event_type: condition_occurrence + event_concept_id: 201826 # Type 2 diabetes (valid OMOP ID) + event_instance: -1 # Last occurrence + offset: 180 # 180 days after diff --git a/tests/assets/cohort_creation/test_cohort_creation_offset.yaml b/tests/assets/cohort_creation/test_cohort_creation_offset.yaml new file mode 100644 index 0000000..c6b75c9 --- /dev/null +++ b/tests/assets/cohort_creation/test_cohort_creation_offset.yaml @@ -0,0 +1,14 @@ +inclusion_criteria: + demographics: + gender: female + min_birth_year: 1970 + max_birth_year: 2000 + temporal_events: + - operator: AND + events: + - event_type: condition_occurrence + event_concept_id: 201826 # Type 2 diabetes (valid OMOP ID) + offset: 180 # 180 days after + - event_type: condition_occurrence + event_concept_id: 201826 # Type 2 diabetes (valid OMOP ID) + offset: -730 # 2 years before diff --git a/tests/conftest.py b/tests/conftest.py index 4fc856c..1ae39cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -123,7 +123,10 @@ def test_db(): (2, 8532, 0, 0, 1996), -- Female, qualifying, not excluded due to not having cardiac surgery (3, 8532, 0, 0, 1996), -- Female, has cardiac surgery (4, 8507, 0, 0, 1980), -- Male, wrong gender - (5, 8532, 0, 0, 1980); -- Female, missing insulin + (5, 8532, 0, 0, 1980), -- Female, missing insulin + -- for offset and negative instance testing + (6, 8532, 0, 0, 1985), -- Female, multiple diabetes records, last one too early + (7, 8532, 0, 0, 1990); -- Female, diabetes record too recent """) # Insert mock concepts as needed @@ -141,7 +144,8 @@ def test_db(): (5, 'Fever', '2012-04-01', '2020-04-01', 'R50.9', 'ICD10CM', 'Condition'), (37311061, 'COVID-19', '2012-04-01', '2020-04-01', '840539006', 'SNOMED', 'Condition'), (4041664, 'Difficulty breathing', '2012-04-01', '2020-04-01', '230145002', 'SNOMED', 'Condition'), - (316139, 'Heart failure', '2012-04-01', '2020-04-01', '84114007', 'SNOMED', 'Condition'); + (316139, 'Heart failure', '2012-04-01', '2020-04-01', '84114007', 'SNOMED', 'Condition'), + (201826, 'Type 2 diabetes mellitus', '2012-04-01', '2020-04-01', '44054006', 'SNOMED', 'Condition'); """) # Insert hierarchical relationships as needed @@ -163,7 +167,8 @@ def test_db(): (1, 3, 1), -- Diabetes -> Type 2 (1, 4, 2), -- Diabetes -> Retinopathy (2, 4, 1), -- Type 1 -> Diabetes Retinopathy - (3, 4, 1); -- Type 2 -> Diabetes Retinopathy + (3, 4, 1), -- Type 2 -> Diabetes Retinopathy + (201826, 201826, 0); -- Type 2 diabetes SNOMED """) # Insert mock condition occurrences as needed @@ -199,7 +204,11 @@ def test_db(): (2, 201826, '2020-06-01', '2020-06-01'), -- Person 2: Diabetes (3, 201826, '2020-06-01', '2020-06-01'), -- Person 3: Diabetes (4, 201826, '2020-06-01', '2020-06-01'), -- Person 4: Diabetes - (5, 201826, '2020-06-01', '2020-06-01'); -- Person 5: Diabetes + (5, 201826, '2020-06-01', '2020-06-01'), -- Person 5: Diabetes + -- for negative event instance and offset testing + (6, 201826, '2017-01-01', '2017-01-01'), -- Person 6: Early diabetes record + (6, 201826, '2018-01-01', '2018-01-01'), -- Person 6: Last diabetes record, still early + (7, 201826, '2023-01-01', '2023-01-01'); -- Person 7: Recent diabetes record """) # Insert mock visit data @@ -220,7 +229,10 @@ def test_db(): (2, 9, 9202, '2020-06-10', '2020-06-10'), -- Person 2: Outpatient (3, 10, 9202, '2020-06-10', '2020-06-10'), -- Person 3: Outpatient (4, 11, 9202, '2020-06-10', '2020-06-10'), -- Person 4: Outpatient - (5, 12, 9202, '2020-06-10', '2020-06-10'); -- Person 5: Outpatient + (5, 12, 9202, '2020-06-10', '2020-06-10'), -- Person 5: Outpatient + -- New patients (no visits needed for exclusion testing) + (6, 13, 9202, '2018-01-10', '2018-01-10'), -- Person 6: Outpatient + (7, 14, 9202, '2023-01-10', '2023-01-10'); -- Person 7: Outpatient """) # Insert mock procedure_occurrence data for mixed domain testing @@ -234,7 +246,9 @@ def test_db(): (3, 3, 4048609, '2020-06-20'), -- Person 3: Blood test (3, 4, 619339, '2020-06-25'), -- Person 3: Cardiac surgery (exclusion) (4, 5, 4048609, '2020-06-20'), -- Person 4: Blood test - (5, 6, 4048609, '2020-06-20'); -- Person 5: Blood test + (5, 6, 4048609, '2020-06-20'), -- Person 5: Blood test + (6, 7, 4048609, '2018-01-15'), -- Person 6: Blood test + (7, 8, 4048609, '2023-01-15'); -- Person 7: Blood test """) # Insert mock procedure_occurrence data for mixed domain testing @@ -246,7 +260,9 @@ def test_db(): (1, 4285892, '2020-06-15', '2020-06-15'), -- Person 1: Insulin 14 days after (2, 4285892, '2020-06-15', '2020-06-15'), -- Person 2: Insulin (3, 4285892, '2020-06-15', '2020-06-15'), -- Person 3: Insulin - (4, 4285892, '2020-06-15', '2020-06-15'); -- Person 4: Insulin + (4, 4285892, '2020-06-15', '2020-06-15'), -- Person 4: Insulin + (6, 4285892, '2018-01-20', '2018-01-20'), -- Person 6: Insulin + (7, 4285892, '2023-01-20', '2023-01-20'); -- Person 7: Insulin -- Person 5: No insulin """) diff --git a/tests/query_based/test_cohort_creation.py b/tests/query_based/test_cohort_creation.py index 8211cad..faafa94 100644 --- a/tests/query_based/test_cohort_creation.py +++ b/tests/query_based/test_cohort_creation.py @@ -290,3 +290,406 @@ def close(self): result = fresh_bias_obj.create_cohort("test", "desc", "SELECT * FROM person", "test_user") assert result is None + + +import os +import datetime +import logging +import pytest +from sqlalchemy.exc import SQLAlchemyError +from numpy.ma.testutils import assert_equal +from biasanalyzer.models import DemographicsCriteria, TemporalEvent, TemporalEventGroup + + +def test_cohort_yaml_validation(test_db): + invalid_data = { + "gender": "female", + "min_birth_year": 2000, + "max_birth_year": 1999 # Invalid: less than min_birth_year + } + with pytest.raises(ValueError): + DemographicsCriteria(**invalid_data) + + invalid_data = { + "event_type": "date", + "event_concept_id": "dummy" + } + # validate date event_type must have a timestamp field + with pytest.raises(ValueError): + TemporalEvent(**invalid_data) + + invalid_data = { + "operator": "BEFORE", + "events": [ + {'event_type': 'condition_occurrence', + 'event_concept_id': 201826}, + {'event_type': 'drug_exposure', + 'event_concept_id': 4285892}, + ], + "interval": [100, 50] + } + # validate interval start must be smaller than interval end + with pytest.raises(ValueError): + TemporalEventGroup(**invalid_data) + + # validate interval must be either a list of 2 integers or a None + invalid_data["interval"] = [123] + with pytest.raises(ValueError): + TemporalEventGroup(**invalid_data) + + # validate NOT operator cannot have more than one event + invalid_data["operator"] = "NOT" + with pytest.raises(ValueError): + TemporalEventGroup(**invalid_data) + + # validate BEFORE operator must have two events + invalid_data["operator"] = "BEFORE" + del invalid_data["events"][1] + with pytest.raises(ValueError): + TemporalEventGroup(**invalid_data) + + +def test_cohort_creation_baseline(caplog, test_db): + bias = test_db + cohort = bias.create_cohort( + "COVID-19 patient", + "Cohort of young female patients", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_condition_occurrence_config_baseline.yaml'), + "test_user" + ) + + # Test cohort object and methods + assert cohort is not None, "Cohort creation failed" + cohort_id = cohort.cohort_id + assert bias.bias_db.get_cohort_definition(cohort_id)['name'] == "COVID-19 patient" + assert bias.bias_db.get_cohort_definition(cohort_id + 1) == {} + assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" + assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert cohort.data is not None, "Cohort creation wrongly returned None data" + caplog.clear() + with caplog.at_level(logging.ERROR): + cohort.get_distributions('ethnicity') + assert "Distribution for variable 'ethnicity' is not available" in caplog.text + + assert len(cohort.get_distributions('age')) == 10, "Cohort get_distribution('age') does not return 10 age_bin items" + assert len(cohort.get_distributions('gender')) == 3, ("Cohort get_distribution('gender') does not return " + "3 gender_bin items") + + patient_ids = set([item['subject_id'] for item in cohort.data]) + assert_equal(len(patient_ids), 5) + assert_equal(patient_ids, {106, 108, 110, 111, 112}) + # select two patients to check for cohort_start_date and cohort_end_date automatically computed + patient_106 = next(item for item in cohort.data if item['subject_id'] == 106) + patient_108 = next(item for item in cohort.data if item['subject_id'] == 108) + + # Replace dates with actual values from your test data + assert_equal(patient_106['cohort_start_date'], datetime.date(2023, 3, 1), + "Incorrect cohort_start_date for patient 106") + assert_equal(patient_106['cohort_end_date'], datetime.date(2023, 3, 15), + "Incorrect cohort_end_date for patient 106") + assert_equal(patient_108['cohort_start_date'], datetime.date(2020, 4, 10), + "Incorrect cohort_start_date for patient 108") + assert_equal(patient_108['cohort_end_date'], datetime.date(2020, 4, 27), + "Incorrect cohort_end_date for patient 108") + + +def test_cohort_creation_study(test_db): + bias = test_db + cohort = bias.create_cohort( + "COVID-19 patient", + "Cohort of young female patients with COVID-19", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_condition_occurrence_config_study.yaml'), + "test_user" + ) + # Test cohort object and methods + assert cohort is not None, "Cohort creation failed" + assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" + assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert cohort.data is not None, "Cohort creation wrongly returned None data" + patient_ids = set([item['subject_id'] for item in cohort.data]) + assert_equal(len(patient_ids), 4) + assert_equal(patient_ids, {108, 110, 111, 112}) + + +def test_cohort_creation_study2(caplog, test_db): + bias = test_db + caplog.clear() + with caplog.at_level(logging.INFO): + cohort = bias.create_cohort( + "COVID-19 patient", + "Cohort of young female patients with no COVID-19", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_condition_occurrence_config_study2.yaml'), + "test_user", + delay=1 + ) + assert 'Simulating long-running task' in caplog.text + # Test cohort object and methods + assert cohort is not None, "Cohort creation failed" + assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" + assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert cohort.data is not None, "Cohort creation wrongly returned None data" + patient_ids = set([item['subject_id'] for item in cohort.data]) + assert_equal(len(patient_ids), 1) + assert_equal(patient_ids, {106}) + + +def test_cohort_creation_all(caplog, test_db): + bias = test_db + cohort = bias.create_cohort( + "COVID-19 patient", + "Cohort of young female patients with COVID-19 who have the condition with difficulty breathing 2 to 5 days " + "before a COVID diagnosis 3/15/20-12/11/20 AND have at least one emergency room visit or at least " + "two inpatient visits", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_condition_occurrence_config.yaml'), + "test_user" + ) + # Test cohort object and methods + assert cohort is not None, "Cohort creation failed" + assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" + assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + stats = cohort.get_stats() + assert stats is not None, "Created cohort's stats is None" + gender_stats = cohort.get_stats(variable='gender') + assert gender_stats is not None, "Created cohort's gender stats is None" + caplog.clear() + with caplog.at_level(logging.ERROR): + cohort.get_stats(variable='address') + assert 'is not available' in caplog.text + assert gender_stats is not None, "Created cohort's gender stats is None" + assert cohort.data is not None, "Cohort creation wrongly returned None data" + patient_ids = set([item['subject_id'] for item in cohort.data]) + print(f'patient_ids: {patient_ids}', flush=True) + assert_equal(len(patient_ids), 2) + assert_equal(patient_ids, {108, 110}) + + +def test_cohort_creation_multiple_temporary_groups_with_no_operator(test_db): + bias = test_db + cohort = bias.create_cohort( + "Patients with COVID or other emergency conditions", + "Cohort of young female patients who either have COVID-19 with difficulty breathing 2 to 5 days " + "before a COVID diagnosis 3/15/20-12/11/20 OR have at least one emergency room visit or at least " + "two inpatient visits", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_multiple_temporal_groups_without_operator.yaml'), + "test_user" + ) + # Test cohort object and methods + patient_ids = set([item['subject_id'] for item in cohort.data]) + print(f'patient_ids: {patient_ids}', flush=True) + assert_equal(len(patient_ids), 2) + assert_equal(patient_ids, {108, 110}) + + +def test_cohort_creation_mixed_domains(test_db): + """ + Test cohort creation with mixed domains (condition, drug, visit, procedure). + """ + bias = test_db + cohort = bias.create_cohort( + "Female diabetes patients born between 1970 and 2000", + "Cohort of female patients with diabetes who had insulin prescribed 0-30 days after diagnosis " + "and have at least one outpatient or emergency visit and underwent a blood test before 12/31/2020, " + "with patients born after 1995 and with cardiac surgery excluded", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_config.yaml'), + "test_user" + ) + + # Test cohort object and methods + assert cohort is not None, "Cohort creation failed" + print(f'metadata: {cohort.metadata}') + assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" + assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + stats = cohort.get_stats() + assert stats is not None, "Created cohort's stats is None" + assert cohort.data is not None, "Cohort creation wrongly returned None data" + patient_ids = set([item['subject_id'] for item in cohort.data]) + print(f'patient_ids: {patient_ids}', flush=True) + assert_equal(len(patient_ids), 3) + assert_equal(patient_ids, {1, 2, 6}) + start_dates = [item['cohort_start_date'] for item in cohort.data] + assert_equal(len(start_dates), 3) + assert_equal(start_dates, [datetime.date(2020, 6, 1), + datetime.date(2020, 6, 1), + datetime.date(2018, 1, 1)]) + end_dates = [item['cohort_end_date'] for item in cohort.data] + assert_equal(len(end_dates), 3) + assert_equal(end_dates, [datetime.date(2020, 6, 20), + datetime.date(2020, 6, 20), + datetime.date(2018, 1, 20)]) + + +def test_cohort_comparison(test_db): + bias = test_db + cohort_base = bias.create_cohort( + "COVID-19 patient", + "Cohort of young female patients", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_condition_occurrence_config_baseline.yaml'), + "test_user" + ) + cohort_study = bias.create_cohort( + "Female diabetes patients born between 1970 and 2000", + "Cohort of female patients with diabetes who had insulin prescribed 0-30 days after diagnosis " + "and have at least one outpatient or emergency visit and underwent a blood test before 12/31/2020, " + "with patients born after 1995 and with cardiac surgery excluded", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_config.yaml'), + "test_user" + ) + results = bias.compare_cohorts(cohort_base.cohort_id, cohort_study.cohort_id) + assert {'gender_hellinger_distance': 0.0} in results + assert any('age_hellinger_distance' in r for r in results) + + +def test_cohort_invalid(caplog, test_db): + caplog.clear() + with caplog.at_level(logging.INFO): + invalid_cohort = test_db.create_cohort('invalid_cohort', 'invalid_cohort', + 'invalid_yaml_file.yml', + 'invalid_created_by') + assert 'cohort creation configuration file does not exist' in caplog.text + assert invalid_cohort is None + + caplog.clear() + with caplog.at_level(logging.INFO): + invalid_cohort = test_db.create_cohort('invalid_cohort', 'invalid_cohort', + os.path.join(os.path.dirname(__file__), '..', 'assets', 'config', + 'test_config.yaml'), 'invalid_created_by') + assert 'configuration yaml file is not valid' in caplog.text + assert invalid_cohort is None + + with caplog.at_level(logging.INFO): + invalid_cohort = test_db.create_cohort('invalid_cohort', 'invalid_cohort', + 'INVALID SQL QUERY STRING', + 'invalid_created_by') + assert 'Error executing query:' in caplog.text + assert invalid_cohort is None + + +def test_create_cohort_sqlalchemy_error(monkeypatch, fresh_bias_obj): + # Mock omop_db methods + class MockOmopDB: + def get_session(self): + return self # not used after error + + def execute_query(self, query): + raise SQLAlchemyError("Mocked SQLAlchemy error") + + def close(self): + pass + + class MockBiasDB: + def create_cohort_definition(self, *args, **kwargs): + pass + + def create_cohort_in_bulk(self, *args, **kwargs): + pass + + def close(self): + pass + + fresh_bias_obj.omop_cdm_db = MockOmopDB() + fresh_bias_obj.bias_db = MockBiasDB() + + result = fresh_bias_obj.create_cohort("test", "desc", "SELECT * FROM person", "test_user") + + assert result is None + + +def test_cohort_creation_negative_instance(test_db): + """ + Test cohort creation with negative event_instance (last occurrence of a condition). + """ + bias = test_db + cohort = bias.create_cohort( + "Diabetes patients (last occurrence)", + "Cohort of female patients born 1970-2000 with the last Type 2 diabetes diagnosis", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_negative_instance.yaml'), + "test_user" + ) + + # Test cohort object and methods + assert cohort is not None, "Cohort creation failed" + assert cohort.data is not None, "Cohort creation returned None data" + + patient_ids = set([item['subject_id'] for item in cohort.data]) + assert_equal(len(patient_ids), 6) # Female patients 1, 2, 3, 5 + assert_equal(patient_ids, {1, 2, 3, 5, 6, 7}) + + # Verify dates for a specific patient (e.g., patient 1 with last diabetes diagnosis) + patient_1 = next(item for item in cohort.data if item['subject_id'] == 1) + assert_equal(patient_1['cohort_start_date'], datetime.date(2020, 6, 1), + "Incorrect cohort_start_date for patient 1 (last diabetes)") + assert_equal(patient_1['cohort_end_date'], datetime.date(2020, 6, 1), + "Incorrect cohort_end_date for patient 1 (last diabetes)") + + +def test_cohort_creation_offset(test_db): + """ + Test cohort creation with non-zero positive and negative offsets. + """ + bias = test_db + cohort = bias.create_cohort( + "Diabetes patients with offset", + "Cohort of female patients born 1970-2000 with Type 2 diabetes diagnosis, adjusted by +180 and -730 days", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_offset.yaml'), + "test_user" + ) + + # Test cohort object and methods + assert cohort is not None, "Cohort creation failed" + assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" + assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert cohort.data is not None, "Cohort creation wrongly returned None data" + + patient_ids = set([item['subject_id'] for item in cohort.data]) + assert_equal(len(patient_ids), 6) # Female patients 1, 2, 3, 5 + assert_equal(patient_ids, {1, 2, 3, 5, 6, 7}) + + # Verify dates for a specific patient (e.g., patient 1 with offset) + patient_1 = next(item for item in cohort.data if item['subject_id'] == 1) + # Diabetes on 2020-06-01: -730 days = 2018-06-02, +180 days = 2020-11-28 + assert_equal(patient_1['cohort_start_date'], datetime.date(2018, 6, 2), + "Incorrect cohort_start_date for patient 1 (with -730 day offset)") + assert_equal(patient_1['cohort_end_date'], datetime.date(2020, 11, 28), + "Incorrect cohort_end_date for patient 1 (with +180 day offset)") + + +def test_cohort_creation_negative_instance_offset(test_db): + """ + Test cohort creation with negative event_instance and non-zero offset. + """ + bias = test_db + cohort = bias.create_cohort( + "Diabetes patients (last occurrence with offset)", + "Cohort of female patients born 1970-2000 with the last Type 2 diabetes diagnosis, adjusted by +180 days", + os.path.join(os.path.dirname(__file__), '..', 'assets', 'cohort_creation', + 'test_cohort_creation_negative_instance_offset.yaml'), + "test_user" + ) + + # Test cohort object and methods + assert cohort is not None, "Cohort creation failed" + assert cohort.metadata is not None, "Cohort creation wrongly returned None metadata" + assert 'creation_info' in cohort.metadata, "Cohort creation does not contain 'creation_info' key" + assert cohort.data is not None, "Cohort creation wrongly returned None data" + + patient_ids = set([item['subject_id'] for item in cohort.data]) + assert_equal(len(patient_ids), 6) + assert_equal(patient_ids, {1, 2, 3, 5, 6, 7}) + + # Verify dates for a specific patient (e.g., patient 1 with last diabetes and offset) + patient_1 = next(item for item in cohort.data if item['subject_id'] == 1) + # Last diabetes on 2020-06-01: +180 days = 2020-11-28 + assert_equal(patient_1['cohort_start_date'], datetime.date(2020, 6, 1), + "Incorrect cohort_start_date for patient 1 (last diabetes)") + assert_equal(patient_1['cohort_end_date'], datetime.date(2020, 11, 28), + "Incorrect cohort_end_date for patient 1 (last diabetes with +180 day offset)")