From 311d9d16abd6f108f75f2053c7819875f96d44b8 Mon Sep 17 00:00:00 2001 From: John Wu Date: Sun, 5 May 2024 15:29:07 -0500 Subject: [PATCH 1/2] mimic4_umse for DL4H --- pyhealth/datasets/mimic4_umse.py | 530 +++++++++++++++++++++++++++++++ 1 file changed, 530 insertions(+) create mode 100644 pyhealth/datasets/mimic4_umse.py diff --git a/pyhealth/datasets/mimic4_umse.py b/pyhealth/datasets/mimic4_umse.py new file mode 100644 index 00000000..d4deb3d3 --- /dev/null +++ b/pyhealth/datasets/mimic4_umse.py @@ -0,0 +1,530 @@ +from pandarallel import pandarallel +pandarallel.initialize(progress_bar=True, nb_workers=8) +import logging +import re +import pandas as pd +import os +import pickle +import matplotlib.pyplot as plt +from PIL import Image +from pyhealth.datasets import BaseEHRDataset +from pyhealth.datasets.utils import strptime +from typing import Optional, Dict +from tqdm import tqdm +from datetime import datetime +from torchvision import transforms +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + + + +def get_img_path(images_path, dicom_id): + img_path = f"{dicom_id}.jpg" + img_path = os.path.join(images_path, img_path) + return img_path + +def get_img(path, transform): + img = Image.open(path) + img_tensor = transform(img) + return img_tensor + +def get_section(text, section_header="Past Medical History"): + pattern = re.escape(section_header) + "(.*?)(?=\n[A-Za-z ]+:|$)" + + # Search for the pattern in the text + match = re.search(pattern, text, flags=re.DOTALL) + past_medical_history_section = None + if match: + past_medical_history_section = match.group(1) + # print(past_medical_history_section) + else: + print(f"Section '{section_header}:' not found.") + return past_medical_history_section[1:] # for the colon + + +# UMSE includes explicitly, the following: +# 1. Continuous Time Series Vital Signal and Lab Measurement Data +# 2. Clinical Notes +# 3. X-ray Images +# each list contains an EventUMSE (time, value, feature_type) + + +# Set of all observable tuples for a specific data type (e.g. lab measurements, vital signals, xrays, etc.) +# patient_id is the patient's id for the set of observables +# data_type is the type of data (e.g. lab measurements, vital signals, xrays, etc.) +# observables is a list of EventUMSE +class SetUMSE: + def __init__(self, patient_id, data_type, observables): + self.patient_id = patient_id + self.data_type = data_type + self.observables = observables + +# EventUMSE = namedtuple('EventUMSE', ['time', 'value', 'feature_type']) +class EventUMSE: + def __init__(self, time, feature_type, value): + self.time = time + self.feature_type = feature_type + self.value = value + +class PatientUMSE: + def __init__(self, patient_id: str, + notes : SetUMSE = None, + lab : SetUMSE = None, + chart : SetUMSE = None, + x_rays : SetUMSE = None, + birth_datetime: Optional[datetime] = None, + death_datetime: Optional[datetime] = None, + initial_admittime : Optional[datetime] = None, + final_discharge_time : Optional[datetime] = None, + gender=None, + ethnicity=None, + age=None, + outcome_events=None): + + self.patient_id = patient_id + self.birth_datetime = birth_datetime + self.death_datetime = death_datetime + self.admittime = initial_admittime + self.discharge_datetime = final_discharge_time + self.gender = gender + self.ethnicity = ethnicity + self.notes = notes + self.lab = lab + self.chart = chart + self.x_rays = x_rays + self.age = age + self.outcome_events = outcome_events + self.logger = logging.getLogger(__name__) + + def info(self): + print(f"Patient ID: {self.patient_id}") + print(f"Birth Date: {self.birth_datetime}") + print(f"Death Date: {self.death_datetime}") + print(f"Age: {self.age}") + print(f"Gender:{self.gender}") + print(f"Ethnicity: {self.ethnicity}") + print("First Admittime:", self.admittime) + print("Final Discharge Time:", self.discharge_datetime) + print("Total Number of Notes:", len(self.notes.observables)) + print("Total Number of X-rays:", len(self.x_rays.observables)) + print("Total Number of Lab Measurements:", len(self.lab.observables)) + print("Total Number of Chart Measurements:", len(self.chart.observables)) + +# This is where the dirty stuff is going to happen. +# Time-Series +# Hematocrit, Platelet, WBC, Bilirubin, pH, bicarbonate, Creatinine, Lactate, Potassium, and Sodium - Lab Events Used +# heart rate, respiration rate, diastolic and systolic blood pressure, temperature, and pulse oximetry. - ChartEvents Used +# Note Sections to Extract +# Past Medical History, Medications on Admission, Chief Medical Complaint (may or may not exist) +class MIMIC4UMSE(BaseEHRDataset): + + def dev_mode(self, df): + if self.dev: + unique_patients = df['subject_id'].unique() + limited_patients = unique_patients[:self.dev_patients] + limited_df = df[df['subject_id'].isin(limited_patients)] + return limited_df + else: + return df + + def get_item_ids(self, item_names, item_df): + item_set = set() + for specific_label in item_names: + # Handle NA/NaN values by replacing them with an empty string + item_df['label'] = item_df['label'].str.lower().fillna('') + if specific_label.lower() in ["ph"]: + matching_ids = item_df[item_df["label"] == specific_label.lower()]['itemid'].to_list() + else: + # Use str.contains correctly and handle NA/NaN values + matching_ids = item_df[item_df["label"].str.contains(specific_label.lower())]['itemid'].to_list() + item_set = item_set.union(set(matching_ids)) + return item_set + + def __init__( + self, + root: str, + cxr_root : str, + note_root : str, + dataset_name: Optional[str] = None, + note_sections = ["Past Medical History", "Medications on Admission", "Chief Medical Complaint"], + lab_events = ["Hematocrit", "Platelet", "WBC", "Bilirubin", "pH", "bicarbonate", "Creatinine", "Lactate", "Potassium", "Sodium"], + chart_events = ["Heart Rate", "respiratory rate", "blood pressure", "temperature", "pulseoxymetry"], + outcome_events = ["mortality", "intubation", "vasopressor", "icd"], + dev : bool = False, + use_parquet : bool = False, + use_relative_time = False, + time_unit = "day", + exclude_negative_time = False, # set if you want to exclude events that exist before their ED or ICU admission. + concatenate_notes = True, # Set if you want note sections to be separate events or all of them to be just one event + dev_patients : int = 1000, # number of patients to use in dev mode + **kwargs, + ): + if dataset_name is None: + dataset_name = self.__class__.__name__ + self.root = root + self.cxr_root = cxr_root + self.note_root = note_root + self.hosp_path = os.path.join(self.root, "hosp") + self.icu_path = os.path.join(self.root, "icu") + self.dataset_name = dataset_name + + # Items to Extract + self.note_sections = note_sections + self.lab_events = lab_events + self.chart_events = chart_events + self.outcome_events = outcome_events + + # Dataset Processing Details + self.image_transform = transforms.Compose([transforms.ToTensor()]) + self.dev = dev + self.dev_patients = dev_patients + self.use_parquet = use_parquet + self.time_unit = time_unit + self.exclude_negative_time = exclude_negative_time + self.concatenate_notes = concatenate_notes + + # read lab and chart event table mappings + lab_event_ids_df = pd.read_csv(os.path.join(self.hosp_path, "d_labitems.csv"), dtype={"itemid": str}) + chart_event_ids_df = pd.read_csv(os.path.join(self.icu_path, "d_items.csv"), dtype={"itemid": str}) + + # sets of lab event ids that we want to keep measurements of. + self.lab_event_ids = self.get_item_ids(lab_events, lab_event_ids_df) + self.chart_event_ids = self.get_item_ids(chart_events, chart_event_ids_df) + + # Convert from id to label + self.to_lab_event_names = lab_event_ids_df.set_index("itemid").to_dict()["label"] + self.to_chart_event_names = chart_event_ids_df.set_index("itemid").to_dict()["label"] + + self.logger.debug(f"Processing {self.dataset_name} base dataset...") + + self.patients = self.process(**kwargs) + if use_relative_time: + self.patients = self.set_patient_occurence_time(self.time_unit) + + # + 1 for the xrays + self.num_feature_types = len(self.lab_event_ids) + len(self.chart_event_ids) + len(self.note_sections) + 1 + + def process(self, **kwargs) -> Dict[str, PatientUMSE]: + patients = dict() + + # load patients info + patients = self.parse_basic_info(patients) + patients = self.parse_lab_events(patients) + patients = self.parse_chart_events(patients) + patients = self.parse_notes(patients) + patients = self.parse_xrays(patients) + + return patients + + + def add_observations_to_patients(self, patients: Dict[str, PatientUMSE], modality, df) -> Dict[str, PatientUMSE]: + for pid, set_umse in df.items(): + assert pid == set_umse.patient_id + patient_id = set_umse.patient_id + if patient_id in patients: + if modality == "lab": + patients[patient_id].lab = set_umse + elif modality == "chart": + patients[patient_id].chart = set_umse + elif modality == "note": + patients[patient_id].notes = set_umse + elif modality == "xray": + patients[patient_id].x_rays = set_umse + else: + AssertionError("Modality not recognized!") + + return patients + + def parse_basic_info(self, patients: Dict[str, PatientUMSE]) -> Dict[str, PatientUMSE]: + def process_patient(self, pid, p_info): + anchor_year = int(p_info["anchor_year"].values[0]) + anchor_age = int(p_info["anchor_age"].values[0]) + birth_year = anchor_year - anchor_age + admit_date = strptime(p_info["admittime"].values[0]) + discharge_date = strptime(p_info["dischtime"].values[-1]) # final + p_outcome_events = {} + + if "mortality" in self.outcome_events: + expired = int(1 in p_info["hospital_expire_flag"].values) # if patient died in hospital, label is 1, else 0 + p_outcome_events["mortality"] = expired + + # load observables + patient = PatientUMSE( + patient_id=pid, + notes=SetUMSE(pid, "note", []), + lab=SetUMSE(pid, "lab", []), + chart=SetUMSE(pid, "chart", []), + x_rays=SetUMSE(pid, "xray", []), + # no exact month, day, and time, use Jan 1st, 00:00:00 + birth_datetime=strptime(str(birth_year)), + # no exact time, use 00:00:00 + death_datetime=strptime(p_info["dod"].values[0]), + gender=p_info["gender"].values[0], + ethnicity=p_info["race"].values[0], + initial_admittime=admit_date, + final_discharge_time=discharge_date, + age=anchor_age, + outcome_events=p_outcome_events + ) + + return patient + + print("Reading Patients and Admissions!") + patients_df = pd.read_csv( + os.path.join(self.hosp_path, "patients.csv"), + dtype={"subject_id": str}, + nrows=self.dev_patients if self.dev else None + ) + print("Total Number of Patient Records:", len(patients_df)) + admissions_df = pd.read_csv( + os.path.join(self.hosp_path, "admissions.csv"), + dtype={"subject_id": str, "hadm_id": str} + ) + + # Now merge DataFrames + df = pd.merge(patients_df, admissions_df, on="subject_id", how="inner") + df = df.dropna(subset=["subject_id", "admittime", "dischtime"]) + # sort by admission and discharge time + df = df.sort_values(["subject_id", "admittime", "dischtime"], ascending=True) + + # group by patient + df_group = df.groupby("subject_id") + print("Parsing Basic Info!") + df_group = df_group.parallel_apply( + lambda x: process_patient(self, x.subject_id.unique()[0], x) + ) + + for pid, pat in df_group.items(): + patients[pid] = pat + + return patients + + # Get all specified lab events for each patient + def parse_lab_events(self, patients: Dict[str, PatientUMSE]) -> Dict[str, PatientUMSE]: + + def parse_lab_event(self, pid, p_info): + lab_measurements = [] + for idx, row in p_info.iterrows(): + # Check if the 'itemid' is in the list of lab event IDs and + # the 'category' is 'Routine Vital Signs' + if str(row['itemid']) in self.lab_event_ids: + # Convert charttime to datetime and handle parsing inside strptime function + charttime_datetime = strptime(row['charttime']) + event = EventUMSE(time=charttime_datetime, + feature_type = str(row['itemid']), + value = row['valuenum']) + lab_measurements.append(event) + + lab_measurements = SetUMSE(pid, "lab", lab_measurements) + return lab_measurements + + # read lab event data + print("Reading Lab Events!") + lab_events_df = None + if self.use_parquet: + path = os.path.join(self.hosp_path, "labevents.parquet") + if os.path.exists(path): + print("Loading Existing Parquet Path!") + lab_events_df = pd.read_parquet(path) + else: + print("Creating New Parquet File!") + lab_events_df = pd.read_csv(os.path.join(self.hosp_path, "labevents.csv")) + lab_events_df.to_parquet(path, index=False) + else: + lab_events_df = pd.read_csv(os.path.join(self.hosp_path, "labevents.csv")) + + print(f"Read {len(lab_events_df)} lab events.") + print("Parsing Lab Events!") + lab_events_df = lab_events_df.dropna(subset=["subject_id", "itemid", "valuenum", "charttime"]) + lab_events_df = lab_events_df.sort_values(["subject_id", "itemid", "charttime"], ascending=True) + lab_events_df['subject_id'] = lab_events_df['subject_id'].astype(str) + lab_events_df = self.dev_mode(lab_events_df) + lab_events_df = lab_events_df.groupby("subject_id") + lab_events_df = lab_events_df.parallel_apply(lambda x: parse_lab_event(self, x.subject_id.unique()[0], x)) + patients = self.add_observations_to_patients(patients, "lab", lab_events_df) + return patients + + + def parse_chart_events(self, patients : Dict[str, PatientUMSE]) -> Dict[str, PatientUMSE]: + def parse_chart_event(self, pid, p_info): + chart_measurements = [] + for idx, row in p_info.iterrows(): + # want (feature type, charttime, valuenum) + if str(row['itemid']) in self.chart_event_ids: + event = EventUMSE(time=strptime(row['charttime']), + feature_type=str(row['itemid']), + value=row['valuenum']) + chart_measurements.append(event) + chart_measurements = SetUMSE(pid, "chart", chart_measurements) + return chart_measurements + + print("Reading Chart Events!") + icu_chart_events_df = None + if self.use_parquet: + path = os.path.join(self.icu_path, "chartevents.parquet") + if os.path.exists(path): + print("Loading Existing Parquet Path!") + icu_chart_events_df = pd.read_parquet(path) + else: + print("Creating New Parquet File!") + icu_chart_events_df = pd.read_csv(os.path.join(self.icu_path, "chartevents.csv")) + icu_chart_events_df.to_parquet(path, index=False) + else: + icu_chart_events_df = pd.read_csv(os.path.join(self.icu_path, "chartevents.csv")) + print(f"Read {len(icu_chart_events_df)} chart events.") + + print("Parsing Chart Events!") + icu_chart_events_df = icu_chart_events_df.dropna(subset=["subject_id", "itemid", "valuenum", "charttime"]) + icu_chart_events_df = icu_chart_events_df.sort_values(["subject_id", "itemid", "charttime"], ascending=True) + icu_chart_events_df['subject_id'] = icu_chart_events_df['subject_id'].astype(str) + icu_chart_events_df = self.dev_mode(icu_chart_events_df) + icu_chart_events_df = icu_chart_events_df.groupby("subject_id") + icu_chart_events_df = icu_chart_events_df.parallel_apply(lambda x: parse_chart_event(self, x.subject_id.unique()[0], x)) + patients = self.add_observations_to_patients(patients, "chart", icu_chart_events_df) + return patients + + def parse_notes(self, patients : Dict[str, PatientUMSE]) -> Dict[str, PatientUMSE]: + def parse_note(self, pid, p_info): + notes = [] + for idx, row in p_info.iterrows(): + # want (feature type, time, text) + text = row['text'] + if self.note_sections[0] == "all": + event = EventUMSE(time=row['charttime'], feature_type="note", value=text) + notes.append(event) + else: + if self.concatenate_notes: + combined_text = " " + for section in self.note_sections: + if section in text: + combined_text += "" + get_section(text.lower(), section.lower()) + event = EventUMSE(time=row['charttime'], feature_type="note", value=combined_text) + notes.append(event) + else: + for section in self.note_sections: + if section in text: + event = EventUMSE(time=row['charttime'], feature_type=section, value=get_section(text.lower(), section.lower())) + notes.append(event) + + notes = SetUMSE(pid, "note", notes) + return notes + + # Read Note Data + print("Reading Note Data!") + note_df = None + if self.use_parquet: + path = os.path.join(self.note_root, "discharge.parquet") + if os.path.exists(path): + print("Loading Existing Parquet Path!") + note_df = pd.read_parquet(path) + else: + print("Creating New Parquet File!") + note_df = pd.read_csv(os.path.join(self.note_root, "discharge.csv")) + note_df.to_parquet(path, index=False) + else: + note_df = pd.read_csv(os.path.join(self.note_root, "discharge.csv")) + note_df = note_df.dropna(subset=["subject_id", "text", "charttime"]) + print(f"Read {len(note_df)} note events.") + note_df = note_df.sort_values(["subject_id", "charttime"], ascending=True) + + note_df['subject_id'] = note_df['subject_id'].astype(str) + note_df = self.dev_mode(note_df) + note_df = note_df.groupby("subject_id") + print("Parsing Notes!") + note_df = note_df.parallel_apply(lambda x: parse_note(self, x.subject_id.unique()[0], x)) + + patients = self.add_observations_to_patients(patients, "note", note_df) + return patients + + def parse_xrays(self, patients : Dict[str, PatientUMSE]) -> Dict[str, PatientUMSE]: + def process_xray(self, pid, p_info): + xrays = [] + for idx, row in p_info.iterrows(): + # want ("xray", time, image) + dicom_id = row['dicom_id'] + image_path = get_img_path(os.path.join(self.cxr_root, "images"), dicom_id) + event = EventUMSE(time=row['StudyDateTime'], feature_type="xray", value=image_path) + xrays.append(event) + xrays = SetUMSE(pid, "xray", xrays) + return xrays + # read mimic-cxr metadata + print("Reading CXR metadata!") + cxr_jpg_meta_df = pd.read_csv(os.path.join(self.cxr_root, "mimic-cxr-2.0.0-metadata.csv")) + cxr_jpg_meta_df.StudyDate = cxr_jpg_meta_df.StudyDate.astype(str) + cxr_jpg_meta_df.StudyTime = cxr_jpg_meta_df.StudyTime.astype(str).str.split(".").str[0] + cxr_jpg_meta_df["StudyDateTime"] = pd.to_datetime(cxr_jpg_meta_df.StudyDate + cxr_jpg_meta_df.StudyTime, + format="%Y%m%d%H%M%S", + errors="coerce") + + cxr_df = cxr_jpg_meta_df[cxr_jpg_meta_df.StudyDateTime.isna()] + cxr_df = cxr_jpg_meta_df[["subject_id", "study_id", "dicom_id", "StudyDateTime"]] + cxr_df = cxr_df.dropna(subset=["subject_id", "dicom_id", "StudyDateTime"]) + cxr_df = cxr_df.sort_values(["subject_id", "StudyDateTime"], ascending=True) + print(f"Read {len(cxr_df)} x-ray events.") + cxr_df['subject_id'] = cxr_df['subject_id'].astype(str) + cxr_df = self.dev_mode(cxr_df) + cxr_df = cxr_df.groupby("subject_id") + print("Parsing X-rays!") + cxr_df = cxr_df.parallel_apply(lambda x: process_xray(self, x.subject_id.unique()[0], x)) + patients = self.add_observations_to_patients(patients, "xray", cxr_df) + + return patients + + def unit_factor(self, difference, unit): + if unit == 'hour': + return difference / 3600 + elif unit == 'day': + return difference / (3600 * 24) + elif unit == 'minute': + return difference / 60 + elif unit == 'second': + return difference + else: + raise ValueError("Unit not recognized") + + # t_occurence = t - t_admit + # t_current = t_discharge - t + def set_patient_occurence_time(self, unit='day'): + for patient_id, patient in tqdm(self.patients.items(), desc="Setting all Charttimes to Relative Time from Admittime"): + + if patient.notes: + for note in patient.notes.observables: + note_time = datetime.strptime(note.time, "%Y-%m-%d %H:%M:%S") if isinstance(note.time, str) else note.time + note.time = self.unit_factor((note_time - patient.admittime).total_seconds(), unit) + + if patient.x_rays: + for xray in patient.x_rays.observables: + xray_time = datetime.strptime(xray.time, "%Y-%m-%d %H:%M:%S") if isinstance(xray.time, str) else xray.time + xray.time = self.unit_factor((xray_time - patient.admittime).total_seconds(), unit) + + if patient.lab: + for lab_event in patient.lab.observables: + lab_event_time = datetime.strptime(lab_event.time, "%Y-%m-%d %H:%M:%S") if isinstance(lab_event.time, str) else lab_event.time + lab_event.time = self.unit_factor((lab_event_time - patient.admittime).total_seconds(), unit) + + if patient.chart: + for chart_event in patient.chart.observables: + chart_event_time = datetime.strptime(chart_event.time, "%Y-%m-%d %H:%M:%S") if isinstance(chart_event.time, str) else chart_event.time + chart_event.time = self.unit_factor((chart_event_time - patient.admittime).total_seconds(), unit) + + return self.patients + + + +def save_to_pkl(obj, path): + with open(path, 'wb') as f: + pickle.dump(obj, f) + + + +if __name__ == "__main__": + mimic_cxr_path = "/home/johnwu3/projects/serv4/Medical_Tri_Modal_Pilot/data/physionet.org/files/MIMIC-CXR" + mimic_cxr_jpg_path = "/home/johnwu3/projects/serv4/Medical_Tri_Modal_Pilot/data/physionet.org/files/MIMIC-CXR" + mimic_iv_path = "/home/johnwu3/projects/serv4/Medical_Tri_Modal_Pilot/data/physionet.org/files/MIMIC-IV/2.0/" + mimic_note_directory = "/home/johnwu3/projects/serv4/Medical_Tri_Modal_Pilot/data/physionet.org/files/mimic-iv-note/2.2/note" + dataset = MIMIC4UMSE(root=mimic_iv_path, + cxr_root=mimic_cxr_jpg_path, + note_root=mimic_note_directory, + dev=True, + dev_patients=2000, + use_parquet=True, + use_relative_time=True) \ No newline at end of file From 97a262ebdad1e4ab09f08d329d406d11af748324 Mon Sep 17 00:00:00 2001 From: John Wu Date: Thu, 17 Oct 2024 16:11:01 -0500 Subject: [PATCH 2/2] Attempts at task template and sample dataset --- pyhealth/data/__init__.py | 6 + pyhealth/data/cache.py | 126 +++++++ pyhealth/data/data.py | 2 +- pyhealth/data/data_v2.py | 198 +++++++++++ pyhealth/datasets/base_dataset_v2.py | 144 ++++++++ pyhealth/datasets/base_ehr_dataset.py | 2 +- pyhealth/datasets/eicu.py | 4 +- pyhealth/datasets/mimic3.py | 389 ++++++++++++++++++--- pyhealth/datasets/mimic4.py | 409 +++++++++++++++++++++-- pyhealth/datasets/sample_dataset.py | 1 - pyhealth/datasets/sample_dataset_v2.py | 140 ++++++++ pyhealth/featurizers/__init__.py | 2 + pyhealth/featurizers/image.py | 22 ++ pyhealth/featurizers/value.py | 14 + pyhealth/tasks/drug_recommendation.py | 2 +- pyhealth/tasks/medical_coding.py | 176 ++++++++++ pyhealth/tasks/mortality_prediction.py | 2 +- pyhealth/tasks/readmission_prediction.py | 2 +- pyhealth/tasks/task_template.py | 68 ++++ test.py | 17 + 20 files changed, 1648 insertions(+), 78 deletions(-) create mode 100644 pyhealth/data/cache.py create mode 100644 pyhealth/data/data_v2.py create mode 100644 pyhealth/datasets/base_dataset_v2.py create mode 100644 pyhealth/datasets/sample_dataset_v2.py create mode 100644 pyhealth/featurizers/__init__.py create mode 100644 pyhealth/featurizers/image.py create mode 100644 pyhealth/featurizers/value.py create mode 100644 pyhealth/tasks/medical_coding.py create mode 100644 pyhealth/tasks/task_template.py create mode 100644 test.py diff --git a/pyhealth/data/__init__.py b/pyhealth/data/__init__.py index cae6153d..92361d68 100755 --- a/pyhealth/data/__init__.py +++ b/pyhealth/data/__init__.py @@ -3,3 +3,9 @@ Visit, Patient, ) + + +# from .data_v2 import ( +# Event, +# Patient, +# ) \ No newline at end of file diff --git a/pyhealth/data/cache.py b/pyhealth/data/cache.py new file mode 100644 index 00000000..967e956e --- /dev/null +++ b/pyhealth/data/cache.py @@ -0,0 +1,126 @@ +import msgpack +import os +from typing import Dict, List, Any +from pyhealth.data.data_v2 import Patient, Event +from datetime import datetime +def patient_default(obj): + if isinstance(obj, Patient): + return { + "__patient__": True, + "patient_id": obj.patient_id, + "birth_datetime": obj.birth_datetime.isoformat() if obj.birth_datetime else None, + "death_datetime": obj.death_datetime.isoformat() if obj.death_datetime else None, + "gender": obj.gender, + "ethnicity": obj.ethnicity, + "attr_dict": obj.attr_dict, + "events": [event_default(event) for event in obj.events] + } + elif isinstance(obj, datetime): + return obj.isoformat() + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") + +def event_default(obj): + if isinstance(obj, Event): + return { + "__event__": True, + "code": obj.code, + "table": obj.table, + "vocabulary": obj.vocabulary, + "visit_id": obj.visit_id, + "patient_id": obj.patient_id, + "timestamp": obj.timestamp.isoformat() if obj.timestamp else None, + "item_id": obj.item_id, + "attr_dict": obj.attr_dict + } + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") + +def patient_hook(obj): + if "__patient__" in obj: + patient = Patient( + patient_id=obj["patient_id"], + birth_datetime=datetime.fromisoformat(obj["birth_datetime"]) if obj["birth_datetime"] else None, + death_datetime=datetime.fromisoformat(obj["death_datetime"]) if obj["death_datetime"] else None, + gender=obj["gender"], + ethnicity=obj["ethnicity"] + ) + patient.attr_dict = obj["attr_dict"] + for event_data in obj["events"]: + patient.add_event(event_hook(event_data)) + return patient + return obj + +def event_hook(obj): + if "__event__" in obj: + event = Event( + code=obj["code"], + table=obj["table"], + vocabulary=obj["vocabulary"], + visit_id=obj["visit_id"], + patient_id=obj["patient_id"], + timestamp=datetime.fromisoformat(obj["timestamp"]) if obj["timestamp"] else None, + item_id=obj["item_id"] + ) + event.attr_dict = obj["attr_dict"] + return event + return obj + +def write_msgpack_patients(data: Dict[str, Patient], filepath: str): + """ + Write a dictionary of Patient objects to a MessagePack file. + + Args: + data (Dict[str, Patient]): Dictionary with patient IDs as keys and Patient objects as values. + filepath (str): Path to the file where data will be written. + """ + with open(filepath, 'wb') as f: + msgpack.pack(data, f, default=patient_default) + +def read_msgpack_patients(filepath: str) -> Dict[str, Patient]: + """ + Read a dictionary of Patient objects from a MessagePack file. + + Args: + filepath (str): Path to the file to read from. + + Returns: + Dict[str, Patient]: Dictionary with patient IDs as keys and Patient objects as values. + """ + with open(filepath, 'rb') as f: + data = msgpack.unpack(f, object_hook=patient_hook) + return data + + +def write_msgpack(data: Dict[str, Any], filepath: str) -> None: + """ + Write a dictionary to a MessagePack file. + + Args: + data (Dict[str, Any]): The dictionary to be written. + filepath (str): The path to the file where data will be written. + """ + # Ensure the directory exists + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + # Open the file in binary write mode + with open(filepath, "wb") as f: + # Pack and write the data + packed = msgpack.packb(data, use_bin_type=True) + f.write(packed) + +def read_msgpack(filepath: str) -> Dict[str, Any]: + """ + Read a dictionary from a MessagePack file. + + Args: + filepath (str): The path to the file to be read. + + Returns: + Dict[str, Any]: The dictionary read from the file. + """ + # Open the file in binary read mode + with open(filepath, "rb") as f: + # Read and unpack the data + packed = f.read() + unpacked = msgpack.unpackb(packed, raw=False) + + return unpacked \ No newline at end of file diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index a8d9090e..464b6815 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -453,4 +453,4 @@ def __str__(self): for visit in self: visit_str = str(visit).replace("\n", "\n\t") lines.append(f"\t- {visit_str}") - return "\n".join(lines) + return "\n".join(lines) \ No newline at end of file diff --git a/pyhealth/data/data_v2.py b/pyhealth/data/data_v2.py new file mode 100644 index 00000000..b06012ca --- /dev/null +++ b/pyhealth/data/data_v2.py @@ -0,0 +1,198 @@ +from collections import OrderedDict +from datetime import datetime +from typing import Optional, List + +class Event: + """Contains information about a single event. + + An event can be anything from a diagnosis to a prescription or a lab test + that happened in a visit of a patient at a specific time. + + Args: + code: code of the event. E.g., "428.0" for congestive heart failure. + table: name of the table where the event is recorded. This corresponds + to the raw csv file name in the dataset. E.g., "DIAGNOSES_ICD". + vocabulary: vocabulary of the code. E.g., "ICD9CM" for ICD-9 diagnosis codes. + visit_id: unique identifier of the visit. + patient_id: unique identifier of the patient. + timestamp: timestamp of the event. Default is None. + event_type: type of the event. + item_id: unique identifier of the item. + **attr: optional attributes to add to the event as key=value pairs. + + Attributes: + attr_dict: Dict, dictionary of event attributes. Each key is an attribute + name and each value is the attribute's value. + """ + + def __init__( + self, + code: str = None, + table: str = None, + vocabulary: str = None, + visit_id: str = None, + patient_id: str = None, + timestamp: Optional[datetime] = None, + item_id: str = None, + **attr, + ): + assert timestamp is None or isinstance( + timestamp, datetime + ), "timestamp must be a datetime object" + self.code = code + self.table = table + self.vocabulary = vocabulary + self.visit_id = visit_id # we can remove the explicity of it if need be. + self.patient_id = patient_id + self.timestamp = timestamp + self.item_id = item_id + self.attr_dict = dict() # Event(...,visit_id=), don't make it explicit + self.attr_dict.update(attr) + + def __repr__(self): + return f"Event with {self.vocabulary} code {self.code} from table {self.table}" + + def __str__(self): + lines = list() + lines.append(f"Event from patient {self.patient_id} visit {self.visit_id}:") + lines.append(f"\t- Code: {self.code}") + lines.append(f"\t- Table: {self.table}") + lines.append(f"\t- Vocabulary: {self.vocabulary}") + lines.append(f"\t- Timestamp: {self.timestamp}") + lines.append(f"\t- Item ID: {self.item_id}") + for k, v in self.attr_dict.items(): + lines.append(f"\t- {k}: {v}") + return "\n".join(lines) + +class Patient: + """Contains information about a single patient. + + A patient is a person who is admitted at least once to a hospital or + a specific department. Each patient is associated with a list of events. + + Args: + patient_id: unique identifier of the patient. + birth_datetime: timestamp of patient's birth. Default is None. + death_datetime: timestamp of patient's death. Default is None. + gender: gender of the patient. Default is None. + ethnicity: ethnicity of the patient. Default is None. + **attr: optional attributes to add to the patient as key=value pairs. + + Attributes: + attr_dict: Dict, dictionary of patient attributes. Each key is an attribute + name and each value is the attribute's value. + events: Dict[str, List[Event]], dictionary of event lists. + Each key is a table name and each value is a list of events from that table. + """ + + def __init__( + self, + patient_id: str, + birth_datetime: Optional[datetime] = None, + death_datetime: Optional[datetime] = None, + gender=None, + ethnicity=None, + **attr, + ): + self.patient_id = patient_id + self.birth_datetime = birth_datetime + self.death_datetime = death_datetime + self.gender = gender + self.ethnicity = ethnicity + self.attr_dict = dict() + self.attr_dict.update(attr) + self.events = [] # Nested Dataframe -> PyArrow? + + def add_event(self, event: Event) -> None: + """Adds an event to the patient. + + If the event's table is not in the patient's event dictionary, it is + added as a new key. The event is then added to the list of events of + that table. + + Args: + event: event to add. + """ + assert event.patient_id == self.patient_id, "patient_id unmatched" + # table = event.table + # if table not in self.events: + # self.events[table] = list() + # self.events[table].append(event) + self.events.append(event) # + + def get_event_list(self, table: str) -> List[Event]: + """Returns a list of events from a specific table. + + If the table is not in the patient's event dictionary, an empty list + is returned. + + Args: + table: name of the table. + + Returns: + List of events from the specified table. + """ + return [event for event in self.events if event.table == table] + + def get_code_list( + self, table: str, remove_duplicate: Optional[bool] = True + ) -> List[str]: + """Returns a list of codes from a specific table. + + If the table is not in the patient's event dictionary, an empty list + is returned. + + Args: + table: name of the table. + remove_duplicate: whether to remove duplicate codes + (but keep the relative order). Default is True. + + Returns: + List of codes from the specified table. + """ + event_list = self.get_event_list(table) + code_list = [event.code for event in event_list] + if remove_duplicate: + # remove duplicate codes but keep the order + code_list = list(dict.fromkeys(code_list)) + return code_list + + @property + def available_tables(self) -> List[str]: + """Returns a list of available tables for the patient. + + Returns: + List of available tables. + """ + tables = set() + for event in self.events: + tables.add(event.table) + return list(tables) + + @property + def num_events(self) -> int: + """Returns the total number of events for the patient. + + Returns: + Total number of events. + """ + return len(self.events) + + def __repr__(self): + return f"Patient {self.patient_id} with {self.num_events} events" + + def __str__(self): + lines = list() + lines.append(f"Patient {self.patient_id} with {self.num_events} events:") + lines.append(f"\t- Birth datetime: {self.birth_datetime}") + lines.append(f"\t- Death datetime: {self.death_datetime}") + lines.append(f"\t- Gender: {self.gender}") + lines.append(f"\t- Ethnicity: {self.ethnicity}") + lines.append(f"\t- Available tables: {self.available_tables}") + for k, v in self.attr_dict.items(): + lines.append(f"\t- {k}: {v}") + for event in self.events: + + event_str = str(event).replace("\n", "\n\t") + lines.append(f"\t- {event_str}") + return "\n".join(lines) \ No newline at end of file diff --git a/pyhealth/datasets/base_dataset_v2.py b/pyhealth/datasets/base_dataset_v2.py new file mode 100644 index 00000000..36de696a --- /dev/null +++ b/pyhealth/datasets/base_dataset_v2.py @@ -0,0 +1,144 @@ +import logging +import time +import os +from abc import ABC, abstractmethod +from collections import Counter +from copy import deepcopy +from typing import Dict, Callable, Tuple, Union, List, Optional +# from pyhealth.datasets.sample_dataset import SampleDataset +from pyhealth.datasets.utils import MODULE_CACHE_PATH, DATASET_BASIC_TABLES +from pyhealth.datasets.utils import hash_str +from pyhealth.medcode import CrossMap +from pyhealth.utils import load_pickle, save_pickle +from pyhealth.data.cache import read_msgpack, read_msgpack_patients, write_msgpack, write_msgpack_patients # better to use msgpack than pickle +logger = logging.getLogger(__name__) + +INFO_MSG = """ +dataset.patients: patient_id -> + + + - events: List[Event] + - other patient-level info + + - code: str + - other event-level info +""" + + +# TODO: parse_tables is too slow + +# Let's add our twist, because we have to define some type of tables even if there aren't any. +class BaseDataset(ABC): + """Abstract base dataset class.""" + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + tables : List[str] = None, + additional_dirs : Optional[Dict[str, str]] = {} , + dev : bool = False, + refresh_cache : bool = False, + **kwargs, + ): + if dataset_name is None: + dataset_name = self.__class__.__name__ + self.root = root + self.dataset_name = dataset_name + logger.debug(f"Processing {self.dataset_name} base dataset...") + self.tables = tables + self.tables_dir = {table: None for table in tables} # root + if additional_dirs: + self.tables.extend(additional_dirs.keys()) + self.tables_dir.update(additional_dirs) + self.tables_dir = {table: None for table in tables} + self.tables_dir.update(additional_dirs) + + self.dev = dev + + # TODO: cache -> problem: It can be dataset specific in the sense that they might have unique args. + # hash filename for cache + self.filepath = self.get_cache_path() + # we should use messagepack + # check if cache exists or refresh_cache is True + if os.path.exists(self.filepath) and (not refresh_cache): + # load from cache + logger.debug( + f"Loaded {self.dataset_name} base dataset from {self.filepath}" + ) + try: + self.patients = read_msgpack_patients(self.filepath) + except: + raise ValueError("Please refresh your cache by set refresh_cache=True") + + else: + # load from raw data + logger.debug(f"Processing {self.dataset_name} base dataset...") + # parse tables + self.patients = self.process() + # save to cache + logger.debug(f"Saved {self.dataset_name} base dataset to {self.filepath}") + write_msgpack_patients(self.patients, self.filepath) + + # return + + def __str__(self): + return f"Base dataset {self.dataset_name}" + + def __len__(self): + return len(self.patients) + + # Essentially, every dataset should have both a unique cache path and process method. + @abstractmethod + def get_cache_path(self) -> str: + args_to_hash = ( + [self.dataset_name, self.root] + + sorted(self.tables) + # + sorted(self.code_mapping.items()) + + ["dev" if self.dev else "prod"] + + sorted([(k, v) for k, v in self.tables_dir.items()]) + ) + filename = hash_str("+".join([str(arg) for arg in args_to_hash])) + ".msgpack" + return filename + # raise NotImplementedError + + @abstractmethod + def process(self) -> Dict: + raise NotImplementedError + + @abstractmethod + def stat(self): + print(f"Statistics of {self.dataset_name}:") + return + + # @property + # def default_task(self) -> Optional[TaskTemplate]: + # return None + + # def set_task(self, task: Optional[TaskTemplate] = None) -> SampleDataset: + # """Processes the base dataset to generate the task-specific sample dataset. + # """ + # # TODO: cache? + # if task is None: + # # assert default tasks exist in attr + # assert self.default_task is not None, "No default tasks found" + # task = self.default_task + + # # load from raw data + # logger.debug(f"Setting task for {self.dataset_name} base dataset...") + + # samples = [] + # for patient_id, patient in tqdm( + # self.patients.items(), desc=f"Generating samples for {task.task_name}" + # ): + # samples.extend(task(patient)) + + # sample_dataset = SampleDataset( + # samples, + # input_schema=task.input_schema, + # output_schema=task.output_schema, + # dataset_name=self.dataset_name, + # task_name=task, + # ) + # return sample_dataset + diff --git a/pyhealth/datasets/base_ehr_dataset.py b/pyhealth/datasets/base_ehr_dataset.py index e5760b70..99cdb6f0 100644 --- a/pyhealth/datasets/base_ehr_dataset.py +++ b/pyhealth/datasets/base_ehr_dataset.py @@ -78,7 +78,7 @@ def __init__( refresh_cache: bool = False, ): """Loads tables into a dict of patients and saves it to cache.""" - + print("LOL I AM BEING INHERITED!!!") if code_mapping is None: code_mapping = {} diff --git a/pyhealth/datasets/eicu.py b/pyhealth/datasets/eicu.py index 0c46e989..8976bffe 100644 --- a/pyhealth/datasets/eicu.py +++ b/pyhealth/datasets/eicu.py @@ -5,11 +5,11 @@ from tqdm import tqdm from datetime import datetime -from pyhealth.data import Event, Visit, Patient +from pyhealth.data import Event, Patient from pyhealth.datasets import BaseEHRDataset from pyhealth.datasets.utils import strptime, padyear -# TODO: add other tables +# TODO: add other tables, change to Patient -> Event structure class eICUDataset(BaseEHRDataset): diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index e26e60ec..0ec181e8 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -1,16 +1,22 @@ import os from typing import Optional, List, Dict, Tuple, Union - +import time import pandas as pd - -from pyhealth.data import Event, Visit, Patient -from pyhealth.datasets import BaseEHRDataset +from pandarallel import pandarallel +from tqdm import tqdm +from pyhealth.data.data_v2 import Event, Patient +# from pyhealth.datasets import BaseEHRDataset +from pyhealth.datasets.base_dataset_v2 import BaseDataset +from pyhealth.tasks.medical_coding import MIMIC3ICD9Coding from pyhealth.datasets.utils import strptime - +from pyhealth.medcode import CrossMap +from copy import deepcopy +from pyhealth.datasets.utils import MODULE_CACHE_PATH, DATASET_BASIC_TABLES +from pyhealth.datasets.utils import hash_str # TODO: add other tables -class MIMIC3Dataset(BaseEHRDataset): +class MIMIC3Dataset(BaseDataset): """Base dataset for MIMIC-III dataset. The MIMIC-III dataset is a large dataset of de-identified health records of ICU @@ -27,6 +33,7 @@ class MIMIC3Dataset(BaseEHRDataset): for patients. - LABEVENTS: contains laboratory measurements (MIMIC3_ITEMID code) for patients + - NOTEEVENTS: contains discharge summaries Args: dataset_name: name of the dataset. @@ -67,6 +74,68 @@ class MIMIC3Dataset(BaseEHRDataset): >>> dataset.stat() >>> dataset.info() """ + def __init__(self, root: str, + dataset_name: Optional[str] = [], + tables : List[str] = None, + additional_dirs : Optional[Dict[str, str]] = {}, + code_mapping: Optional[Dict[str, Union[str, Tuple[str, Dict]]]] = None, + **kwargs): + self.code_mapping = code_mapping + self.code_mapping_tools = self._load_code_mapping_tools() + self.code_vocs = {} + super().__init__(root, dataset_name, tables, additional_dirs, **kwargs) + + + def get_cache_path(self): + args_to_hash = ( + [self.dataset_name, self.root] + + sorted(self.tables) + + sorted(self.code_mapping.items()) + + ["dev" if self.dev else "prod"] + + sorted([(k, v) for k, v in self.tables_dir.items()]) + ) + filename = hash_str("+".join([str(arg) for arg in args_to_hash])) + ".pkl" + filepath = os.path.join(MODULE_CACHE_PATH, filename) # questions here (?) + return filepath + + def parse_tables(self) -> Dict[str, Patient]: + """Parses the tables in `self.tables` and return a dict of patients. + + Will be called in `self.__init__()` if cache file does not exist or + refresh_cache is True. + + This function will first call `self.parse_basic_info()` to parse the + basic patient information, and then call `self.parse_[table_name]()` to + parse the table with name `table_name`. Both `self.parse_basic_info()` and + `self.parse_[table_name]()` should be implemented in the subclass. + + Returns: + A dict mapping patient_id to `Patient` object. + """ + pandarallel.initialize(progress_bar=False) + + patients: Dict[str, Patient] = dict() + tic = time.time() + patients = self.parse_basic_info(patients) + print(f"finish basic patient information parsing : {time.time() - tic}s") + + for table in self.tables: + try: + tic = time.time() + parse_method = getattr(self, f"parse_{table.lower()}") + patients = parse_method(patients) + print(f"finish parsing {table} : {time.time() - tic}s") + except AttributeError: + raise NotImplementedError(f"Parser for table {table} is not implemented yet.") + + return patients + + def process(self): + # process the data + patients = self.parse_tables() + # convert codes + patients = self._convert_code_in_patient_dict(patients) + return patients def parse_basic_info(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: """Helper function which parses PATIENTS and ADMISSIONS tables. @@ -110,20 +179,10 @@ def basic_unit(p_id, p_info): gender=p_info["GENDER"].values[0], ethnicity=p_info["ETHNICITY"].values[0], ) - # load visits - for v_id, v_info in p_info.groupby("HADM_ID"): - visit = Visit( - visit_id=v_id, - patient_id=p_id, - encounter_time=strptime(v_info["ADMITTIME"].values[0]), - discharge_time=strptime(v_info["DISCHTIME"].values[0]), - discharge_status=v_info["HOSPITAL_EXPIRE_FLAG"].values[0], - ) - # add visit - patient.add_visit(visit) + # Remove the loop that created Visit objects return patient - # parallel apply + # parallel apply, df_group = df_group.parallel_apply( lambda x: basic_unit(x.SUBJECT_ID.unique()[0], x) ) @@ -177,7 +236,7 @@ def diagnosis_unit(p_id, p_info): table=table, vocabulary="ICD9CM", visit_id=v_id, - patient_id=p_id, + patient_id=p_id ) events.append(event) return events @@ -235,7 +294,7 @@ def procedure_unit(p_id, p_info): table=table, vocabulary="ICD9PROC", visit_id=v_id, - patient_id=p_id, + patient_id=p_id ) events.append(event) return events @@ -293,7 +352,7 @@ def prescription_unit(p_id, p_info): vocabulary="NDC", visit_id=v_id, patient_id=p_id, - timestamp=strptime(timestamp), + timestamp=strptime(timestamp) ) events.append(event) return events @@ -348,7 +407,7 @@ def lab_unit(p_id, p_info): vocabulary="MIMIC3_ITEMID", visit_id=v_id, patient_id=p_id, - timestamp=strptime(timestamp), + timestamp=strptime(timestamp) ) events.append(event) return events @@ -362,30 +421,272 @@ def lab_unit(p_id, p_info): patients = self._add_events_to_patient_dict(patients, group_df) return patients + def parse_noteevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: + """Helper function which parses NOTEEVENTS table. -if __name__ == "__main__": - dataset = MIMIC3Dataset( - root="https://storage.googleapis.com/pyhealth/mimiciii-demo/1.4/", - tables=[ - "DIAGNOSES_ICD", - "PROCEDURES_ICD", - "PRESCRIPTIONS", - "LABEVENTS", - ], - code_mapping={"NDC": "ATC"}, - dev=True, - refresh_cache=True, - ) - dataset.stat() - dataset.info() + Will be called in `self.parse_tables()` + + Docs: + - NOTEEVENTS: https://mimic.mit.edu/docs/iii/tables/noteevents/ + + Args: + patients: a dict of `Patient` objects indexed by patient_id. + + Returns: + The updated patients dict. + """ + table = "NOTEEVENTS" + # read table + df = pd.read_csv( + os.path.join(self.root, f"{table}.csv"), + dtype={"SUBJECT_ID": str, "HADM_ID": str}, + ) + # drop records of the other patients + df = df[df["SUBJECT_ID"].isin(patients.keys())] + # drop rows with missing values + df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "TEXT"]) + # sort by charttime + df = df.sort_values(["SUBJECT_ID", "HADM_ID", "CHARTTIME"], ascending=True) + # group by patient and visit + group_df = df.groupby("SUBJECT_ID") + # parallel unit for note (per patient) + def note_unit(p_id, p_info): + events = [] + for v_id, v_info in p_info.groupby("HADM_ID"): + for _, row in v_info.iterrows(): + event = Event( + code=row["TEXT"], + table=table, + vocabulary="note", + visit_id=v_id, + patient_id=p_id, + timestamp=strptime(row["CHARTTIME"]) + ) + events.append(event) + return events + + # parallel apply + group_df = group_df.parallel_apply( + lambda x: note_unit(x.SUBJECT_ID.unique()[0], x) + ) + + # summarize the results + patients = self._add_events_to_patient_dict(patients, group_df) + return patients + + + def _convert_code_in_event(self, event: Event) -> List[Event]: + """Helper function which converts the code for a single event. + + Note that an event may be mapped to multiple events after code conversion. + + Will be called in `self._convert_code_in_patient()`. + + Args: + event: an `Event` object. + + Returns: + A list of `Event` objects after code conversion. + """ + src_vocab = event.vocabulary + if src_vocab in self.code_mapping: + target = self.code_mapping[src_vocab] + if isinstance(target, tuple): + tgt_vocab, kwargs = target + source_kwargs = kwargs.get("source_kwargs", {}) + target_kwargs = kwargs.get("target_kwargs", {}) + else: + tgt_vocab = self.code_mapping[src_vocab] + source_kwargs = {} + target_kwargs = {} + code_mapping_tool = self.code_mapping_tools[f"{src_vocab}_{tgt_vocab}"] + mapped_code_list = code_mapping_tool.map( + event.code, source_kwargs=source_kwargs, target_kwargs=target_kwargs + ) + mapped_event_list = [deepcopy(event) for _ in range(len(mapped_code_list))] + for i, mapped_event in enumerate(mapped_event_list): + mapped_event.code = mapped_code_list[i] + mapped_event.vocabulary = tgt_vocab + + # update the code vocs + for key, value in self.code_vocs.items(): + if value == src_vocab: + self.code_vocs[key] = tgt_vocab + + return mapped_event_list + # TODO: should normalize the code here + return [event] + + + def _convert_code_in_patient(self, patient: Patient) -> Patient: + """Helper function which converts the codes for a single patient. + + Will be called in `self._convert_code_in_patient_dict()`. + + Args: + patient:a `Patient` object. + + Returns: + The updated `Patient` object. + """ + # for visit in patient: + # for table in visit.available_tables: + # all_mapped_events = [] + # for event in visit.get_event_list(table): + all_mapped_events = [] + for event in patient.events: + # an event may be mapped to multiple events after code conversion + mapped_events: List[Event] + mapped_events = self._convert_code_in_event(event) + all_mapped_events.extend(mapped_events) + # visit.set_event_list(table, all_mapped_events) + patient.events = all_mapped_events + return patient + + def _convert_code_in_patient_dict( + self, + patients: Dict[str, Patient], + ) -> Dict[str, Patient]: + """Helper function which converts the codes for all patients. + + The codes to be converted are specified in `self.code_mapping`. + + Will be called in `self.__init__()` after `self.parse_tables()`. + + Args: + patients: a dict mapping patient_id to `Patient` object. + + Returns: + The updated patient dict. + """ + for p_id, patient in tqdm(patients.items(), desc="Mapping codes"): + patients[p_id] = self._convert_code_in_patient(patient) + return patients + + + def _add_events_to_patient_dict( + self, + patient_dict: Dict[str, Patient], + group_df: pd.DataFrame, + ) -> Dict[str, Patient]: + """Helper function which adds the events column of a df.groupby object to the patient dict. + + Will be called at the end of each `self.parse_[table_name]()` function. + + Args: + patient_dict: a dict mapping patient_id to `Patient` object. + group_df: a df.groupby object, having two columns: patient_id and events. + - the patient_id column is the index of the patient + - the events column is a list of objects + + Returns: + The updated patient dict. + """ + for _, events in group_df.items(): + for event in events: + + patient_dict = self._add_event_to_patient_dict(patient_dict, event) + return patient_dict + + @staticmethod + def _add_event_to_patient_dict( + patient_dict: Dict[str, Patient], + event: Event, + ) -> Dict[str, Patient]: + """Helper function which adds an event to the patient dict. + + Will be called in `self._add_events_to_patient_dict`. + + Note that if the patient of the event is not in the patient dict, or the + visit of the event is not in the patient, this function will do nothing. + + Args: + patient_dict: a dict mapping patient_id to `Patient` object. + event: an event to be added to the patient dict. + + Returns: + The updated patient dict. + """ + patient_id = event.patient_id + try: + patient_dict[patient_id].add_event(event) + except KeyError: + pass + return patient_dict + + def stat(self) -> str: + """Returns some statistics of the base dataset.""" + lines = list() + lines.append("") + lines.append(f"Statistics of base dataset (dev={self.dev}):") + lines.append(f"\t- Dataset: {self.dataset_name}") + lines.append(f"\t- Number of patients: {len(self.patients)}") + # num_visits = [len(p) for p in self.patients.values()] # ask Zhenbang if it even makes sense to writ ea function like this? + # lines.append(f"\t- Number of visits: {sum(num_visits)}") + # lines.append( + # f"\t- Number of visits per patient: {sum(num_visits) / len(num_visits):.4f}" + # ) + for table in self.tables: + num_events = [ + len(p.get_event_list(table)) for p in self.patients.values() + ] + lines.append( + f"\t- Number of events in {table}: " + f"{sum(num_events) :.4f}" + ) + lines.append("") + print("\n".join(lines)) + return "\n".join(lines) + + def _load_code_mapping_tools(self) -> Dict[str, CrossMap]: + """Helper function which loads code mapping tools CrossMap for code mapping. + + Will be called in `self.__init__()`. + + Returns: + A dict whose key is the source and target code vocabulary and + value is the `CrossMap` object. + """ + code_mapping_tools = {} + for s_vocab, target in self.code_mapping.items(): + if isinstance(target, tuple): + assert len(target) == 2 + assert type(target[0]) == str + assert type(target[1]) == dict + assert target[1].keys() <= {"source_kwargs", "target_kwargs"} + t_vocab = target[0] + else: + t_vocab = target + # load code mapping from source to target + code_mapping_tools[f"{s_vocab}_{t_vocab}"] = CrossMap(s_vocab, t_vocab) + return code_mapping_tools + + +def main(): + root = "/srv/local/data/jw3/physionet.org/files/mimic-iii-clinical-database-1.4" # dataset = MIMIC3Dataset( - # root="/srv/local/data/physionet.org/files/mimiciii/1.4", - # tables=["DIAGNOSES_ICD", "PRESCRIPTIONS"], - # dev=True, - # code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, + # root=root, + # dataset_name="mimic3", + # tables=[ + # "DIAGNOSES_ICD", + # "PROCEDURES_ICD", + # "PRESCRIPTIONS", + # "LABEVENTS", + # "NOTEEVENTS" + # ], + # code_mapping={"NDC": "ATC"}, + # dev=False, # refresh_cache=False, # ) - # print(dataset.stat()) - # print(dataset.available_tables) - # print(list(dataset.patients.values())[4]) + # dataset.stat() + mimic3_coding = MIMIC3ICD9Coding(refresh_cache= False) + print(len(mimic3_coding.samples)) + sample_dataset = mimic3_coding.to_torch_dataset() + # print(sample_dataset[0]) + + + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 99c950da..31896b38 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -1,16 +1,22 @@ import os -from typing import Optional, List, Dict, Union, Tuple - +import time import pandas as pd - -from pyhealth.data import Event, Visit, Patient -from pyhealth.datasets import BaseEHRDataset +from tqdm import tqdm +from typing import Optional, List, Dict, Union, Tuple +from pandarallel import pandarallel +from pyhealth.data.data_v2 import Event, Patient +from pyhealth.datasets.base_dataset_v2 import BaseDataset from pyhealth.datasets.utils import strptime +from pyhealth.medcode import CrossMap +from copy import deepcopy +from pyhealth.datasets.utils import MODULE_CACHE_PATH, DATASET_BASIC_TABLES +from pyhealth.datasets.utils import hash_str +from pyhealth.tasks.medical_coding import MIMIC4ICD9Coding +# TODO: add other tables, pyspark or pyarrow for preprocessing. -# TODO: add other tables +class MIMIC4Dataset(BaseDataset): -class MIMIC4Dataset(BaseEHRDataset): """Base dataset for MIMIC-IV dataset. The MIMIC-IV dataset is a large dataset of de-identified health records of ICU @@ -69,7 +75,61 @@ class MIMIC4Dataset(BaseEHRDataset): >>> dataset.stat() >>> dataset.info() """ + def __init__(self, root: str, + dataset_name: Optional[str] = "MIMIC-IV", + tables : List[str] = None, + additional_dirs : Optional[Dict[str, str]] = {}, + code_mapping: Optional[Dict[str, Union[str, Tuple[str, Dict]]]] = None, + **kwargs): + self.code_mapping = code_mapping + self.code_mapping_tools = self._load_code_mapping_tools() + self.code_vocs = {} + super().__init__(root, dataset_name, tables, additional_dirs, **kwargs) + + + def get_cache_path(self): + args_to_hash = ( + [self.dataset_name, self.root] + + sorted(self.tables) + + sorted(self.code_mapping.items()) + + ["dev" if self.dev else "prod"] + + sorted([(k, v) for k, v in self.tables_dir.items()]) + ) + filename = hash_str("+".join([str(arg) for arg in args_to_hash])) + ".pkl" + filepath = os.path.join(MODULE_CACHE_PATH, filename) # questions here (?) + return filepath + + def process(self): + # process the data + patients = self.parse_tables() + # convert codes + patients = self._convert_code_in_patient_dict(patients) + return patients + + def _load_code_mapping_tools(self) -> Dict[str, CrossMap]: + """Helper function which loads code mapping tools CrossMap for code mapping. + + Will be called in `self.__init__()`. + + Returns: + A dict whose key is the source and target code vocabulary and + value is the `CrossMap` object. + """ + code_mapping_tools = {} + for s_vocab, target in self.code_mapping.items(): + if isinstance(target, tuple): + assert len(target) == 2 + assert type(target[0]) == str + assert type(target[1]) == dict + assert target[1].keys() <= {"source_kwargs", "target_kwargs"} + t_vocab = target[0] + else: + t_vocab = target + # load code mapping from source to target + code_mapping_tools[f"{s_vocab}_{t_vocab}"] = CrossMap(s_vocab, t_vocab) + return code_mapping_tools + def parse_basic_info(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: """Helper functions which parses patients and admissions tables. @@ -120,17 +180,7 @@ def basic_unit(p_id, p_info): ethnicity=p_info["race"].values[0], anchor_year_group=p_info["anchor_year_group"].values[0], ) - # load visits - for v_id, v_info in p_info.groupby("hadm_id"): - visit = Visit( - visit_id=v_id, - patient_id=p_id, - encounter_time=strptime(v_info["admittime"].values[0]), - discharge_time=strptime(v_info["dischtime"].values[0]), - discharge_status=v_info["hospital_expire_flag"].values[0], - ) - # add visit - patient.add_visit(visit) + return patient # parallel apply @@ -421,13 +471,320 @@ def hcpcsevents_unit(p_id, p_info): patients = self._add_events_to_patient_dict(patients, group_df) return patients + + + def parse_discharge(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: + table = "discharge" # hardcoded + df = pd.read_csv(os.path.join(self.tables_dir[table], f"{table}.csv"), + dtype={"subject_id": str, "hadm_id": str}) + df = df.dropna(subset=["subject_id", "hadm_id", "text", "charttime"]) + df = df.sort_values(["subject_id", "hadm_id"], ascending=True) + group_df = df.groupby("subject_id") + def discharge_unit(p_id, p_info): + events = [] + for v_id, v_info in p_info.groupby("hadm_id"): + for text in v_info["text"]: + event = Event( + code=text, + table=table, + vocabulary="text", + visit_id=v_id, + patient_id=p_id, + timestamp=strptime(v_info["charttime"].values[0]) + ) + events.append(event) + return events + group_df = group_df.parallel_apply( + lambda x: discharge_unit(x.subject_id.unique()[0], x) + ) + patients = self._add_events_to_patient_dict(patients, group_df) + return patients + + + def transform_study_datetime(self, date_str, time_str): + # Extract year, month, and day from date_str + year = date_str[:4] + month = date_str[4:6] + day = date_str[6:] + + # Extract hours, minutes, and seconds from time_str + time_parts = time_str.split('.') + time_main = time_parts[0].zfill(6) + hours = time_main[:2] + minutes = time_main[2:4] + seconds = time_main[4:] + + # Combine into the desired format + formatted_datetime = f"{year}-{month}-{day} {hours}:{minutes}:{seconds}" + + return formatted_datetime + + + def parse_cxr(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: + table = "cxr" + cxr_file = "mimic-cxr-2.0.0-metadata" + # hardcoded + df = pd.read_csv(os.path.join(self.tables_dir[table], f"{cxr_file}.csv"), + dtype={"subject_id": str, "hadm_id": str}) + + # combine date and time to create timestamp + df = df.dropna(subset=["subject_id", "study_id", "dicom_id"]) + df.StudyDate = df.StudyDate.astype(str) + df.StudyTime = df.StudyTime.astype(str) + # process all the dates and times + df['StudyDateTime'] = df.apply(lambda row: self.transform_study_datetime(str(row['StudyDate']), str(row['StudyTime'])), axis=1) + df = df.sort_values(["subject_id", "study_id"], ascending=True) + + group_df = df.groupby("subject_id") + + def cxr_unit(p_id, p_info): + events = [] + for v_id, v_info in p_info.groupby("study_id"): + for dicom_id, timestamp in zip(v_info["dicom_id"], v_info["StudyDateTime"]): + # print(timestamp) + event = Event( + code=dicom_id, # used for the dicom_id pathing + table=table, + vocabulary="dicom_id", + visit_id=v_id, + patient_id=p_id, + timestamp=strptime(timestamp) + ) + events.append(event) + return events + group_df = group_df.parallel_apply(lambda x: cxr_unit(x.subject_id.unique()[0], x)) + patients = self._add_events_to_patient_dict(patients, group_df) + return patients + + def parse_tables(self) -> Dict[str, Patient]: + """Parses the tables in `self.tables` and return a dict of patients. + + Will be called in `self.__init__()` if cache file does not exist or + refresh_cache is True. + + This function will first call `self.parse_basic_info()` to parse the + basic patient information, and then call `self.parse_[table_name]()` to + parse the table with name `table_name`. Both `self.parse_basic_info()` and + `self.parse_[table_name]()` should be implemented in the subclass. + + Returns: + A dict mapping patient_id to `Patient` object. + """ + pandarallel.initialize(progress_bar=False) + + patients: Dict[str, Patient] = dict() + tic = time.time() + patients = self.parse_basic_info(patients) + print(f"finish basic patient information parsing : {time.time() - tic}s") + + for table in self.tables: + try: + tic = time.time() + parse_method = getattr(self, f"parse_{table.lower()}") + patients = parse_method(patients) + print(f"finish parsing {table} : {time.time() - tic}s") + except AttributeError: + raise NotImplementedError(f"Parser for table {table} is not implemented yet.") + + return patients + + def stat(self) -> str: + """Returns some statistics of the base dataset.""" + lines = list() + lines.append("") + lines.append(f"Statistics of base dataset (dev={self.dev}):") + lines.append(f"\t- Dataset: {self.dataset_name}") + lines.append(f"\t- Number of patients: {len(self.patients)}") + # num_visits = [len(p) for p in self.patients.values()] # ask Zhenbang if it even makes sense to writ ea function like this? + # lines.append(f"\t- Number of visits: {sum(num_visits)}") + # lines.append( + # f"\t- Number of visits per patient: {sum(num_visits) / len(num_visits):.4f}" + # ) + for table in self.tables: + num_events = [ + len(p.get_event_list(table)) for p in self.patients.values() + ] + lines.append( + f"\t- Number of events in {table}: " + f"{sum(num_events) :.4f}" + ) + lines.append("") + print("\n".join(lines)) + return "\n".join(lines) + + @property + def available_tables(self) -> List[str]: + """Returns a list of available tables for the dataset. + + Returns: + List of available tables. + """ + tables = [] + for patient in self.patients.values(): + tables.extend(patient.available_tables) + return list(set(tables)) + + # util funcs + def _add_events_to_patient_dict( + self, + patient_dict: Dict[str, Patient], + group_df: pd.DataFrame, + ) -> Dict[str, Patient]: + """Helper function which adds the events column of a df.groupby object to the patient dict. + + Will be called at the end of each `self.parse_[table_name]()` function. + + Args: + patient_dict: a dict mapping patient_id to `Patient` object. + group_df: a df.groupby object, having two columns: patient_id and events. + - the patient_id column is the index of the patient + - the events column is a list of objects + + Returns: + The updated patient dict. + """ + for _, events in group_df.items(): + for event in events: + + patient_dict = self._add_event_to_patient_dict(patient_dict, event) + return patient_dict + + @staticmethod + def _add_event_to_patient_dict( + patient_dict: Dict[str, Patient], + event: Event, + ) -> Dict[str, Patient]: + """Helper function which adds an event to the patient dict. + + Will be called in `self._add_events_to_patient_dict`. + + Note that if the patient of the event is not in the patient dict, or the + visit of the event is not in the patient, this function will do nothing. + + Args: + patient_dict: a dict mapping patient_id to `Patient` object. + event: an event to be added to the patient dict. + + Returns: + The updated patient dict. + """ + patient_id = event.patient_id + try: + patient_dict[patient_id].add_event(event) + except KeyError: + pass + return patient_dict + + def _convert_code_in_event(self, event: Event) -> List[Event]: + """Helper function which converts the code for a single event. + + Note that an event may be mapped to multiple events after code conversion. + + Will be called in `self._convert_code_in_patient()`. + + Args: + event: an `Event` object. + + Returns: + A list of `Event` objects after code conversion. + """ + src_vocab = event.vocabulary + if src_vocab in self.code_mapping: + target = self.code_mapping[src_vocab] + if isinstance(target, tuple): + tgt_vocab, kwargs = target + source_kwargs = kwargs.get("source_kwargs", {}) + target_kwargs = kwargs.get("target_kwargs", {}) + else: + tgt_vocab = self.code_mapping[src_vocab] + source_kwargs = {} + target_kwargs = {} + code_mapping_tool = self.code_mapping_tools[f"{src_vocab}_{tgt_vocab}"] + mapped_code_list = code_mapping_tool.map( + event.code, source_kwargs=source_kwargs, target_kwargs=target_kwargs + ) + mapped_event_list = [deepcopy(event) for _ in range(len(mapped_code_list))] + for i, mapped_event in enumerate(mapped_event_list): + mapped_event.code = mapped_code_list[i] + mapped_event.vocabulary = tgt_vocab + + # update the code vocs + for key, value in self.code_vocs.items(): + if value == src_vocab: + self.code_vocs[key] = tgt_vocab + + return mapped_event_list + # TODO: should normalize the code here + return [event] + + + def _convert_code_in_patient(self, patient: Patient) -> Patient: + """Helper function which converts the codes for a single patient. + + Will be called in `self._convert_code_in_patient_dict()`. + + Args: + patient:a `Patient` object. + + Returns: + The updated `Patient` object. + """ + # for visit in patient: + # for table in visit.available_tables: + # all_mapped_events = [] + # for event in visit.get_event_list(table): + all_mapped_events = [] + for event in patient.events: + # an event may be mapped to multiple events after code conversion + mapped_events: List[Event] + mapped_events = self._convert_code_in_event(event) + all_mapped_events.extend(mapped_events) + # visit.set_event_list(table, all_mapped_events) + patient.events = all_mapped_events + return patient + + def _convert_code_in_patient_dict( + self, + patients: Dict[str, Patient], + ) -> Dict[str, Patient]: + """Helper function which converts the codes for all patients. + + The codes to be converted are specified in `self.code_mapping`. + + Will be called in `self.__init__()` after `self.parse_tables()`. + + Args: + patients: a dict mapping patient_id to `Patient` object. + + Returns: + The updated patient dict. + """ + for p_id, patient in tqdm(patients.items(), desc="Mapping codes"): + patients[p_id] = self._convert_code_in_patient(patient) + return patients + + + +def main(): + # dataset = MIMIC4Dataset( + # root="/srv/local/data/jw3/physionet.org/files/MIMIC-IV/2.0/hosp", + # tables=["diagnoses_icd","procedures_icd"], + # code_mapping={"NDC": "ATC"}, + # refresh_cache=False, + # dev=False, + # additional_dirs={"discharge" : "/srv/local/data/jw3/physionet.org/files/MIMIC-IV/2.0/note", + # "cxr" : "/srv/local/data/jw3/physionet.org/files/MIMIC-CXR"} + # ) + # dataset.stat() + + # print(dataset.available_tables) + task = MIMIC4ICD9Coding(dataset=None, refresh_cache=False) + print(len(task.samples)) + print(task.samples[0]["icd_codes"]) + # task.process() + + # sample_dataset = task.to_torch_dataset() + if __name__ == "__main__": - dataset = MIMIC4Dataset( - root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", - tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents", "hcpcsevents"], - code_mapping={"NDC": "ATC"}, - refresh_cache=False, - ) - dataset.stat() - dataset.info() + main() diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 1b7d28ae..bb8edc0f 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -6,7 +6,6 @@ from pyhealth.datasets.utils import list_nested_levels, flatten_list - class SampleBaseDataset(Dataset): """Sample base dataset class. diff --git a/pyhealth/datasets/sample_dataset_v2.py b/pyhealth/datasets/sample_dataset_v2.py new file mode 100644 index 00000000..09bfd23b --- /dev/null +++ b/pyhealth/datasets/sample_dataset_v2.py @@ -0,0 +1,140 @@ +from typing import Dict, List, Optional + +from torch.utils.data import Dataset + +import sys +sys.path.append('.') + +from pyhealth.featurizers import ImageFeaturizer, ValueFeaturizer + + +class SampleDataset(Dataset): + """Sample dataset class. + """ + + def __init__( + self, + samples: List[Dict], + input_schema: Dict[str, str], + output_schema: Dict[str, str], + dataset_name: Optional[str] = None, + task_name: Optional[str] = None, + ): + if dataset_name is None: + dataset_name = "" + if task_name is None: + task_name = "" + self.samples = samples + self.input_schema = input_schema + self.output_schema = output_schema + self.dataset_name = dataset_name + self.task_name = task_name + self.transform = None + # TODO: get rid of input_info + # self.input_info: Dict = self.validate() + # # self.build() + + # def validate(self): + # input_keys = set(self.input_schema.keys()) + # output_keys = set(self.output_schema.keys()) + # for s in self.samples: + # assert input_keys.issubset(s.keys()), \ + # "Input schema does not match samples." + # assert output_keys.issubset(s.keys()), \ + # "Output schema does not match samples." + # input_info = {} + # # get label signal info + # input_info["label"] = {"type": str, "dim": 0} + # return input_info + + def build(self): + for k, v in self.input_schema.items(): + if v == "image": + self.input_schema[k] = ImageFeaturizer() + else: + self.input_schema[k] = ValueFeaturizer() + for k, v in self.output_schema.items(): + if v == "image": + self.output_schema[k] = ImageFeaturizer() + else: + self.output_schema[k] = ValueFeaturizer() + return + + def __getitem__(self, index) -> Dict: + """Returns a sample by index. + + Returns: + Dict, a dict with patient_id, visit_id/record_id, and other task-specific + attributes as key. Conversion to index/tensor will be done + in the model. + """ + out = {} + for k, v in self.samples[index].items(): + if k in self.input_schema: + out[k] = self.input_schema[k].encode(v) + elif k in self.output_schema: + out[k] = self.output_schema[k].encode(v) + else: + out[k] = v + + if self.transform is not None: + out = self.transform(out) + + return out + + def set_transform(self, transform): + """Sets the transform for the dataset. + + Args: + transform: a callable transform function. + """ + self.transform = transform + return + + def __str__(self): + """Prints some information of the dataset.""" + return f"Sample dataset {self.dataset_name} {self.task_name}" + + def __len__(self): + """Returns the number of samples in the dataset.""" + return len(self.samples) + + +if __name__ == "__main__": + samples = [ + { + "id": "0", + "single_vector": [1, 2, 3], + "list_codes": ["505800458", "50580045810", "50580045811"], + "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], + "list_list_codes": [ + ["A05B", "A05C", "A06A"], + ["A11D", "A11E"] + ], + "list_list_vectors": [ + [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], + [[7.7, 8.5, 9.4]], + ], + "image": "data/COVID-19_Radiography_Dataset/Normal/images/Normal-6335.png", + "text": "This is a sample text", + "label": 1, + }, + ] + + dataset = SampleDataset( + samples=samples, + input_schema={ + "id": "str", + "single_vector": "vector", + "list_codes": "list", + "list_vectors": "list", + "list_list_codes": "list", + "list_list_vectors": "list", + "image": "image", + "text": "text", + }, + output_schema={ + "label": "label" + } + ) + print(dataset[0]) \ No newline at end of file diff --git a/pyhealth/featurizers/__init__.py b/pyhealth/featurizers/__init__.py new file mode 100644 index 00000000..7ad6268b --- /dev/null +++ b/pyhealth/featurizers/__init__.py @@ -0,0 +1,2 @@ +from .image import ImageFeaturizer +from .value import ValueFeaturizer \ No newline at end of file diff --git a/pyhealth/featurizers/image.py b/pyhealth/featurizers/image.py new file mode 100644 index 00000000..95d407d3 --- /dev/null +++ b/pyhealth/featurizers/image.py @@ -0,0 +1,22 @@ +import PIL.Image +import torchvision.transforms as transforms + + +class ImageFeaturizer: + + def __init__(self): + self.transform = transforms.Compose([transforms.ToTensor()]) + + def encode(self, value): + image = PIL.Image.open(value) + image.load() # to avoid "Too many open files" errors + image = self.transform(image) + return image + + +if __name__ == "__main__": + sample_image = "/srv/local/data/COVID-19_Radiography_Dataset/Normal/images/Normal-6335.png" + featurizer = ImageFeaturizer() + print(featurizer) + print(type(featurizer)) + print(featurizer.encode(sample_image)) \ No newline at end of file diff --git a/pyhealth/featurizers/value.py b/pyhealth/featurizers/value.py new file mode 100644 index 00000000..8c1eb2a5 --- /dev/null +++ b/pyhealth/featurizers/value.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + + +@dataclass +class ValueFeaturizer: + + def encode(self, value): + return value + + +if __name__ == "__main__": + featurizer = ValueFeaturizer() + print(featurizer) + print(featurizer.encode(2)) \ No newline at end of file diff --git a/pyhealth/tasks/drug_recommendation.py b/pyhealth/tasks/drug_recommendation.py index d423d51d..616b2fa8 100644 --- a/pyhealth/tasks/drug_recommendation.py +++ b/pyhealth/tasks/drug_recommendation.py @@ -1,4 +1,4 @@ -from pyhealth.data import Patient, Visit +from pyhealth.data import Patient def drug_recommendation_mimic3_fn(patient: Patient): diff --git a/pyhealth/tasks/medical_coding.py b/pyhealth/tasks/medical_coding.py new file mode 100644 index 00000000..c2743b48 --- /dev/null +++ b/pyhealth/tasks/medical_coding.py @@ -0,0 +1,176 @@ +import logging +from tqdm import tqdm +from pyhealth.data.data_v2 import Patient, Event +from pyhealth.tasks.task_template import TaskTemplate +from pyhealth.datasets.sample_dataset_v2 import SampleDataset +from pyhealth.datasets.base_dataset_v2 import BaseDataset + +logger = logging.getLogger(__name__) +class MIMIC3ICD9Coding(TaskTemplate): + def __init__(self, dataset: BaseDataset = None, cache_dir: str = "./cache", refresh_cache: bool = False): + super().__init__( + task_name="mimic3_icd9_coding", + input_schema={"text": "str"}, + output_schema={"icd_codes": "List[str]"}, + dataset=dataset, + cache_dir=cache_dir, + refresh_cache=refresh_cache + ) + + def process(self): + logger.debug(f"Setting task for {self.dataset.dataset_name} base dataset...") + samples = [] + for patient_id, patient in tqdm( + self.dataset.patients.items(), desc=f"Generating samples for {self.task_name}" + ): + samples.extend(self.sample(patient)) + return samples + + def sample(self, patient: Patient): + text = "" + icd_codes = set() + for event in patient.events: + if event.table == "NOTEEVENTS": + text += event.code + if event.table == "DIAGNOSES_ICD": + icd_codes.add(event.code) + if event.table == "PROCEDURES_ICD": + icd_codes.add(event.code) + if text == "" or len(icd_codes) < 1: + return [] + return [{"text": text, "icd_codes": list(icd_codes)}] + + +class MIMIC4ICD9Coding(TaskTemplate): + def __init__(self, dataset: BaseDataset = None, cache_dir: str = "./cache", refresh_cache: bool = False): + super().__init__( + task_name="mimic4_icd9_coding", + input_schema={"text": "str"}, + output_schema={"icd_codes": "List[str]"}, + dataset=dataset, + cache_dir=cache_dir, + refresh_cache=refresh_cache + ) + + def sample(self, patient: Patient): + text = "" + icd_codes = set() + for event in patient.events: + if event.table == "discharge": + text += event.code + if event.vocabulary == "ICD9CM": + if event.table == "diagnoses_icd": + icd_codes.add(event.code) + if event.table == "procedures_icd": + icd_codes.add(event.code) + if text == "" or len(icd_codes) < 1: + return [] + return [{"text": text, "icd_codes": list(icd_codes)}] + + def process(self): + # load from raw data + logger.debug(f"Setting task for {self.dataset.dataset_name} base dataset...") + + samples = [] + for patient_id, patient in tqdm( + self.dataset.patients.items(), desc=f"Generating samples for {self.task_name}" + ): + samples.extend(self.sample(patient)) + return samples + +class MIMIC4ICD10Coding(TaskTemplate): + def __init__(self, dataset: BaseDataset = None, cache_dir: str = "./cache", refresh_cache: bool = False): + super().__init__( + task_name="mimic4_icd10_coding", + input_schema={"text": "str"}, + output_schema={"icd_codes": "List[str]"}, + dataset=dataset, + cache_dir=cache_dir, + refresh_cache=refresh_cache + ) + + + def sample(self, patient: Patient): + text = "" + icd_codes = set() + for event in patient.events: + if event.table == "discharge": + text += event.code + if event.vocabulary == "ICD10CM": + if event.table == "diagnoses_icd": + icd_codes.add(event.code) + if event.table == "procedures_icd": + icd_codes.add(event.code) + if text == "" or len(icd_codes) < 1: + return [] + return [{"text": text, "icd_codes": list(icd_codes)}] + + def process(self): + # load from raw data + logger.debug(f"Setting task for {self.dataset.dataset_name} base dataset...") + + samples = [] + for patient_id, patient in tqdm( + self.dataset.patients.items(), desc=f"Generating samples for {self.task_name}" + ): + samples.extend(self.sample(patient)) + return samples + + # def to_torch_dataset(self): + + + # sample_dataset = SampleDataset( + # samples, + # input_schema=self.input_schema, + # output_schema=self.output_schema, + # dataset_name=self.dataset.dataset_name, + # task_name=self.task_name, + # ) + # return sample_dataset + +# def icd9_coding_mimic3_fn(patient : Patient): +# text = "" +# icd_codes = [] +# for event in patient.events: +# if event.table == "NOTEEVENTS": +# text += event.code +# if event.table == "DIAGNOSES_ICD": +# icd_codes.append(event.code) +# if event.table == "PROCEDURES_ICD": +# icd_codes.append(event.code) +# if text == "" or len(icd_codes) < 1: +# return [] +# return [{"text": text, "icd_codes": icd_codes}] + + +# def icd9_coding_mimic4_fn(patient : Patient): +# text = "" +# icd_codes = [] # all notes are essentially concatenated into 1 with every code in the patient's events +# for event in patient.events: +# if event.table == "discharge": +# text += event.code +# if event.vocabulary == "ICD9CM": +# if event.table == "diagnoses_icd": +# icd_codes.append(event.code) +# if event.table == "procedures_icd": +# icd_codes.append(event.code) +# # we need to probably add some check that the text is not empty, otherwise return with [] +# if text == "" or len(icd_codes) < 1: +# return [] +# return [{"text": text, "icd_codes": icd_codes}] + +# def icd10_coding_fn(patient : Patient): +# text = "" +# icd_codes = [] +# for event in patient.events: +# if event.table == "discharge": +# text += event.code +# if event.vocabulary == "ICD10CM": +# if event.table == "diagnoses_icd": +# icd_codes.append(event.code) +# if event.table == "procedures_icd": +# icd_codes.append(event.code) +# if text == "" or len(icd_codes) < 1: +# return [] +# return [{"text": text, "icd_codes": icd_codes}] + diff --git a/pyhealth/tasks/mortality_prediction.py b/pyhealth/tasks/mortality_prediction.py index 3661540b..b2562576 100644 --- a/pyhealth/tasks/mortality_prediction.py +++ b/pyhealth/tasks/mortality_prediction.py @@ -1,4 +1,4 @@ -from pyhealth.data import Patient, Visit +from pyhealth.data import Patient def mortality_prediction_mimic3_fn(patient: Patient): diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index 9d04b5eb..e43ac5ca 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -1,4 +1,4 @@ -from pyhealth.data import Patient, Visit +from pyhealth.data import Patient # TODO: time_window cannot be passed in to base_dataset diff --git a/pyhealth/tasks/task_template.py b/pyhealth/tasks/task_template.py new file mode 100644 index 00000000..21e50dc6 --- /dev/null +++ b/pyhealth/tasks/task_template.py @@ -0,0 +1,68 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Dict, List, Any, Optional +import os +import hashlib +from pyhealth.data.cache import read_msgpack, write_msgpack +from pyhealth.datasets.sample_dataset_v2 import SampleDataset +from pyhealth.datasets.base_dataset_v2 import BaseDataset +from pyhealth.datasets.utils import hash_str +@dataclass +class TaskTemplate(ABC): + task_name: str + input_schema: Dict[str, str] + output_schema: Dict[str, str] + dataset: BaseDataset = None + cache_dir: str = "./cache" + refresh_cache: bool = False + samples: List[Any] = field(default_factory=list, init=False) + + def __post_init__(self): + self.cache_path = self.get_cache_path() + if os.path.exists(self.cache_path) and not self.refresh_cache: + try: + self.samples = read_msgpack(self.cache_path) + print(f"Loaded {self.task_name} task data from {self.cache_path}") + return + except: + print(f"Failed to load cache for {self.task_name}. Processing from scratch.") + else: + if self.dataset is None: + raise ValueError("Dataset is required when cache doesn't exist or refresh_cache is True") + self.samples = self.process() + # Save to cache + os.makedirs(os.path.dirname(self.cache_path), exist_ok=True) + write_msgpack(self.samples, self.cache_path) + print(f"Saved {self.task_name} task data to {self.cache_path}") + + def get_cache_path(self) -> str: + schema_str = f"input_{'-'.join(self.input_schema.values())}_output_{'-'.join(self.output_schema.values())}" + hash_object = hash_str(schema_str) + hash_num = int(hash_object, 16) + short_hash = str(hash_num)[-10:] + cache_filename = f"{self.task_name}_{short_hash}.msgpack" + return os.path.join(self.cache_dir, cache_filename) + + @abstractmethod + def process(self) -> List[Any]: + raise NotImplementedError + + def to_torch_dataset(self) -> SampleDataset: + dataset_name = self.dataset.dataset_name if self.dataset else "Unknown" + return SampleDataset( + self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name=dataset_name, + task_name=self.task_name, + ) + + @classmethod + def from_cache(cls, task_name: str, input_schema: Dict[str, str], output_schema: Dict[str, str], cache_dir: str = "./cache"): + task = cls(task_name, input_schema, output_schema, dataset=None, cache_dir=cache_dir) + if os.path.exists(task.cache_path): + task.samples = read_msgpack(task.cache_path) + print(f"Loaded {task.task_name} task data from {task.cache_path}") + else: + raise FileNotFoundError(f"Cache file not found for {task_name}") + return task \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 00000000..65b1fac3 --- /dev/null +++ b/test.py @@ -0,0 +1,17 @@ +import pyhealth.datasets.mimic3 as mimic3 +import pyhealth.datasets.mimic4 as mimic4 +import time + +def time_function(func, name): + start_time = time.time() + func() + end_time = time.time() + execution_time = end_time - start_time + print(f"{name} execution time: {execution_time:.2f} seconds") + +if __name__ == "__main__": + print("Starting MIMIC-III processing...") + time_function(mimic3.main, "MIMIC-III") + + print("\nStarting MIMIC-IV processing...") + time_function(mimic4.main, "MIMIC-IV") \ No newline at end of file