diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 58358241..8c0446e0 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -10,6 +10,7 @@ Datasets
datasets/pyhealth.datasets.SampleEHRDataset
datasets/pyhealth.datasets.SampleSignalDataset
datasets/pyhealth.datasets.MIMIC3Dataset
+ datasets/pyhealth.datasets.MIMICExtractDataset
datasets/pyhealth.datasets.MIMIC4Dataset
datasets/pyhealth.datasets.eICUDataset
datasets/pyhealth.datasets.OMOPDataset
diff --git a/docs/api/datasets/pyhealth.datasets.MIMICExtractDataset.rst b/docs/api/datasets/pyhealth.datasets.MIMICExtractDataset.rst
new file mode 100644
index 00000000..d38b9cfc
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.MIMICExtractDataset.rst
@@ -0,0 +1,15 @@
+pyhealth.datasets.MIMICExtractDataset
+===================================
+
+The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis.
+
+.. autoclass:: pyhealth.datasets.MIMICExtractDataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst b/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst
index 1c140905..108f3001 100644
--- a/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst
+++ b/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst
@@ -1,7 +1,7 @@
pyhealth.datasets.OMOPDataset
===================================
-We can process any OMOP-CDM formatted database, refer to `doc `_ for more information. We it into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis.
+We can process any OMOP-CDM formatted database, refer to `doc `_ for more information. The raw data is processed into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis.
.. autoclass:: pyhealth.datasets.OMOPDataset
:members:
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 310d0249..817faeb0 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -15,6 +15,7 @@ We implement the following models for supporting multiple healthcare predictive
models/pyhealth.models.GAMENet
models/pyhealth.models.MICRON
models/pyhealth.models.SafeDrug
+ models/pyhealth.models.MoleRec
models/pyhealth.models.Deepr
models/pyhealth.models.ContraWR
models/pyhealth.models.SparcNet
diff --git a/docs/api/models/pyhealth.models.MoleRec.rst b/docs/api/models/pyhealth.models.MoleRec.rst
new file mode 100644
index 00000000..541d315f
--- /dev/null
+++ b/docs/api/models/pyhealth.models.MoleRec.rst
@@ -0,0 +1,14 @@
+pyhealth.models.MoleRec
+===================================
+
+The separate callable MoleRecLayer and the complete MoleRec model.
+
+.. autoclass:: pyhealth.models.MoleRecLayer
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: pyhealth.models.MoleRec
+ :members:
+ :undoc-members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/log.rst b/docs/log.rst
index 711f35fa..aec41122 100644
--- a/docs/log.rst
+++ b/docs/log.rst
@@ -4,66 +4,66 @@ We track the new development here:
**May 9, 2023**
-.. code-block:: bash
+.. code-block:: rst
- 1. add MIMIC-Extract dataset `#136 `_
+ 1. add MIMIC-Extract dataset `#136`
2. add new maintainer members for pyhealth: Junyi Gao and Benjamin Danek
**May 6, 2023**
-.. code-block:: bash
+.. code-block:: rst
- 1. add new parser functions (admissionDx, diagnosisStrings) and prediction tasks for eICU dataset `#148 `_
+ 1. add new parser functions (admissionDx, diagnosisStrings) and prediction tasks for eICU dataset `#148`
**Apr 27, 2023**
-.. code-block:: bash
+.. code-block:: rst
- 1. add MoleRec model (WWW'23) for drug recommendation `#122 `_
+ 1. add MoleRec model (WWW'23) for drug recommendation `#122`
**Apr 26, 2023**
-.. code-block:: bash
+.. code-block:: rst
- 1. fix bugs in GRASP model `#141 `_
- 2. add pandas install <2 constraints `#135 `_
- 3. add hcpcsevents table process in MIMIC4 dataset `#134 `_
+ 1. fix bugs in GRASP model `#141`
+ 2. add pandas install <2 constraints `#135`
+ 3. add hcpcsevents table process in MIMIC4 dataset `#134`
**Apr 10, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. fix Ambiguous datetime usage in eICU (https://github.com/sunlabuiuc/PyHealth/pull/132)
**Mar 26, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. add the entire uncertainty quantification module (https://github.com/sunlabuiuc/PyHealth/pull/111)
**Feb 26, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. add 6 EHR predictiom model: Adacare, Concare, Stagenet, TCN, Grasp, Agent
**Feb 24, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. add unittest for omop dataset
- 2. add github action triggered manually, check #104
+ 2. add github action triggered manually, check `#104`
**Feb 19, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. add unittest for eicu dataset
2. add ISRUC dataset (and task function) for signal learning
**Feb 12, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. add unittest for mimiciii, mimiciv
2. add SHHS datasets for sleep staging task
@@ -71,7 +71,7 @@ We track the new development here:
**Feb 08, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. complete the biosignal data support, add ContraWR [1] model for general purpose biosignal classification task ([1] Yang, Chaoqi, Danica Xiao, M. Brandon Westover, and Jimeng Sun.
"Self-supervised eeg representation learning for automatic sleep staging."
@@ -79,46 +79,46 @@ We track the new development here:
**Feb 07, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. Support signal dataset processing and split: add SampleSignalDataset, BaseSignalDataset. Use SleepEDFcassette dataset as the first signal dataset. Use example/sleep_staging_sleepEDF_contrawr.py
2. rename the dataset/ parts: previous BaseDataset becomes BaseEHRDataset and SampleDatast becomes SampleEHRDataset. Right now, BaseDataset will be inherited by BaseEHRDataset and BaseSignalDataset. SampleBaseDataset will be inherited by SampleEHRDataset and SampleSignalDataset.
**Feb 06, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. improve readme style
2. add the pyhealth live 06 and 07 link to pyhealth live
**Feb 01, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. add unittest of PyHealth MedCode and Tokenizer
**Jan 26, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. accelerate MIMIC-IV, eICU and OMOP data loading by using multiprocessing (pandarallel)
**Jan 25, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. accelerate the MIMIC-III data loading process by using multiprocessing (pandarallel)
**Jan 24, 2023**
-.. code-block:: bash
+.. code-block:: rst
- 1. Fix the code typo in pyhealth/tasks/drug_recommendation.py for issue #71.
+ 1. Fix the code typo in pyhealth/tasks/drug_recommendation.py for issue `#71`.
2. update the pyhealth live schedule
**Jan 22, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. Fix the list of list of vector problem in RNN, Transformer, RETAIN, and CNN
2. Add initialization examples for RNN, Transformer, RETAIN, CNN, and Deepr
@@ -128,34 +128,34 @@ We track the new development here:
**Jan 21, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. Added a new model, Deepr (models.Deepr)
**Jan 20, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. add the pyhealth live 05
2. add slack channel invitation in pyhealth live page
**Jan 13, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. add the pyhealth live 03 and 04 video link to the nagivation
2. add future pyhealth live schedule
**Jan 8, 2023**
-.. code-block:: bash
+.. code-block:: rst
1. Changed BaseModel.add_feature_transform_layer in models/base_model.py so that it accepts special_tokens if necessary
2. fix an int/float bug in dataset checking (transform int to float and then process them uniformly)
**Dec 26, 2022**
-.. code-block:: bash
+.. code-block:: rst
1. add examples to pyhealth.data, pyhealth.datasets
2. improve jupyter notebook tutorials 0, 1, 2
@@ -163,7 +163,7 @@ We track the new development here:
**Dec 21, 2022**
-.. code-block:: bash
+.. code-block:: rst
1. add the development logs to the navigation
2. add the pyhealth live schedule to the nagivation
diff --git a/docs/requirements.txt b/docs/requirements.txt
index c7016fe4..959c5f4c 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,5 +1,5 @@
Sphinx==5.2.3
-sphinx-automodapi>
+sphinx-automodapi
sphinx-autodoc-annotation
sphinx_last_updated_by_git
sphinxcontrib-spelling
diff --git a/pyhealth/metrics/early_prediction_score.py b/pyhealth/metrics/early_prediction_score.py
new file mode 100644
index 00000000..dc5e5364
--- /dev/null
+++ b/pyhealth/metrics/early_prediction_score.py
@@ -0,0 +1,113 @@
+import numpy as np
+from typing import Dict, Optional
+
+
+def calculate_confusion_matrix_value_result(outcome_pred, outcome_true):
+ outcome_pred = 1 if outcome_pred > 0.5 else 0
+ if outcome_pred == 1 and outcome_true == 1:
+ return "tp"
+ elif outcome_pred == 0 and outcome_true == 0:
+ return "tn"
+ elif outcome_pred == 1 and outcome_true == 0:
+ return "fp"
+ elif outcome_pred == 0 and outcome_true == 1:
+ return "fn"
+ else:
+ raise ValueError("Unknown value occurred")
+
+def calculate_es(los_true, threshold, penalty, case="tp"):
+ metric = 0.0
+ if case == "tp":
+ if los_true >= threshold: # predict correct in early stage
+ metric = 1
+ else:
+ metric = los_true / threshold
+ elif case == "fn":
+ if los_true >= threshold: # predict wrong in early stage
+ metric = 0
+ else:
+ metric = los_true / threshold - 1
+ elif case == "tn":
+ metric = 0.0
+ elif case == "fp":
+ metric = penalty # penalty term
+ return metric
+
+
+def early_prediction_score(
+ y_true_outcome: np.ndarray,
+ y_true_los: np.ndarray,
+ y_prob: np.ndarray,
+ late_threshold: Optional[float] = None,
+ fp_penalty: Optional[float] = -0.1
+) -> Dict[str, float]:
+ """Computes early prediction score for binary classification.
+
+ Paper: Junyi Gao, et al. A Comprehensive Benchmark for COVID-19 Predictive Modeling
+ Using Electronic Health Records in Intensive Care: Choosing the Best Model for
+ COVID-19 Prognosis. arXiv preprint arXiv:2209.07805, 2023.
+
+ Args:
+ y_true_outcome: True target outcome of shape (n_samples,).
+ y_true_los: Time to true target outcome of shape (n_samples,).
+ y_prob: Predicted probabilities of shape (n_samples,).
+ late_threshold: Threshold gamma for late prediction penalties. Default is 0.5 *
+ mean(y_true_los).
+ fp_penalty: Penalty term for false positive predictions. Default is -0.1.
+
+ Returns:
+ Dictionary of metrics whose keys are the metric names and values are
+ the metric values.
+
+ Examples:
+ >>> from pyhealth.metrics import early_prediction_score
+ >>> y_true_outcome = np.array([0, 0, 1, 1])
+ >>> y_true_los = np.array([5, 3, 8, 1])
+ >>> y_prob = np.array([0.1, 0.4, 0.7, 0.8])
+ >>> early_prediction_score(y_true_outcome, y_true_los, y_prob)
+ {'score': 0.5952380952380952, 'late_threshold': 2.125, 'fp_penalty': 0.1}
+ """
+ metric = []
+ metric_optimal = []
+ num_records = len(y_prob)
+
+ if late_threshold is None:
+ late_threshold = 0.5 * np.mean(y_true_los)
+
+ for i in range(num_records):
+ cur_outcome_pred = y_prob[i]
+ cur_outcome_true = y_true_outcome[i]
+ cur_los_true = y_true_los[i]
+ prediction_result = calculate_confusion_matrix_value_result(cur_outcome_pred, cur_outcome_true)
+ prediction_result_optimal = calculate_confusion_matrix_value_result(cur_outcome_true, cur_outcome_true)
+ metric.append(
+ calculate_es(
+ cur_los_true,
+ late_threshold,
+ penalty=fp_penalty,
+ case=prediction_result,
+ )
+ )
+ metric_optimal.append(
+ calculate_es(
+ cur_los_true,
+ late_threshold,
+ penalty=fp_penalty,
+ case=prediction_result_optimal,
+ )
+ )
+ metric = np.array(metric)
+ metric_optimal = np.array(metric_optimal)
+ result = 0.0
+ if metric_optimal.sum() > 0.0:
+ result = metric.sum() / metric_optimal.sum()
+ result = max(result, -1.0)
+ if isinstance(result, np.float64):
+ result = result.item()
+ return {"score": result, 'late_threshold': late_threshold, 'fp_penalty': fp_penalty}
+
+if __name__ == "__main__":
+ y_true_outcome = np.array([0, 1, 1, 1])
+ y_true_los = np.array([5, 3, 8, 1])
+ y_prob = np.array([0.1, 0.4, 0.7, 0.8])
+ print(early_prediction_score(y_true_outcome, y_true_los, y_prob))
\ No newline at end of file
diff --git a/pyhealth/metrics/osmae.py b/pyhealth/metrics/osmae.py
new file mode 100644
index 00000000..23375a7d
--- /dev/null
+++ b/pyhealth/metrics/osmae.py
@@ -0,0 +1,103 @@
+
+import numpy as np
+from typing import Dict, Optional
+
+
+def calculate_outcome_prediction_result(outcome_pred, outcome_true):
+ outcome_pred = 1 if outcome_pred > 0.5 else 0
+ return "true" if outcome_pred == outcome_true else "false"
+
+
+def calculate_epsilon(los_true, threshold, large_los):
+ """
+ epsilon is the decay term
+ """
+ if los_true <= threshold:
+ return 1
+ else:
+ return max(0, (los_true - large_los) / (threshold - large_los))
+
+
+def calculate_osmae(los_pred, los_true, large_los, threshold, case="true"):
+ if case == "true":
+ epsilon = calculate_epsilon(los_true, threshold, large_los)
+ return epsilon * np.abs(los_pred - los_true)
+ elif case == "false":
+ epsilon = calculate_epsilon(los_true, threshold, large_los)
+ return epsilon * (max(0, large_los - los_pred) + max(0, large_los - los_true))
+ else:
+ raise ValueError("case must be 'true' or 'false'")
+
+
+def osmae_score(
+ y_true_outcome: np.ndarray,
+ y_true_los: np.ndarray,
+ y_prob_outcome: np.ndarray,
+ y_pred_los: np.ndarray,
+ large_los: Optional[float] = None,
+ threshold: Optional[float] = None
+) -> Dict[str, float]:
+ """Computes outcome-specific mean absolute error for length-of-stay prediction.
+
+ Paper: Junyi Gao, et al. A Comprehensive Benchmark for COVID-19 Predictive Modeling
+ Using Electronic Health Records in Intensive Care: Choosing the Best Model for
+ COVID-19 Prognosis. arXiv preprint arXiv:2209.07805, 2023.
+
+ Args:
+ y_true_outcome: True target outcome of shape (n_samples,).
+ y_true_los: Time to true target outcome of shape (n_samples,).
+ y_prob: Predicted probabilities of shape (n_samples,).
+ y_prob_outcome: Predicted outcome probabilities of shape (n_samples,).
+ y_pred_los: Predicted length-of-stay of shape (n_samples,).
+ large_los: Largest length-of-stay E, default is 95% percentile of maximum of total
+ length-of-stay.
+ threshold: Threshold gamma for late prediction penalties. Default is 0.5 *
+ mean(y_true_los).
+
+ Returns:
+ Dictionary of metrics whose keys are the metric names and values are
+ the metric values.
+
+ Examples:
+ >>> from pyhealth.metrics import osmae_score
+ >>> y_true_outcome = np.array([0, 0, 1, 1])
+ >>> y_true_los = np.array([5, 3, 8, 1])
+ >>> y_prob_outcome = np.array([0.1, 0.4, 0.7, 0.2])
+ >>> y_pred_los = np.array([10, 5, 7, 3])
+ >>> osmae_score(y_true_outcome, y_true_los, y_prob_outcome, y_pred_los)
+ {'osmae': 4.0638297872340425, 'large_los': 8, 'threshold': 2.125}
+ """
+ if large_los is None:
+ large_los = np.sort(y_true_los)[int(0.95 * len(y_true_los))]
+
+ if threshold is None:
+ threshold = 0.5 * np.mean(y_true_los)
+
+ metric = []
+ num_records = len(y_prob_outcome)
+ for i in range(num_records):
+ cur_outcome_pred = y_prob_outcome[i]
+ cur_los_pred = y_pred_los[i]
+ cur_outcome_true = y_true_outcome[i]
+ cur_los_true = y_true_los[i]
+ prediction_result = calculate_outcome_prediction_result(
+ cur_outcome_pred, cur_outcome_true
+ )
+ metric.append(
+ calculate_osmae(
+ cur_los_pred,
+ cur_los_true,
+ large_los,
+ threshold,
+ case=prediction_result,
+ )
+ )
+ result = np.array(metric)
+ return {"osmae": result.mean(axis=0).item(), "large_los": large_los, "threshold": threshold}
+
+if __name__ == "__main__":
+ y_true_outcome = np.array([0, 0, 1, 1])
+ y_true_los = np.array([5, 3, 8, 1])
+ y_prob_outcome = np.array([0.1, 0.4, 0.7, 0.2])
+ y_pred_los = np.array([10, 5, 7, 3])
+ print(osmae_score(y_true_outcome, y_true_los, y_prob_outcome, y_pred_los))
\ No newline at end of file
diff --git a/pyhealth/metrics/regression.py b/pyhealth/metrics/regression.py
new file mode 100644
index 00000000..9db6266c
--- /dev/null
+++ b/pyhealth/metrics/regression.py
@@ -0,0 +1,101 @@
+from typing import Dict, List, Optional
+
+import numpy as np
+import sklearn.metrics as sklearn_metrics
+
+def calculate_ccc(y_true, y_pred):
+ """
+ This function calculates the concordance correlation coefficient (CCC) between two vectors
+ :param y_true: real data
+ :param y_pred: estimated data
+ :return: CCC
+ :rtype: float
+ """
+ cor = np.corrcoef(y_true, y_pred)[0][1]
+ mean_true = np.mean(y_true)
+ mean_pred = np.mean(y_pred)
+ var_true = np.var(y_true)
+ var_pred = np.var(y_pred)
+ sd_true = np.std(y_true)
+ sd_pred = np.std(y_pred)
+ numerator = 2 * cor * sd_true * sd_pred
+ denominator = var_true + var_pred + (mean_true - mean_pred) ** 2
+ return numerator / denominator
+
+def regression_metrics_fn(
+ y_true: np.ndarray,
+ y_pred: np.ndarray,
+ metrics: Optional[List[str]] = None
+) -> Dict[str, float]:
+ """Computes metrics for regression.
+
+ User can specify which metrics to compute by passing a list of metric names.
+ The accepted metric names are:
+ - mse: mean squared error
+ - mae: mean absolute error
+ - mape: mean absolute percentage error
+ - rmse: root mean squared error
+ - ccc: concordance correlation coefficient
+ - r2: R^2 score
+ If no metrics are specified, mse, mae and r2 are computed by default.
+
+ This function calls sklearn.metrics functions to compute the metrics. For
+ more information on the metrics, please refer to the documentation of the
+ corresponding sklearn.metrics functions.
+
+ Args:
+ y_true: True target values of shape (n_samples,).
+ y_pred: Predicted values of shape (n_samples,).
+ metrics: List of metrics to compute. Default is ["mse", "mae", "r2"].
+
+ Returns:
+ Dictionary of metrics whose keys are the metric names and values are
+ the metric values.
+
+ Examples:
+ >>> from pyhealth.metrics import binary_metrics_fn
+ >>> y_true = np.array([1, 3, 2, 4])
+ >>> y_pred = np.array([1.1, 2.4, 1.35, 2.8])
+ >>> binary_metrics_fn(y_true, y_prob, metrics=["mse", "mae", "r2"])
+ {'mse': 0.1475, 'mae': 0.275, 'r2': 0.6923076923076923}
+ """
+ if metrics is None:
+ metrics = ["mse", "mae", "r2"]
+
+ output = {}
+ for metric in metrics:
+ if metric == "mse":
+ mse = sklearn_metrics.mean_squared_error(y_true, y_pred)
+ output["mse"] = mse
+ elif metric == "mae":
+ mae = sklearn_metrics.mean_absolute_error(y_true, y_pred)
+ output["mae"] = mae
+ elif metric == "mape":
+ mape = sklearn_metrics.mean_absolute_percentage_error(y_true, y_pred)
+ output["mape"] = mape
+ elif metric == "rmse":
+ rmse = np.sqrt(sklearn_metrics.mean_squared_error(y_true, y_pred))
+ output["rmse"] = rmse
+ elif metric == "ccc":
+ ccc = calculate_ccc(y_true, y_pred)
+ output["ccc"] = ccc
+ elif metric == "r2":
+ r2 = sklearn_metrics.r2_score(y_true, y_pred)
+ output["r2"] = r2
+ else:
+ raise ValueError(f"Unknown metric for regression: {metric}")
+ return output
+
+
+if __name__ == "__main__":
+ all_metrics = [
+ "mse",
+ "mae",
+ "r2",
+ "ccc",
+ "rmse",
+ "mape"
+ ]
+ y_true = np.random.randint(0,10, size=100000)
+ y_pred = np.random.random(size=100000)*10
+ print(regression_metrics_fn(y_true, y_pred, metrics=all_metrics))
diff --git a/pyhealth/models/adacare.py b/pyhealth/models/adacare.py
index 378f4498..b081eb5a 100644
--- a/pyhealth/models/adacare.py
+++ b/pyhealth/models/adacare.py
@@ -348,6 +348,8 @@ class AdaCare(BaseModel):
>>> print(ret)
{
'loss': tensor(0.7167, grad_fn=),
+ 'feature_importance: [tesnor of shape (batch_size, time_step, feature_dim), ...],
+ 'conv_feature_importance: [tesnor of shape (batch_size, time_step, 3*kernel_size), ...],
'y_prob': tensor([[0.5009], [0.4779]], grad_fn=),
'y_true': tensor([[0.], [1.]]),
'logit': tensor([[ 0.0036], [-0.0886]], grad_fn=)
@@ -438,7 +440,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
loss: a scalar tensor representing the loss.
feature_importance: a list of tensors with shape (feature_type, batch_size, time_step, features)
representing the feature importance.
- conv_feature_importance: a list of tensors with shape (feature_type, batch_size, time_step, 3*kernal_size)
+ conv_feature_importance: a list of tensors with shape (feature_type, batch_size, time_step, 3*kernel_size)
representing the convolutional feature importance.
y_prob: a tensor representing the predicted probabilities.
y_true: a tensor representing the true labels.
@@ -518,6 +520,8 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
y_prob = self.prepare_y_prob(logits)
results = {
"loss": loss,
+ "feature_importance": feature_importance,
+ "conv_feature_importance": conv_feature_importance,
"y_prob": y_prob,
"y_true": y_true,
"logit": logits,
diff --git a/pyhealth/models/agent.py b/pyhealth/models/agent.py
index 187fedb9..24b8b164 100644
--- a/pyhealth/models/agent.py
+++ b/pyhealth/models/agent.py
@@ -368,6 +368,8 @@ class Agent(BaseModel):
>>> print(ret)
{
'loss': tensor(1.4059, grad_fn=),
+ 'loss_task': tensor(0.6931, grad_fn=),
+ 'loss_RL': tensor(0.7128, grad_fn=),
'y_prob': tensor([[0.4861], [0.5348]], grad_fn=),
'y_true': tensor([[0.], [1.]]),
'logit': tensor([[-0.0556], [0.1392]], grad_fn=)
@@ -660,6 +662,8 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
y_prob = self.prepare_y_prob(logits)
results = {
"loss": loss,
+ 'loss_task': loss_task,
+ 'loss_RL': loss_rl,
"y_prob": y_prob,
"y_true": y_true,
"logit": logits,
diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py
index a6d844ff..8655aeee 100644
--- a/pyhealth/models/base_model.py
+++ b/pyhealth/models/base_model.py
@@ -220,18 +220,21 @@ def get_output_size(self, label_tokenizer: Tokenizer) -> int:
output_size: the output size of the model.
"""
output_size = label_tokenizer.get_vocabulary_size()
- if self.mode == "binary":
+ if self.mode == "binary" or self.mode == "early_mortality":
assert output_size == 2
output_size = 1
return output_size
- def get_loss_function(self) -> Callable:
+ def get_loss_function(self, **args) -> Callable:
"""Gets the default loss function using `self.mode`.
The default loss functions are:
- binary: `F.binary_cross_entropy_with_logits`
- multiclass: `F.cross_entropy`
- multilabel: `F.binary_cross_entropy_with_logits`
+ - early_mortality: TimeAwareLoss
+ - regression: `F.mse_loss`
+ - multi_target: a list of loss functions with the same format as above
Returns:
The default loss function.
@@ -242,32 +245,44 @@ def get_loss_function(self) -> Callable:
return F.cross_entropy
elif self.mode == "multilabel":
return F.binary_cross_entropy_with_logits
+ elif self.mode == "regression":
+ return F.mse_loss
+ elif self.mode == "multi_target":
+ assert args["targets"] is type(list)
+ loss_list = []
+ for target in args["targets"]:
+ loss_list.append(self.get_loss_function(target))
+ elif self.mode == "early_mortality":
+ assert "outcome_pred" in args.keys() and "outcome_true" in args.keys() and "los_true" in args.keys()
+ return TimeAwareLoss()
else:
raise ValueError("Invalid mode: {}".format(self.mode))
def prepare_labels(
self,
- labels: Union[List[str], List[List[str]]],
+ labels: Union[List[str], List[List[str]], List[List[List[str]]]],
label_tokenizer: Tokenizer,
) -> torch.Tensor:
"""Prepares the labels for model training and evaluation.
This function converts the labels to different formats depending on the
mode. The default formats are:
- - binary: a tensor of shape (batch_size, 1)
+ - binary, early_mortality: a tensor of shape (batch_size, 1)
- multiclass: a tensor of shape (batch_size,)
- multilabel: a tensor of shape (batch_size, num_labels)
+ - regression: a tensor of shape (batch_size, 1)
+ - multi_target: a list of tensors with the same format as above
Args:
labels: the raw labels from the samples. It should be
- - a list of str for binary and multiclass classificationa
+ - a list of str for binary and multiclass classification
- a list of list of str for multilabel classification
label_tokenizer: the label tokenizer.
Returns:
labels: the processed labels.
"""
- if self.mode in ["binary"]:
+ if self.mode in ["binary"] or self.mode in ["early_mortality"]:
labels = label_tokenizer.convert_tokens_to_indices(labels)
labels = torch.FloatTensor(labels).unsqueeze(-1)
elif self.mode in ["multiclass"]:
@@ -281,9 +296,14 @@ def prepare_labels(
# convert to multihot
num_labels = label_tokenizer.get_vocabulary_size()
labels = batch_to_multihot(labels_index, num_labels)
+ elif self.mode in ["regression"]:
+ labels = torch.FloatTensor(labels).unsqueeze(-1)
+ elif self.mode in ["multi_target"]:
+ labels = [self.prepare_labels(label, label_tokenizer) for label in labels]
else:
raise NotImplementedError
- labels = labels.to(self.device)
+ if self.mode not in ["multi_target"]:
+ labels = labels.to(self.device)
return labels
def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor:
@@ -291,13 +311,15 @@ def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor:
This function converts the predicted logits to predicted probabilities
depending on the mode. The default formats are:
- - binary: a tensor of shape (batch_size, 1) with values in [0, 1],
+ - binary, early_mortality: a tensor of shape (batch_size, 1) with values in [0, 1],
which is obtained with `torch.sigmoid()`
- multiclass: a tensor of shape (batch_size, num_classes) with
values in [0, 1] and sum to 1, which is obtained with
`torch.softmax()`
- multilabel: a tensor of shape (batch_size, num_labels) with values
in [0, 1], which is obtained with `torch.sigmoid()`
+ - regression: a tensor of shape (batch_size, 1)
+ - multi_target: a list of tensors with the same format as above
Args:
logits: the predicted logit tensor.
@@ -305,12 +327,64 @@ def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor:
Returns:
y_prob: the predicted probability tensor.
"""
- if self.mode in ["binary"]:
+ if self.mode in ["binary"] or self.mode in ["early_mortality"]:
y_prob = torch.sigmoid(logits)
elif self.mode in ["multiclass"]:
y_prob = F.softmax(logits, dim=-1)
elif self.mode in ["multilabel"]:
y_prob = torch.sigmoid(logits)
+ elif self.mode in ["regression"]:
+ y_prob = logits
+ elif self.mode in ["multi_target"]:
+ y_prob = [self.prepare_y_prob(logit) for logit in logits]
else:
raise NotImplementedError
return y_prob
+
+class TimeAwareLoss(nn.Module):
+ """Computes time-aware loss for early mortality prediction.
+
+ Paper: Junyi Gao, et al. A Comprehensive Benchmark for COVID-19 Predictive Modeling
+ Using Electronic Health Records in Intensive Care: Choosing the Best Model for
+ COVID-19 Prognosis. arXiv preprint arXiv:2209.07805, 2023.
+
+ Args:
+ decay_rate: decay rate for the los .
+ reward_factor: reward factor for the correct predictions.
+
+ Returns:
+ Dictionary of metrics whose keys are the metric names and values are
+ the metric values.
+
+ Examples:
+ >>> from pyhealth.metrics import early_prediction_score
+ >>> y_true_outcome = np.array([0, 0, 1, 1])
+ >>> y_true_los = np.array([5, 3, 8, 1])
+ >>> y_prob = np.array([0.1, 0.4, 0.7, 0.8])
+ >>> early_prediction_score(y_true_outcome, y_true_los, y_prob)
+ {'score': 0.5952380952380952, 'late_threshold': 2.125, 'fp_penalty': 0.1}
+ """
+ def __init__(self, decay_rate=0.1, reward_factor=0.1):
+ super(TimeAwareLoss, self).__init__()
+ self.bce = nn.BCELoss(reduction='none')
+ self.decay_rate = decay_rate
+ self.reward_factor = reward_factor
+
+ def forward(self, outcome_pred, outcome_true, los_true):
+ """Return the loss value of time-aware loss.
+
+ Args:
+ outcome_pred: the predicted outcome
+ outcome_true: the true outcome
+ los_true: the true length of stay at the prediction time
+
+ Returns:
+ y_prob: the predicted probability tensor.
+ """
+ los_weights = torch.exp(-self.decay_rate * los_true) # Exponential decay
+ loss_unreduced = self.bce(outcome_pred, outcome_true)
+
+ reward_term = (los_true * torch.abs(outcome_true - outcome_pred)).mean() # Reward term
+ loss = (loss_unreduced * los_weights).mean()-self.reward_factor * reward_term # Weighted loss
+
+ return torch.clamp(loss, min=0)
\ No newline at end of file
diff --git a/pyhealth/models/concare.py b/pyhealth/models/concare.py
index 62fb0ea2..b5297870 100644
--- a/pyhealth/models/concare.py
+++ b/pyhealth/models/concare.py
@@ -737,6 +737,8 @@ class ConCare(BaseModel):
>>> print(ret)
{
'loss': tensor(9.5541, grad_fn=),
+ 'loss_task': tensor(9.5541, grad_fn=),
+ 'loss_decov': tensor(0., grad_fn=),
'y_prob': tensor([[0.5323], [0.5363]], grad_fn=),
'y_true': tensor([[1.], [0.]]),
'logit': tensor([[0.1293], [0.1454]], grad_fn=)
@@ -925,6 +927,8 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
y_prob = self.prepare_y_prob(logits)
results = {
"loss": loss,
+ "loss_task": loss_task,
+ "loss_decov": decov_loss,
"y_prob": y_prob,
"y_true": y_true,
'logit': logits,
diff --git a/pyhealth/models/molerec.py b/pyhealth/models/molerec.py
index dfa92a07..0b5b4667 100644
--- a/pyhealth/models/molerec.py
+++ b/pyhealth/models/molerec.py
@@ -297,7 +297,6 @@ def __init__(
)
print(e)
-
self.hidden_size = hidden_size
self.coef, self.target_ddi = coef, target_ddi
GNN_para = {
diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py
index ea0fe84f..6da3d3b5 100644
--- a/pyhealth/models/stagenet.py
+++ b/pyhealth/models/stagenet.py
@@ -346,6 +346,7 @@ class StageNet(BaseModel):
>>> print(ret)
{
'loss': tensor(0.7111, grad_fn=),
+ 'distance': [tensor of shape (batch_size, time_step)], ...],
'y_prob': tensor([[0.4815],
[0.4991]], grad_fn=),
'y_true': tensor([[1.],
@@ -532,6 +533,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
y_prob = self.prepare_y_prob(logits)
results = {
"loss": loss,
+ 'distance': distance,
"y_prob": y_prob,
"y_true": y_true,
"logit": logits,
diff --git a/pyhealth/tasks/outcome_specific_length_of_stay_prediction.py b/pyhealth/tasks/outcome_specific_length_of_stay_prediction.py
new file mode 100644
index 00000000..0b88a270
--- /dev/null
+++ b/pyhealth/tasks/outcome_specific_length_of_stay_prediction.py
@@ -0,0 +1,329 @@
+from pyhealth.data import Patient
+
+
+def categorize_los(days: int):
+ """Categorizes length of stay into 10 categories.
+
+ One for ICU stays shorter than a day, seven day-long categories for each day of
+ the first week, one for stays of over one week but less than two,
+ and one for stays of over two weeks.
+
+ Args:
+ days: int, length of stay in days
+
+ Returns:
+ category: int, category of length of stay
+ """
+ # ICU stays shorter than a day
+ if days < 1:
+ return 0
+ # each day of the first week
+ elif 1 <= days <= 7:
+ return days
+ # stays of over one week but less than two
+ elif 7 < days <= 14:
+ return 8
+ # stays of over two weeks
+ else:
+ return 9
+
+
+def length_of_stay_prediction_mimic3_fn(patient: Patient):
+ """Processes a single patient for the length-of-stay prediction task.
+
+ Length of stay prediction aims at predicting the length of stay (in days) of the
+ current hospital visit based on the clinical information from the visit
+ (e.g., conditions and procedures).
+
+ Args:
+ patient: a Patient object.
+
+ Returns:
+ samples: a list of samples, each sample is a dict with patient_id, visit_id,
+ and other task-specific attributes as key.
+
+ Note that we define the task as a multi-class classification task.
+
+ Examples:
+ >>> from pyhealth.datasets import MIMIC3Dataset
+ >>> mimic3_base = MIMIC3Dataset(
+ ... root="/srv/local/data/physionet.org/files/mimiciii/1.4",
+ ... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
+ ... code_mapping={"ICD9CM": "CCSCM"},
+ ... )
+ >>> from pyhealth.tasks import length_of_stay_prediction_mimic3_fn
+ >>> mimic3_sample = mimic3_base.set_task(length_of_stay_prediction_mimic3_fn)
+ >>> mimic3_sample.samples[0]
+ [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 4}]
+ """
+ samples = []
+
+ for visit in patient:
+
+ conditions = visit.get_code_list(table="DIAGNOSES_ICD")
+ procedures = visit.get_code_list(table="PROCEDURES_ICD")
+ drugs = visit.get_code_list(table="PRESCRIPTIONS")
+ # exclude: visits without condition, procedure, or drug code
+ if len(conditions) * len(procedures) * len(drugs) == 0:
+ continue
+
+ los_days = (visit.discharge_time - visit.encounter_time).days
+ los_category = categorize_los(los_days)
+
+ if visit.discharge_status not in [0, 1]:
+ mortality_label = 0
+ else:
+ mortality_label = int(visit.discharge_status)
+
+ # TODO: should also exclude visit with age < 18
+ samples.append(
+ {
+ "visit_id": visit.visit_id,
+ "patient_id": patient.patient_id,
+ "conditions": [conditions],
+ "procedures": [procedures],
+ "drugs": [drugs],
+ "label": los_category,
+ "outcome_label": mortality_label,
+ }
+ )
+ # no cohort selection
+ return samples
+
+
+def length_of_stay_prediction_mimic4_fn(patient: Patient):
+ """Processes a single patient for the length-of-stay prediction task.
+
+ Length of stay prediction aims at predicting the length of stay (in days) of the
+ current hospital visit based on the clinical information from the visit
+ (e.g., conditions and procedures).
+
+ Args:
+ patient: a Patient object.
+
+ Returns:
+ samples: a list of samples, each sample is a dict with patient_id, visit_id,
+ and other task-specific attributes as key.
+
+ Note that we define the task as a multi-class classification task.
+
+ Examples:
+ >>> from pyhealth.datasets import MIMIC4Dataset
+ >>> mimic4_base = MIMIC4Dataset(
+ ... root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp",
+ ... tables=["diagnoses_icd", "procedures_icd"],
+ ... code_mapping={"ICD10PROC": "CCSPROC"},
+ ... )
+ >>> from pyhealth.tasks import length_of_stay_prediction_mimic4_fn
+ >>> mimic4_sample = mimic4_base.set_task(length_of_stay_prediction_mimic4_fn)
+ >>> mimic4_sample.samples[0]
+ [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 2}]
+ """
+ samples = []
+
+ for visit in patient:
+
+ conditions = visit.get_code_list(table="diagnoses_icd")
+ procedures = visit.get_code_list(table="procedures_icd")
+ drugs = visit.get_code_list(table="prescriptions")
+ # exclude: visits without condition, procedure, or drug code
+ if len(conditions) * len(procedures) * len(drugs) == 0:
+ continue
+
+ los_days = (visit.discharge_time - visit.encounter_time).days
+ los_category = categorize_los(los_days)
+
+ if visit.discharge_status not in [0, 1]:
+ mortality_label = 0
+ else:
+ mortality_label = int(visit.discharge_status)
+
+ # TODO: should also exclude visit with age < 18
+ samples.append(
+ {
+ "visit_id": visit.visit_id,
+ "patient_id": patient.patient_id,
+ "conditions": [conditions],
+ "procedures": [procedures],
+ "drugs": [drugs],
+ "label": los_category,
+ "outcome_label": mortality_label,
+ }
+ )
+ # no cohort selection
+ return samples
+
+
+def length_of_stay_prediction_eicu_fn(patient: Patient):
+ """Processes a single patient for the length-of-stay prediction task.
+
+ Length of stay prediction aims at predicting the length of stay (in days) of the
+ current hospital visit based on the clinical information from the visit
+ (e.g., conditions and procedures).
+
+ Args:
+ patient: a Patient object.
+
+ Returns:
+ samples: a list of samples, each sample is a dict with patient_id, visit_id,
+ and other task-specific attributes as key.
+
+ Note that we define the task as a multi-class classification task.
+
+ Examples:
+ >>> from pyhealth.datasets import eICUDataset
+ >>> eicu_base = eICUDataset(
+ ... root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
+ ... tables=["diagnosis", "medication"],
+ ... code_mapping={},
+ ... dev=True
+ ... )
+ >>> from pyhealth.tasks import length_of_stay_prediction_eicu_fn
+ >>> eicu_sample = eicu_base.set_task(length_of_stay_prediction_eicu_fn)
+ >>> eicu_sample.samples[0]
+ [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 5}]
+ """
+ samples = []
+
+ for visit in patient:
+
+ conditions = visit.get_code_list(table="diagnosis")
+ procedures = visit.get_code_list(table="physicalExam")
+ drugs = visit.get_code_list(table="medication")
+ # exclude: visits without condition, procedure, or drug code
+ if len(conditions) * len(procedures) * len(drugs) == 0:
+ continue
+
+ los_days = (visit.discharge_time - visit.encounter_time).days
+ los_category = categorize_los(los_days)
+
+ if visit.discharge_status not in ["Alive", "Expired"]:
+ mortality_label = 0
+ else:
+ mortality_label = 0 if visit.discharge_status == "Alive" else 1
+
+ # TODO: should also exclude visit with age < 18
+ samples.append(
+ {
+ "visit_id": visit.visit_id,
+ "patient_id": patient.patient_id,
+ "conditions": [conditions],
+ "procedures": [procedures],
+ "drugs": [drugs],
+ "label": los_category,
+ "outcome_label": mortality_label,
+ }
+ )
+ # no cohort selection
+ return samples
+
+
+def length_of_stay_prediction_omop_fn(patient: Patient):
+ """Processes a single patient for the length-of-stay prediction task.
+
+ Length of stay prediction aims at predicting the length of stay (in days) of the
+ current hospital visit based on the clinical information from the visit
+ (e.g., conditions and procedures).
+
+ Args:
+ patient: a Patient object.
+
+ Returns:
+ samples: a list of samples, each sample is a dict with patient_id, visit_id,
+ and other task-specific attributes as key.
+
+ Note that we define the task as a multi-class classification task.
+
+ Examples:
+ >>> from pyhealth.datasets import OMOPDataset
+ >>> omop_base = OMOPDataset(
+ ... root="https://storage.googleapis.com/pyhealth/synpuf1k_omop_cdm_5.2.2",
+ ... tables=["condition_occurrence", "procedure_occurrence"],
+ ... code_mapping={},
+ ... )
+ >>> from pyhealth.tasks import length_of_stay_prediction_omop_fn
+ >>> omop_sample = omop_base.set_task(length_of_stay_prediction_eicu_fn)
+ >>> omop_sample.samples[0]
+ [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 7}]
+ """
+ samples = []
+
+ for visit in patient:
+
+ conditions = visit.get_code_list(table="condition_occurrence")
+ procedures = visit.get_code_list(table="procedure_occurrence")
+ drugs = visit.get_code_list(table="drug_exposure")
+ # exclude: visits without condition, procedure, or drug code
+ if len(conditions) * len(procedures) * len(drugs) == 0:
+ continue
+
+ los_days = (visit.discharge_time - visit.encounter_time).days
+ los_category = categorize_los(los_days)
+ mortality_label = int(visit.discharge_status)
+
+ # TODO: should also exclude visit with age < 18
+ samples.append(
+ {
+ "visit_id": visit.visit_id,
+ "patient_id": patient.patient_id,
+ "conditions": [conditions],
+ "procedures": [procedures],
+ "drugs": [drugs],
+ "label": los_category,
+ "outcome_label": mortality_label,
+ }
+ )
+ # no cohort selection
+ return samples
+
+
+if __name__ == "__main__":
+ from pyhealth.datasets import MIMIC3Dataset
+
+ base_dataset = MIMIC3Dataset(
+ root="/srv/local/data/physionet.org/files/mimiciii/1.4",
+ tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
+ dev=True,
+ code_mapping={"ICD9CM": "CCSCM", "NDC": "ATC"},
+ refresh_cache=False,
+ )
+ sample_dataset = base_dataset.set_task(task_fn=length_of_stay_prediction_mimic3_fn)
+ sample_dataset.stat()
+ print(sample_dataset.available_keys)
+
+ from pyhealth.datasets import MIMIC4Dataset
+
+ base_dataset = MIMIC4Dataset(
+ root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp",
+ tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
+ dev=True,
+ code_mapping={"NDC": "ATC"},
+ refresh_cache=False,
+ )
+ sample_dataset = base_dataset.set_task(task_fn=length_of_stay_prediction_mimic4_fn)
+ sample_dataset.stat()
+ print(sample_dataset.available_keys)
+
+ from pyhealth.datasets import eICUDataset
+
+ base_dataset = eICUDataset(
+ root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
+ tables=["diagnosis", "medication", "physicalExam"],
+ dev=True,
+ refresh_cache=False,
+ )
+ sample_dataset = base_dataset.set_task(task_fn=length_of_stay_prediction_eicu_fn)
+ sample_dataset.stat()
+ print(sample_dataset.available_keys)
+
+ from pyhealth.datasets import OMOPDataset
+
+ base_dataset = OMOPDataset(
+ root="/srv/local/data/zw12/pyhealth/raw_data/synpuf1k_omop_cdm_5.2.2",
+ tables=["condition_occurrence", "procedure_occurrence", "drug_exposure"],
+ dev=True,
+ refresh_cache=False,
+ )
+ sample_dataset = base_dataset.set_task(task_fn=length_of_stay_prediction_omop_fn)
+ sample_dataset.stat()
+ print(sample_dataset.available_keys)
diff --git a/requirements.txt b/requirements.txt
index d20966aa..c1fba519 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,5 +5,6 @@ networkx>=2.6.3
pandas>=1.3.2,<2
pandarallel>=1.5.3
mne>=1.0.3
+urllib3<=1.26.15
numpy
tqdm
\ No newline at end of file