From ea9686c050cf5ee6d60470aa13a723ce6bbf58b0 Mon Sep 17 00:00:00 2001 From: Shi Hu Date: Fri, 24 Apr 2020 09:35:37 +0200 Subject: [PATCH] init commit --- Classifiers/survival_analysis.py | 103 +++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 Classifiers/survival_analysis.py diff --git a/Classifiers/survival_analysis.py b/Classifiers/survival_analysis.py new file mode 100644 index 0000000..2aeff8b --- /dev/null +++ b/Classifiers/survival_analysis.py @@ -0,0 +1,103 @@ +''' + To run this script, we first need to modify the code as follows (updated Apr 23, 2020) + * in select_variables function of covid19_ICU_util.py, add this line: + variables_to_include += ['hospital'] + below + variables_to_include += ['Record Id'] + + * comment out the 3 lines below in select_x_y of covid19_ICU_util.py: + outcome_name = 'Combined outcome' + y = pd.concat([y1, y2, y3], axis=1) + return x, y, outcome_name + and replace them with: + outcome_name = 'y2' + y = pd.concat([y2, is_at_icu], axis=1) + return x, y, outcome_name + + * add these 2 lines after the call to prepare_for_learning in covid19_ICU_admission.py: + x.to_excel('features.xlsx') + y.to_excel('outcomes.xlsx') + + * run covid19_ICU_admission.py to generate the above 2 xlsx files +''' + +import matplotlib +import matplotlib.pyplot as plt +from lifelines import * +import pandas as pd +import numpy as np + + + + +def simple_estimates(outcomes): + fig, axes = plt.subplots(3, 3, figsize=(20, 15)) + matplotlib.rcParams.update({'font.size': 16}) + + outcomes = outcomes.dropna() + outcomes = outcomes[outcomes['duration_mortality'] > 0] + + T = outcomes['duration_mortality'] + E = outcomes['event_mortality'] + + kmf = KaplanMeierFitter().fit(T, E, label='KaplanMeierFitter') + wbf = WeibullFitter().fit(T, E, label='WeibullFitter') + exf = ExponentialFitter().fit(T, E, label='ExponentalFitter') + lnf = LogNormalFitter().fit(T, E, label='LogNormalFitter') + llf = LogLogisticFitter().fit(T, E, label='LogLogisticFitter') + pwf = PiecewiseExponentialFitter([40, 60]).fit(T, E, label='PiecewiseExponentialFitter') + gg = GeneralizedGammaFitter().fit(T, E, label='GeneralizedGammaFitter') + spf = SplineFitter([6, 20, 40, 75]).fit(T, E, label='SplineFitter') + + wbf.plot_survival_function(ax=axes[0][0]) + exf.plot_survival_function(ax=axes[0][1]) + lnf.plot_survival_function(ax=axes[0][2]) + kmf.plot_survival_function(ax=axes[1][0]) + llf.plot_survival_function(ax=axes[1][1]) + pwf.plot_survival_function(ax=axes[1][2]) + gg.plot_survival_function(ax=axes[2][0]) + spf.plot_survival_function(ax=axes[2][1]) + + + plt.savefig('plots/simple_estimate.png') + + +def cox_ph(features, outcomes, use_all=True): + result = pd.concat([features, outcomes], axis=1) + result = result.dropna() + result = result.drop(columns=['microbiology_worker']) + result = result.drop(columns=['days_at_icu']) # NOTE this is for selecting those who went to ICU and who didn't (currently not used) + + if use_all: + train_set = result.drop(columns=['hospital', 'aids_hiv']) # NOTE aids_hiv is an outlier for plotting coefs + + cph = CoxPHFitter() + cph.fit(train_set, duration_col='duration_mortality', event_col='event_mortality', show_progress=True, step_size=0.1) + + cph.print_summary() + + fig, ax = plt.subplots(figsize=(40, 30)) + cph.plot() + plt.savefig('plots/coef.png') + else: + test_hospital = 'MUMC' # this can be MUMC, Zuyderland, or AUMC - AMC + train_set = result[result['hospital'] != test_hospital] + test_set = result[result['hospital'] == test_hospital] + + train_set = train_set.drop(columns=['hospital']) + test_set = test_set.drop(columns=['hospital']) + + cph = CoxPHFitter() + cph.fit(train_set, duration_col='duration_mortality', event_col='event_mortality', show_progress=True, step_size=0.1) + + print('with and without ICU') + print('test hospital:', test_hospital) + print('test c-index', cph.score(test_set, scoring_method="concordance_index")) + +if __name__ == '__main__': + features = pd.read_excel('../features.xlsx', index_col=0) + outcomes = pd.read_excel('../outcomes.xlsx', index_col=0) + + #simple_estimates(outcomes) + + cox_ph(features, outcomes, False)