diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 58358241..8c0446e0 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -10,6 +10,7 @@ Datasets datasets/pyhealth.datasets.SampleEHRDataset datasets/pyhealth.datasets.SampleSignalDataset datasets/pyhealth.datasets.MIMIC3Dataset + datasets/pyhealth.datasets.MIMICExtractDataset datasets/pyhealth.datasets.MIMIC4Dataset datasets/pyhealth.datasets.eICUDataset datasets/pyhealth.datasets.OMOPDataset diff --git a/docs/api/datasets/pyhealth.datasets.MIMICExtractDataset.rst b/docs/api/datasets/pyhealth.datasets.MIMICExtractDataset.rst new file mode 100644 index 00000000..d38b9cfc --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.MIMICExtractDataset.rst @@ -0,0 +1,15 @@ +pyhealth.datasets.MIMICExtractDataset +=================================== + +The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. + +.. autoclass:: pyhealth.datasets.MIMICExtractDataset + :members: + :undoc-members: + :show-inheritance: + + + + + + \ No newline at end of file diff --git a/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst b/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst index 1c140905..108f3001 100644 --- a/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst +++ b/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst @@ -1,7 +1,7 @@ pyhealth.datasets.OMOPDataset =================================== -We can process any OMOP-CDM formatted database, refer to `doc `_ for more information. We it into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. +We can process any OMOP-CDM formatted database, refer to `doc `_ for more information. The raw data is processed into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. .. autoclass:: pyhealth.datasets.OMOPDataset :members: diff --git a/docs/api/models.rst b/docs/api/models.rst index 310d0249..817faeb0 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -15,6 +15,7 @@ We implement the following models for supporting multiple healthcare predictive models/pyhealth.models.GAMENet models/pyhealth.models.MICRON models/pyhealth.models.SafeDrug + models/pyhealth.models.MoleRec models/pyhealth.models.Deepr models/pyhealth.models.ContraWR models/pyhealth.models.SparcNet diff --git a/docs/api/models/pyhealth.models.MoleRec.rst b/docs/api/models/pyhealth.models.MoleRec.rst new file mode 100644 index 00000000..541d315f --- /dev/null +++ b/docs/api/models/pyhealth.models.MoleRec.rst @@ -0,0 +1,14 @@ +pyhealth.models.MoleRec +=================================== + +The separate callable MoleRecLayer and the complete MoleRec model. + +.. autoclass:: pyhealth.models.MoleRecLayer + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.MoleRec + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/log.rst b/docs/log.rst index 711f35fa..aec41122 100644 --- a/docs/log.rst +++ b/docs/log.rst @@ -4,66 +4,66 @@ We track the new development here: **May 9, 2023** -.. code-block:: bash +.. code-block:: rst - 1. add MIMIC-Extract dataset `#136 `_ + 1. add MIMIC-Extract dataset `#136` 2. add new maintainer members for pyhealth: Junyi Gao and Benjamin Danek **May 6, 2023** -.. code-block:: bash +.. code-block:: rst - 1. add new parser functions (admissionDx, diagnosisStrings) and prediction tasks for eICU dataset `#148 `_ + 1. add new parser functions (admissionDx, diagnosisStrings) and prediction tasks for eICU dataset `#148` **Apr 27, 2023** -.. code-block:: bash +.. code-block:: rst - 1. add MoleRec model (WWW'23) for drug recommendation `#122 `_ + 1. add MoleRec model (WWW'23) for drug recommendation `#122` **Apr 26, 2023** -.. code-block:: bash +.. code-block:: rst - 1. fix bugs in GRASP model `#141 `_ - 2. add pandas install <2 constraints `#135 `_ - 3. add hcpcsevents table process in MIMIC4 dataset `#134 `_ + 1. fix bugs in GRASP model `#141` + 2. add pandas install <2 constraints `#135` + 3. add hcpcsevents table process in MIMIC4 dataset `#134` **Apr 10, 2023** -.. code-block:: bash +.. code-block:: rst 1. fix Ambiguous datetime usage in eICU (https://github.com/sunlabuiuc/PyHealth/pull/132) **Mar 26, 2023** -.. code-block:: bash +.. code-block:: rst 1. add the entire uncertainty quantification module (https://github.com/sunlabuiuc/PyHealth/pull/111) **Feb 26, 2023** -.. code-block:: bash +.. code-block:: rst 1. add 6 EHR predictiom model: Adacare, Concare, Stagenet, TCN, Grasp, Agent **Feb 24, 2023** -.. code-block:: bash +.. code-block:: rst 1. add unittest for omop dataset - 2. add github action triggered manually, check #104 + 2. add github action triggered manually, check `#104` **Feb 19, 2023** -.. code-block:: bash +.. code-block:: rst 1. add unittest for eicu dataset 2. add ISRUC dataset (and task function) for signal learning **Feb 12, 2023** -.. code-block:: bash +.. code-block:: rst 1. add unittest for mimiciii, mimiciv 2. add SHHS datasets for sleep staging task @@ -71,7 +71,7 @@ We track the new development here: **Feb 08, 2023** -.. code-block:: bash +.. code-block:: rst 1. complete the biosignal data support, add ContraWR [1] model for general purpose biosignal classification task ([1] Yang, Chaoqi, Danica Xiao, M. Brandon Westover, and Jimeng Sun. "Self-supervised eeg representation learning for automatic sleep staging." @@ -79,46 +79,46 @@ We track the new development here: **Feb 07, 2023** -.. code-block:: bash +.. code-block:: rst 1. Support signal dataset processing and split: add SampleSignalDataset, BaseSignalDataset. Use SleepEDFcassette dataset as the first signal dataset. Use example/sleep_staging_sleepEDF_contrawr.py 2. rename the dataset/ parts: previous BaseDataset becomes BaseEHRDataset and SampleDatast becomes SampleEHRDataset. Right now, BaseDataset will be inherited by BaseEHRDataset and BaseSignalDataset. SampleBaseDataset will be inherited by SampleEHRDataset and SampleSignalDataset. **Feb 06, 2023** -.. code-block:: bash +.. code-block:: rst 1. improve readme style 2. add the pyhealth live 06 and 07 link to pyhealth live **Feb 01, 2023** -.. code-block:: bash +.. code-block:: rst 1. add unittest of PyHealth MedCode and Tokenizer **Jan 26, 2023** -.. code-block:: bash +.. code-block:: rst 1. accelerate MIMIC-IV, eICU and OMOP data loading by using multiprocessing (pandarallel) **Jan 25, 2023** -.. code-block:: bash +.. code-block:: rst 1. accelerate the MIMIC-III data loading process by using multiprocessing (pandarallel) **Jan 24, 2023** -.. code-block:: bash +.. code-block:: rst - 1. Fix the code typo in pyhealth/tasks/drug_recommendation.py for issue #71. + 1. Fix the code typo in pyhealth/tasks/drug_recommendation.py for issue `#71`. 2. update the pyhealth live schedule **Jan 22, 2023** -.. code-block:: bash +.. code-block:: rst 1. Fix the list of list of vector problem in RNN, Transformer, RETAIN, and CNN 2. Add initialization examples for RNN, Transformer, RETAIN, CNN, and Deepr @@ -128,34 +128,34 @@ We track the new development here: **Jan 21, 2023** -.. code-block:: bash +.. code-block:: rst 1. Added a new model, Deepr (models.Deepr) **Jan 20, 2023** -.. code-block:: bash +.. code-block:: rst 1. add the pyhealth live 05 2. add slack channel invitation in pyhealth live page **Jan 13, 2023** -.. code-block:: bash +.. code-block:: rst 1. add the pyhealth live 03 and 04 video link to the nagivation 2. add future pyhealth live schedule **Jan 8, 2023** -.. code-block:: bash +.. code-block:: rst 1. Changed BaseModel.add_feature_transform_layer in models/base_model.py so that it accepts special_tokens if necessary 2. fix an int/float bug in dataset checking (transform int to float and then process them uniformly) **Dec 26, 2022** -.. code-block:: bash +.. code-block:: rst 1. add examples to pyhealth.data, pyhealth.datasets 2. improve jupyter notebook tutorials 0, 1, 2 @@ -163,7 +163,7 @@ We track the new development here: **Dec 21, 2022** -.. code-block:: bash +.. code-block:: rst 1. add the development logs to the navigation 2. add the pyhealth live schedule to the nagivation diff --git a/docs/requirements.txt b/docs/requirements.txt index c7016fe4..959c5f4c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ Sphinx==5.2.3 -sphinx-automodapi> +sphinx-automodapi sphinx-autodoc-annotation sphinx_last_updated_by_git sphinxcontrib-spelling diff --git a/pyhealth/metrics/early_prediction_score.py b/pyhealth/metrics/early_prediction_score.py new file mode 100644 index 00000000..dc5e5364 --- /dev/null +++ b/pyhealth/metrics/early_prediction_score.py @@ -0,0 +1,113 @@ +import numpy as np +from typing import Dict, Optional + + +def calculate_confusion_matrix_value_result(outcome_pred, outcome_true): + outcome_pred = 1 if outcome_pred > 0.5 else 0 + if outcome_pred == 1 and outcome_true == 1: + return "tp" + elif outcome_pred == 0 and outcome_true == 0: + return "tn" + elif outcome_pred == 1 and outcome_true == 0: + return "fp" + elif outcome_pred == 0 and outcome_true == 1: + return "fn" + else: + raise ValueError("Unknown value occurred") + +def calculate_es(los_true, threshold, penalty, case="tp"): + metric = 0.0 + if case == "tp": + if los_true >= threshold: # predict correct in early stage + metric = 1 + else: + metric = los_true / threshold + elif case == "fn": + if los_true >= threshold: # predict wrong in early stage + metric = 0 + else: + metric = los_true / threshold - 1 + elif case == "tn": + metric = 0.0 + elif case == "fp": + metric = penalty # penalty term + return metric + + +def early_prediction_score( + y_true_outcome: np.ndarray, + y_true_los: np.ndarray, + y_prob: np.ndarray, + late_threshold: Optional[float] = None, + fp_penalty: Optional[float] = -0.1 +) -> Dict[str, float]: + """Computes early prediction score for binary classification. + + Paper: Junyi Gao, et al. A Comprehensive Benchmark for COVID-19 Predictive Modeling + Using Electronic Health Records in Intensive Care: Choosing the Best Model for + COVID-19 Prognosis. arXiv preprint arXiv:2209.07805, 2023. + + Args: + y_true_outcome: True target outcome of shape (n_samples,). + y_true_los: Time to true target outcome of shape (n_samples,). + y_prob: Predicted probabilities of shape (n_samples,). + late_threshold: Threshold gamma for late prediction penalties. Default is 0.5 * + mean(y_true_los). + fp_penalty: Penalty term for false positive predictions. Default is -0.1. + + Returns: + Dictionary of metrics whose keys are the metric names and values are + the metric values. + + Examples: + >>> from pyhealth.metrics import early_prediction_score + >>> y_true_outcome = np.array([0, 0, 1, 1]) + >>> y_true_los = np.array([5, 3, 8, 1]) + >>> y_prob = np.array([0.1, 0.4, 0.7, 0.8]) + >>> early_prediction_score(y_true_outcome, y_true_los, y_prob) + {'score': 0.5952380952380952, 'late_threshold': 2.125, 'fp_penalty': 0.1} + """ + metric = [] + metric_optimal = [] + num_records = len(y_prob) + + if late_threshold is None: + late_threshold = 0.5 * np.mean(y_true_los) + + for i in range(num_records): + cur_outcome_pred = y_prob[i] + cur_outcome_true = y_true_outcome[i] + cur_los_true = y_true_los[i] + prediction_result = calculate_confusion_matrix_value_result(cur_outcome_pred, cur_outcome_true) + prediction_result_optimal = calculate_confusion_matrix_value_result(cur_outcome_true, cur_outcome_true) + metric.append( + calculate_es( + cur_los_true, + late_threshold, + penalty=fp_penalty, + case=prediction_result, + ) + ) + metric_optimal.append( + calculate_es( + cur_los_true, + late_threshold, + penalty=fp_penalty, + case=prediction_result_optimal, + ) + ) + metric = np.array(metric) + metric_optimal = np.array(metric_optimal) + result = 0.0 + if metric_optimal.sum() > 0.0: + result = metric.sum() / metric_optimal.sum() + result = max(result, -1.0) + if isinstance(result, np.float64): + result = result.item() + return {"score": result, 'late_threshold': late_threshold, 'fp_penalty': fp_penalty} + +if __name__ == "__main__": + y_true_outcome = np.array([0, 1, 1, 1]) + y_true_los = np.array([5, 3, 8, 1]) + y_prob = np.array([0.1, 0.4, 0.7, 0.8]) + print(early_prediction_score(y_true_outcome, y_true_los, y_prob)) \ No newline at end of file diff --git a/pyhealth/metrics/osmae.py b/pyhealth/metrics/osmae.py new file mode 100644 index 00000000..23375a7d --- /dev/null +++ b/pyhealth/metrics/osmae.py @@ -0,0 +1,103 @@ + +import numpy as np +from typing import Dict, Optional + + +def calculate_outcome_prediction_result(outcome_pred, outcome_true): + outcome_pred = 1 if outcome_pred > 0.5 else 0 + return "true" if outcome_pred == outcome_true else "false" + + +def calculate_epsilon(los_true, threshold, large_los): + """ + epsilon is the decay term + """ + if los_true <= threshold: + return 1 + else: + return max(0, (los_true - large_los) / (threshold - large_los)) + + +def calculate_osmae(los_pred, los_true, large_los, threshold, case="true"): + if case == "true": + epsilon = calculate_epsilon(los_true, threshold, large_los) + return epsilon * np.abs(los_pred - los_true) + elif case == "false": + epsilon = calculate_epsilon(los_true, threshold, large_los) + return epsilon * (max(0, large_los - los_pred) + max(0, large_los - los_true)) + else: + raise ValueError("case must be 'true' or 'false'") + + +def osmae_score( + y_true_outcome: np.ndarray, + y_true_los: np.ndarray, + y_prob_outcome: np.ndarray, + y_pred_los: np.ndarray, + large_los: Optional[float] = None, + threshold: Optional[float] = None +) -> Dict[str, float]: + """Computes outcome-specific mean absolute error for length-of-stay prediction. + + Paper: Junyi Gao, et al. A Comprehensive Benchmark for COVID-19 Predictive Modeling + Using Electronic Health Records in Intensive Care: Choosing the Best Model for + COVID-19 Prognosis. arXiv preprint arXiv:2209.07805, 2023. + + Args: + y_true_outcome: True target outcome of shape (n_samples,). + y_true_los: Time to true target outcome of shape (n_samples,). + y_prob: Predicted probabilities of shape (n_samples,). + y_prob_outcome: Predicted outcome probabilities of shape (n_samples,). + y_pred_los: Predicted length-of-stay of shape (n_samples,). + large_los: Largest length-of-stay E, default is 95% percentile of maximum of total + length-of-stay. + threshold: Threshold gamma for late prediction penalties. Default is 0.5 * + mean(y_true_los). + + Returns: + Dictionary of metrics whose keys are the metric names and values are + the metric values. + + Examples: + >>> from pyhealth.metrics import osmae_score + >>> y_true_outcome = np.array([0, 0, 1, 1]) + >>> y_true_los = np.array([5, 3, 8, 1]) + >>> y_prob_outcome = np.array([0.1, 0.4, 0.7, 0.2]) + >>> y_pred_los = np.array([10, 5, 7, 3]) + >>> osmae_score(y_true_outcome, y_true_los, y_prob_outcome, y_pred_los) + {'osmae': 4.0638297872340425, 'large_los': 8, 'threshold': 2.125} + """ + if large_los is None: + large_los = np.sort(y_true_los)[int(0.95 * len(y_true_los))] + + if threshold is None: + threshold = 0.5 * np.mean(y_true_los) + + metric = [] + num_records = len(y_prob_outcome) + for i in range(num_records): + cur_outcome_pred = y_prob_outcome[i] + cur_los_pred = y_pred_los[i] + cur_outcome_true = y_true_outcome[i] + cur_los_true = y_true_los[i] + prediction_result = calculate_outcome_prediction_result( + cur_outcome_pred, cur_outcome_true + ) + metric.append( + calculate_osmae( + cur_los_pred, + cur_los_true, + large_los, + threshold, + case=prediction_result, + ) + ) + result = np.array(metric) + return {"osmae": result.mean(axis=0).item(), "large_los": large_los, "threshold": threshold} + +if __name__ == "__main__": + y_true_outcome = np.array([0, 0, 1, 1]) + y_true_los = np.array([5, 3, 8, 1]) + y_prob_outcome = np.array([0.1, 0.4, 0.7, 0.2]) + y_pred_los = np.array([10, 5, 7, 3]) + print(osmae_score(y_true_outcome, y_true_los, y_prob_outcome, y_pred_los)) \ No newline at end of file diff --git a/pyhealth/metrics/regression.py b/pyhealth/metrics/regression.py new file mode 100644 index 00000000..9db6266c --- /dev/null +++ b/pyhealth/metrics/regression.py @@ -0,0 +1,101 @@ +from typing import Dict, List, Optional + +import numpy as np +import sklearn.metrics as sklearn_metrics + +def calculate_ccc(y_true, y_pred): + """ + This function calculates the concordance correlation coefficient (CCC) between two vectors + :param y_true: real data + :param y_pred: estimated data + :return: CCC + :rtype: float + """ + cor = np.corrcoef(y_true, y_pred)[0][1] + mean_true = np.mean(y_true) + mean_pred = np.mean(y_pred) + var_true = np.var(y_true) + var_pred = np.var(y_pred) + sd_true = np.std(y_true) + sd_pred = np.std(y_pred) + numerator = 2 * cor * sd_true * sd_pred + denominator = var_true + var_pred + (mean_true - mean_pred) ** 2 + return numerator / denominator + +def regression_metrics_fn( + y_true: np.ndarray, + y_pred: np.ndarray, + metrics: Optional[List[str]] = None +) -> Dict[str, float]: + """Computes metrics for regression. + + User can specify which metrics to compute by passing a list of metric names. + The accepted metric names are: + - mse: mean squared error + - mae: mean absolute error + - mape: mean absolute percentage error + - rmse: root mean squared error + - ccc: concordance correlation coefficient + - r2: R^2 score + If no metrics are specified, mse, mae and r2 are computed by default. + + This function calls sklearn.metrics functions to compute the metrics. For + more information on the metrics, please refer to the documentation of the + corresponding sklearn.metrics functions. + + Args: + y_true: True target values of shape (n_samples,). + y_pred: Predicted values of shape (n_samples,). + metrics: List of metrics to compute. Default is ["mse", "mae", "r2"]. + + Returns: + Dictionary of metrics whose keys are the metric names and values are + the metric values. + + Examples: + >>> from pyhealth.metrics import binary_metrics_fn + >>> y_true = np.array([1, 3, 2, 4]) + >>> y_pred = np.array([1.1, 2.4, 1.35, 2.8]) + >>> binary_metrics_fn(y_true, y_prob, metrics=["mse", "mae", "r2"]) + {'mse': 0.1475, 'mae': 0.275, 'r2': 0.6923076923076923} + """ + if metrics is None: + metrics = ["mse", "mae", "r2"] + + output = {} + for metric in metrics: + if metric == "mse": + mse = sklearn_metrics.mean_squared_error(y_true, y_pred) + output["mse"] = mse + elif metric == "mae": + mae = sklearn_metrics.mean_absolute_error(y_true, y_pred) + output["mae"] = mae + elif metric == "mape": + mape = sklearn_metrics.mean_absolute_percentage_error(y_true, y_pred) + output["mape"] = mape + elif metric == "rmse": + rmse = np.sqrt(sklearn_metrics.mean_squared_error(y_true, y_pred)) + output["rmse"] = rmse + elif metric == "ccc": + ccc = calculate_ccc(y_true, y_pred) + output["ccc"] = ccc + elif metric == "r2": + r2 = sklearn_metrics.r2_score(y_true, y_pred) + output["r2"] = r2 + else: + raise ValueError(f"Unknown metric for regression: {metric}") + return output + + +if __name__ == "__main__": + all_metrics = [ + "mse", + "mae", + "r2", + "ccc", + "rmse", + "mape" + ] + y_true = np.random.randint(0,10, size=100000) + y_pred = np.random.random(size=100000)*10 + print(regression_metrics_fn(y_true, y_pred, metrics=all_metrics)) diff --git a/pyhealth/models/adacare.py b/pyhealth/models/adacare.py index 378f4498..b081eb5a 100644 --- a/pyhealth/models/adacare.py +++ b/pyhealth/models/adacare.py @@ -348,6 +348,8 @@ class AdaCare(BaseModel): >>> print(ret) { 'loss': tensor(0.7167, grad_fn=), + 'feature_importance: [tesnor of shape (batch_size, time_step, feature_dim), ...], + 'conv_feature_importance: [tesnor of shape (batch_size, time_step, 3*kernel_size), ...], 'y_prob': tensor([[0.5009], [0.4779]], grad_fn=), 'y_true': tensor([[0.], [1.]]), 'logit': tensor([[ 0.0036], [-0.0886]], grad_fn=) @@ -438,7 +440,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: loss: a scalar tensor representing the loss. feature_importance: a list of tensors with shape (feature_type, batch_size, time_step, features) representing the feature importance. - conv_feature_importance: a list of tensors with shape (feature_type, batch_size, time_step, 3*kernal_size) + conv_feature_importance: a list of tensors with shape (feature_type, batch_size, time_step, 3*kernel_size) representing the convolutional feature importance. y_prob: a tensor representing the predicted probabilities. y_true: a tensor representing the true labels. @@ -518,6 +520,8 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: y_prob = self.prepare_y_prob(logits) results = { "loss": loss, + "feature_importance": feature_importance, + "conv_feature_importance": conv_feature_importance, "y_prob": y_prob, "y_true": y_true, "logit": logits, diff --git a/pyhealth/models/agent.py b/pyhealth/models/agent.py index 187fedb9..24b8b164 100644 --- a/pyhealth/models/agent.py +++ b/pyhealth/models/agent.py @@ -368,6 +368,8 @@ class Agent(BaseModel): >>> print(ret) { 'loss': tensor(1.4059, grad_fn=), + 'loss_task': tensor(0.6931, grad_fn=), + 'loss_RL': tensor(0.7128, grad_fn=), 'y_prob': tensor([[0.4861], [0.5348]], grad_fn=), 'y_true': tensor([[0.], [1.]]), 'logit': tensor([[-0.0556], [0.1392]], grad_fn=) @@ -660,6 +662,8 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: y_prob = self.prepare_y_prob(logits) results = { "loss": loss, + 'loss_task': loss_task, + 'loss_RL': loss_rl, "y_prob": y_prob, "y_true": y_true, "logit": logits, diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index a6d844ff..8655aeee 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -220,18 +220,21 @@ def get_output_size(self, label_tokenizer: Tokenizer) -> int: output_size: the output size of the model. """ output_size = label_tokenizer.get_vocabulary_size() - if self.mode == "binary": + if self.mode == "binary" or self.mode == "early_mortality": assert output_size == 2 output_size = 1 return output_size - def get_loss_function(self) -> Callable: + def get_loss_function(self, **args) -> Callable: """Gets the default loss function using `self.mode`. The default loss functions are: - binary: `F.binary_cross_entropy_with_logits` - multiclass: `F.cross_entropy` - multilabel: `F.binary_cross_entropy_with_logits` + - early_mortality: TimeAwareLoss + - regression: `F.mse_loss` + - multi_target: a list of loss functions with the same format as above Returns: The default loss function. @@ -242,32 +245,44 @@ def get_loss_function(self) -> Callable: return F.cross_entropy elif self.mode == "multilabel": return F.binary_cross_entropy_with_logits + elif self.mode == "regression": + return F.mse_loss + elif self.mode == "multi_target": + assert args["targets"] is type(list) + loss_list = [] + for target in args["targets"]: + loss_list.append(self.get_loss_function(target)) + elif self.mode == "early_mortality": + assert "outcome_pred" in args.keys() and "outcome_true" in args.keys() and "los_true" in args.keys() + return TimeAwareLoss() else: raise ValueError("Invalid mode: {}".format(self.mode)) def prepare_labels( self, - labels: Union[List[str], List[List[str]]], + labels: Union[List[str], List[List[str]], List[List[List[str]]]], label_tokenizer: Tokenizer, ) -> torch.Tensor: """Prepares the labels for model training and evaluation. This function converts the labels to different formats depending on the mode. The default formats are: - - binary: a tensor of shape (batch_size, 1) + - binary, early_mortality: a tensor of shape (batch_size, 1) - multiclass: a tensor of shape (batch_size,) - multilabel: a tensor of shape (batch_size, num_labels) + - regression: a tensor of shape (batch_size, 1) + - multi_target: a list of tensors with the same format as above Args: labels: the raw labels from the samples. It should be - - a list of str for binary and multiclass classificationa + - a list of str for binary and multiclass classification - a list of list of str for multilabel classification label_tokenizer: the label tokenizer. Returns: labels: the processed labels. """ - if self.mode in ["binary"]: + if self.mode in ["binary"] or self.mode in ["early_mortality"]: labels = label_tokenizer.convert_tokens_to_indices(labels) labels = torch.FloatTensor(labels).unsqueeze(-1) elif self.mode in ["multiclass"]: @@ -281,9 +296,14 @@ def prepare_labels( # convert to multihot num_labels = label_tokenizer.get_vocabulary_size() labels = batch_to_multihot(labels_index, num_labels) + elif self.mode in ["regression"]: + labels = torch.FloatTensor(labels).unsqueeze(-1) + elif self.mode in ["multi_target"]: + labels = [self.prepare_labels(label, label_tokenizer) for label in labels] else: raise NotImplementedError - labels = labels.to(self.device) + if self.mode not in ["multi_target"]: + labels = labels.to(self.device) return labels def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor: @@ -291,13 +311,15 @@ def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor: This function converts the predicted logits to predicted probabilities depending on the mode. The default formats are: - - binary: a tensor of shape (batch_size, 1) with values in [0, 1], + - binary, early_mortality: a tensor of shape (batch_size, 1) with values in [0, 1], which is obtained with `torch.sigmoid()` - multiclass: a tensor of shape (batch_size, num_classes) with values in [0, 1] and sum to 1, which is obtained with `torch.softmax()` - multilabel: a tensor of shape (batch_size, num_labels) with values in [0, 1], which is obtained with `torch.sigmoid()` + - regression: a tensor of shape (batch_size, 1) + - multi_target: a list of tensors with the same format as above Args: logits: the predicted logit tensor. @@ -305,12 +327,64 @@ def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor: Returns: y_prob: the predicted probability tensor. """ - if self.mode in ["binary"]: + if self.mode in ["binary"] or self.mode in ["early_mortality"]: y_prob = torch.sigmoid(logits) elif self.mode in ["multiclass"]: y_prob = F.softmax(logits, dim=-1) elif self.mode in ["multilabel"]: y_prob = torch.sigmoid(logits) + elif self.mode in ["regression"]: + y_prob = logits + elif self.mode in ["multi_target"]: + y_prob = [self.prepare_y_prob(logit) for logit in logits] else: raise NotImplementedError return y_prob + +class TimeAwareLoss(nn.Module): + """Computes time-aware loss for early mortality prediction. + + Paper: Junyi Gao, et al. A Comprehensive Benchmark for COVID-19 Predictive Modeling + Using Electronic Health Records in Intensive Care: Choosing the Best Model for + COVID-19 Prognosis. arXiv preprint arXiv:2209.07805, 2023. + + Args: + decay_rate: decay rate for the los . + reward_factor: reward factor for the correct predictions. + + Returns: + Dictionary of metrics whose keys are the metric names and values are + the metric values. + + Examples: + >>> from pyhealth.metrics import early_prediction_score + >>> y_true_outcome = np.array([0, 0, 1, 1]) + >>> y_true_los = np.array([5, 3, 8, 1]) + >>> y_prob = np.array([0.1, 0.4, 0.7, 0.8]) + >>> early_prediction_score(y_true_outcome, y_true_los, y_prob) + {'score': 0.5952380952380952, 'late_threshold': 2.125, 'fp_penalty': 0.1} + """ + def __init__(self, decay_rate=0.1, reward_factor=0.1): + super(TimeAwareLoss, self).__init__() + self.bce = nn.BCELoss(reduction='none') + self.decay_rate = decay_rate + self.reward_factor = reward_factor + + def forward(self, outcome_pred, outcome_true, los_true): + """Return the loss value of time-aware loss. + + Args: + outcome_pred: the predicted outcome + outcome_true: the true outcome + los_true: the true length of stay at the prediction time + + Returns: + y_prob: the predicted probability tensor. + """ + los_weights = torch.exp(-self.decay_rate * los_true) # Exponential decay + loss_unreduced = self.bce(outcome_pred, outcome_true) + + reward_term = (los_true * torch.abs(outcome_true - outcome_pred)).mean() # Reward term + loss = (loss_unreduced * los_weights).mean()-self.reward_factor * reward_term # Weighted loss + + return torch.clamp(loss, min=0) \ No newline at end of file diff --git a/pyhealth/models/concare.py b/pyhealth/models/concare.py index 62fb0ea2..b5297870 100644 --- a/pyhealth/models/concare.py +++ b/pyhealth/models/concare.py @@ -737,6 +737,8 @@ class ConCare(BaseModel): >>> print(ret) { 'loss': tensor(9.5541, grad_fn=), + 'loss_task': tensor(9.5541, grad_fn=), + 'loss_decov': tensor(0., grad_fn=), 'y_prob': tensor([[0.5323], [0.5363]], grad_fn=), 'y_true': tensor([[1.], [0.]]), 'logit': tensor([[0.1293], [0.1454]], grad_fn=) @@ -925,6 +927,8 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: y_prob = self.prepare_y_prob(logits) results = { "loss": loss, + "loss_task": loss_task, + "loss_decov": decov_loss, "y_prob": y_prob, "y_true": y_true, 'logit': logits, diff --git a/pyhealth/models/molerec.py b/pyhealth/models/molerec.py index dfa92a07..0b5b4667 100644 --- a/pyhealth/models/molerec.py +++ b/pyhealth/models/molerec.py @@ -297,7 +297,6 @@ def __init__( ) print(e) - self.hidden_size = hidden_size self.coef, self.target_ddi = coef, target_ddi GNN_para = { diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py index ea0fe84f..6da3d3b5 100644 --- a/pyhealth/models/stagenet.py +++ b/pyhealth/models/stagenet.py @@ -346,6 +346,7 @@ class StageNet(BaseModel): >>> print(ret) { 'loss': tensor(0.7111, grad_fn=), + 'distance': [tensor of shape (batch_size, time_step)], ...], 'y_prob': tensor([[0.4815], [0.4991]], grad_fn=), 'y_true': tensor([[1.], @@ -532,6 +533,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: y_prob = self.prepare_y_prob(logits) results = { "loss": loss, + 'distance': distance, "y_prob": y_prob, "y_true": y_true, "logit": logits, diff --git a/pyhealth/tasks/outcome_specific_length_of_stay_prediction.py b/pyhealth/tasks/outcome_specific_length_of_stay_prediction.py new file mode 100644 index 00000000..0b88a270 --- /dev/null +++ b/pyhealth/tasks/outcome_specific_length_of_stay_prediction.py @@ -0,0 +1,329 @@ +from pyhealth.data import Patient + + +def categorize_los(days: int): + """Categorizes length of stay into 10 categories. + + One for ICU stays shorter than a day, seven day-long categories for each day of + the first week, one for stays of over one week but less than two, + and one for stays of over two weeks. + + Args: + days: int, length of stay in days + + Returns: + category: int, category of length of stay + """ + # ICU stays shorter than a day + if days < 1: + return 0 + # each day of the first week + elif 1 <= days <= 7: + return days + # stays of over one week but less than two + elif 7 < days <= 14: + return 8 + # stays of over two weeks + else: + return 9 + + +def length_of_stay_prediction_mimic3_fn(patient: Patient): + """Processes a single patient for the length-of-stay prediction task. + + Length of stay prediction aims at predicting the length of stay (in days) of the + current hospital visit based on the clinical information from the visit + (e.g., conditions and procedures). + + Args: + patient: a Patient object. + + Returns: + samples: a list of samples, each sample is a dict with patient_id, visit_id, + and other task-specific attributes as key. + + Note that we define the task as a multi-class classification task. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> mimic3_base = MIMIC3Dataset( + ... root="/srv/local/data/physionet.org/files/mimiciii/1.4", + ... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], + ... code_mapping={"ICD9CM": "CCSCM"}, + ... ) + >>> from pyhealth.tasks import length_of_stay_prediction_mimic3_fn + >>> mimic3_sample = mimic3_base.set_task(length_of_stay_prediction_mimic3_fn) + >>> mimic3_sample.samples[0] + [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 4}] + """ + samples = [] + + for visit in patient: + + conditions = visit.get_code_list(table="DIAGNOSES_ICD") + procedures = visit.get_code_list(table="PROCEDURES_ICD") + drugs = visit.get_code_list(table="PRESCRIPTIONS") + # exclude: visits without condition, procedure, or drug code + if len(conditions) * len(procedures) * len(drugs) == 0: + continue + + los_days = (visit.discharge_time - visit.encounter_time).days + los_category = categorize_los(los_days) + + if visit.discharge_status not in [0, 1]: + mortality_label = 0 + else: + mortality_label = int(visit.discharge_status) + + # TODO: should also exclude visit with age < 18 + samples.append( + { + "visit_id": visit.visit_id, + "patient_id": patient.patient_id, + "conditions": [conditions], + "procedures": [procedures], + "drugs": [drugs], + "label": los_category, + "outcome_label": mortality_label, + } + ) + # no cohort selection + return samples + + +def length_of_stay_prediction_mimic4_fn(patient: Patient): + """Processes a single patient for the length-of-stay prediction task. + + Length of stay prediction aims at predicting the length of stay (in days) of the + current hospital visit based on the clinical information from the visit + (e.g., conditions and procedures). + + Args: + patient: a Patient object. + + Returns: + samples: a list of samples, each sample is a dict with patient_id, visit_id, + and other task-specific attributes as key. + + Note that we define the task as a multi-class classification task. + + Examples: + >>> from pyhealth.datasets import MIMIC4Dataset + >>> mimic4_base = MIMIC4Dataset( + ... root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", + ... tables=["diagnoses_icd", "procedures_icd"], + ... code_mapping={"ICD10PROC": "CCSPROC"}, + ... ) + >>> from pyhealth.tasks import length_of_stay_prediction_mimic4_fn + >>> mimic4_sample = mimic4_base.set_task(length_of_stay_prediction_mimic4_fn) + >>> mimic4_sample.samples[0] + [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 2}] + """ + samples = [] + + for visit in patient: + + conditions = visit.get_code_list(table="diagnoses_icd") + procedures = visit.get_code_list(table="procedures_icd") + drugs = visit.get_code_list(table="prescriptions") + # exclude: visits without condition, procedure, or drug code + if len(conditions) * len(procedures) * len(drugs) == 0: + continue + + los_days = (visit.discharge_time - visit.encounter_time).days + los_category = categorize_los(los_days) + + if visit.discharge_status not in [0, 1]: + mortality_label = 0 + else: + mortality_label = int(visit.discharge_status) + + # TODO: should also exclude visit with age < 18 + samples.append( + { + "visit_id": visit.visit_id, + "patient_id": patient.patient_id, + "conditions": [conditions], + "procedures": [procedures], + "drugs": [drugs], + "label": los_category, + "outcome_label": mortality_label, + } + ) + # no cohort selection + return samples + + +def length_of_stay_prediction_eicu_fn(patient: Patient): + """Processes a single patient for the length-of-stay prediction task. + + Length of stay prediction aims at predicting the length of stay (in days) of the + current hospital visit based on the clinical information from the visit + (e.g., conditions and procedures). + + Args: + patient: a Patient object. + + Returns: + samples: a list of samples, each sample is a dict with patient_id, visit_id, + and other task-specific attributes as key. + + Note that we define the task as a multi-class classification task. + + Examples: + >>> from pyhealth.datasets import eICUDataset + >>> eicu_base = eICUDataset( + ... root="/srv/local/data/physionet.org/files/eicu-crd/2.0", + ... tables=["diagnosis", "medication"], + ... code_mapping={}, + ... dev=True + ... ) + >>> from pyhealth.tasks import length_of_stay_prediction_eicu_fn + >>> eicu_sample = eicu_base.set_task(length_of_stay_prediction_eicu_fn) + >>> eicu_sample.samples[0] + [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 5}] + """ + samples = [] + + for visit in patient: + + conditions = visit.get_code_list(table="diagnosis") + procedures = visit.get_code_list(table="physicalExam") + drugs = visit.get_code_list(table="medication") + # exclude: visits without condition, procedure, or drug code + if len(conditions) * len(procedures) * len(drugs) == 0: + continue + + los_days = (visit.discharge_time - visit.encounter_time).days + los_category = categorize_los(los_days) + + if visit.discharge_status not in ["Alive", "Expired"]: + mortality_label = 0 + else: + mortality_label = 0 if visit.discharge_status == "Alive" else 1 + + # TODO: should also exclude visit with age < 18 + samples.append( + { + "visit_id": visit.visit_id, + "patient_id": patient.patient_id, + "conditions": [conditions], + "procedures": [procedures], + "drugs": [drugs], + "label": los_category, + "outcome_label": mortality_label, + } + ) + # no cohort selection + return samples + + +def length_of_stay_prediction_omop_fn(patient: Patient): + """Processes a single patient for the length-of-stay prediction task. + + Length of stay prediction aims at predicting the length of stay (in days) of the + current hospital visit based on the clinical information from the visit + (e.g., conditions and procedures). + + Args: + patient: a Patient object. + + Returns: + samples: a list of samples, each sample is a dict with patient_id, visit_id, + and other task-specific attributes as key. + + Note that we define the task as a multi-class classification task. + + Examples: + >>> from pyhealth.datasets import OMOPDataset + >>> omop_base = OMOPDataset( + ... root="https://storage.googleapis.com/pyhealth/synpuf1k_omop_cdm_5.2.2", + ... tables=["condition_occurrence", "procedure_occurrence"], + ... code_mapping={}, + ... ) + >>> from pyhealth.tasks import length_of_stay_prediction_omop_fn + >>> omop_sample = omop_base.set_task(length_of_stay_prediction_eicu_fn) + >>> omop_sample.samples[0] + [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 7}] + """ + samples = [] + + for visit in patient: + + conditions = visit.get_code_list(table="condition_occurrence") + procedures = visit.get_code_list(table="procedure_occurrence") + drugs = visit.get_code_list(table="drug_exposure") + # exclude: visits without condition, procedure, or drug code + if len(conditions) * len(procedures) * len(drugs) == 0: + continue + + los_days = (visit.discharge_time - visit.encounter_time).days + los_category = categorize_los(los_days) + mortality_label = int(visit.discharge_status) + + # TODO: should also exclude visit with age < 18 + samples.append( + { + "visit_id": visit.visit_id, + "patient_id": patient.patient_id, + "conditions": [conditions], + "procedures": [procedures], + "drugs": [drugs], + "label": los_category, + "outcome_label": mortality_label, + } + ) + # no cohort selection + return samples + + +if __name__ == "__main__": + from pyhealth.datasets import MIMIC3Dataset + + base_dataset = MIMIC3Dataset( + root="/srv/local/data/physionet.org/files/mimiciii/1.4", + tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], + dev=True, + code_mapping={"ICD9CM": "CCSCM", "NDC": "ATC"}, + refresh_cache=False, + ) + sample_dataset = base_dataset.set_task(task_fn=length_of_stay_prediction_mimic3_fn) + sample_dataset.stat() + print(sample_dataset.available_keys) + + from pyhealth.datasets import MIMIC4Dataset + + base_dataset = MIMIC4Dataset( + root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", + tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + dev=True, + code_mapping={"NDC": "ATC"}, + refresh_cache=False, + ) + sample_dataset = base_dataset.set_task(task_fn=length_of_stay_prediction_mimic4_fn) + sample_dataset.stat() + print(sample_dataset.available_keys) + + from pyhealth.datasets import eICUDataset + + base_dataset = eICUDataset( + root="/srv/local/data/physionet.org/files/eicu-crd/2.0", + tables=["diagnosis", "medication", "physicalExam"], + dev=True, + refresh_cache=False, + ) + sample_dataset = base_dataset.set_task(task_fn=length_of_stay_prediction_eicu_fn) + sample_dataset.stat() + print(sample_dataset.available_keys) + + from pyhealth.datasets import OMOPDataset + + base_dataset = OMOPDataset( + root="/srv/local/data/zw12/pyhealth/raw_data/synpuf1k_omop_cdm_5.2.2", + tables=["condition_occurrence", "procedure_occurrence", "drug_exposure"], + dev=True, + refresh_cache=False, + ) + sample_dataset = base_dataset.set_task(task_fn=length_of_stay_prediction_omop_fn) + sample_dataset.stat() + print(sample_dataset.available_keys) diff --git a/requirements.txt b/requirements.txt index d20966aa..c1fba519 100755 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,6 @@ networkx>=2.6.3 pandas>=1.3.2,<2 pandarallel>=1.5.3 mne>=1.0.3 +urllib3<=1.26.15 numpy tqdm \ No newline at end of file