Skip to content

Commit dfd6a23

Browse files
authored
ECG Autoencoder PheWAS in the Model Zoo (#514)
* categorical and continuous composite TensorMaps, bump version, add ecg phewas to model zoo
1 parent a0bbf53 commit dfd6a23

31 files changed

+17253
-8657
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ notebooks/**.ipynb filter=nbstripout
66
*.linux filter=lfs diff=lfs merge=lfs -text
77
*.osx filter=lfs diff=lfs merge=lfs -text
88
*.genes filter=lfs diff=lfs merge=lfs -text
9+
model_zoo/ECG_PheWAS/*.h5 filter=lfs diff=lfs merge=lfs -text

.github/workflows/docker-publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ on:
1616
# - v*
1717

1818
# Run tests for any PRs.
19-
pull_request:
19+
# pull_request:
2020

2121
env:
2222
IMAGE_NAME: ml4h_terra

docker/vm_boot_images/config/tensorflow-requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ plotnine
3434
vega
3535
ipycanvas>=0.7.0
3636
ipyannotations>=0.2.1
37-
torch
37+
torch==1.12.1
3838
opencv-python
3939
blosc
4040
boto3
41-
ml4ht==0.0.9
41+
ml4ht==0.0.10

ml4h/defines.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def __str__(self):
5858
'anterior_papillary': 9, 'LV_cavity': 10, 'LA_cavity': 11, 'body': 12,
5959
}
6060
LAX_2CH_HEART_LABELS = {
61-
'aortic_arch': 1, 'left_pulmonary_artery_wall': 2, 'left_pulmonary_artery': 3,
6261
'LA_appendage': 4, 'LA_free_wall': 5, 'LV_posterior_wall': 6, 'LV_anterior_wall': 7, 'posterior_papillary': 8,
6362
'anterior_papillary': 9, 'LV_cavity': 10, 'LA_cavity': 11,
6463
}
@@ -88,10 +87,9 @@ def __str__(self):
8887
'RV_cavity': 5, 'thoracic_cavity': 6, 'liver': 7, 'stomach': 8, 'spleen': 9, 'kidney': 11, 'body': 10,
8988
'left_atrium': 12, 'right_atrium': 13, 'aorta': 14, 'pulmonary_artery': 15,
9089
}
91-
MRI_SAX_SEGMENTED_CHANNEL_MAP = {
92-
'background': 0, 'RV_free_wall': 1, 'interventricular_septum': 2, 'LV_free_wall': 3, 'LV_cavity': 4,
93-
'RV_cavity': 5, 'thoracic_cavity': 6, 'liver': 7, 'stomach': 8, 'spleen': 9, 'kidney': 11, 'body': 10,
94-
'left_atrium': 12, 'right_atrium': 13, 'aorta': 14, 'pulmonary_artery': 15,
90+
SAX_HEART_LABELS = {
91+
'RV_free_wall': 1, 'interventricular_septum': 2, 'LV_free_wall': 3, 'LV_cavity': 4,
92+
'RV_cavity': 5, 'left_atrium': 12, 'right_atrium': 13,
9593
}
9694
MRI_AO_SEGMENTED_CHANNEL_MAP = {
9795
'ao_background': 0, 'ao_superior_vena_cava': 1, 'ao_pulmonary_artery': 2, 'ao_ascending_aortic_wall': 3, 'ao_ascending_aorta': 4,

ml4h/explorations.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,13 @@ def explore(args):
10451045

10461046
def latent_space_dataframe(infer_hidden_tsv, explore_csv):
10471047
df = pd.read_csv(explore_csv)
1048-
df['sample_id'] = pd.to_numeric(df['sample_id'], errors='coerce')
1048+
if 'sample_id' in df.columns:
1049+
id_col = 'sample_id'
1050+
elif 'fpath' in df.columns:
1051+
id_col = 'fpath'
1052+
else:
1053+
raise ValueError(f'Could not find a sample ID column in explore CSV:{explore_csv}')
1054+
df['sample_id'] = pd.to_numeric(df[id_col], errors='coerce')
10491055
df2 = pd.read_csv(infer_hidden_tsv, sep='\t', engine='python')
10501056
df2['sample_id'] = pd.to_numeric(df2['sample_id'], errors='coerce')
10511057
latent_df = pd.merge(df, df2, on='sample_id', how='inner')

ml4h/models/pretrained_blocks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def __init__(
8080
*,
8181
tensor_map: TensorMap,
8282
pretrain_trainable: bool,
83-
base_model = "https://tfhub.dev/jeongukjae/roberta_en_cased_L-24_H-1024_A-16/1",
84-
preprocess_model="https://tfhub.dev/jeongukjae/roberta_en_cased_preprocess/1",
83+
base_model="https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3",
84+
preprocess_model="https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3",
8585
**kwargs,
8686
):
8787
self.tensor_map = tensor_map

ml4h/plots.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,9 @@ def plot_survivorship(
10321032
:param title: Title for the plot
10331033
:param prefix: Path prefix where plot will be saved
10341034
:param days_window: Maximum days of follow up
1035+
:param dpi: Dots per inch of the figure
1036+
:param width: Width in inches of the figure
1037+
:param height: Height in inches of the figure
10351038
"""
10361039
plt.figure(figsize=(width, height), dpi=dpi)
10371040
days_sorted_index = np.argsort(days_follow_up)
@@ -1100,6 +1103,9 @@ def plot_survival(
11001103
:param title: Title for the plot
11011104
:param days_window: Maximum days of follow up
11021105
:param prefix: Path prefix where plot will be saved
1106+
:param dpi: Dots per inch of the figure
1107+
:param width: Width in inches of the figure
1108+
:param height: Height in inches of the figure
11031109
11041110
:return: Dictionary mapping metric names to their floating point values
11051111
"""
@@ -1109,7 +1115,7 @@ def plot_survival(
11091115
plt.figure(figsize=(width, height), dpi=dpi)
11101116

11111117
cumulative_sick = np.cumsum(np.sum(truth[:, intervals:], axis=0))
1112-
cumulative_censored = (truth.shape[0]-np.sum(truth[:, :intervals], axis=0))-cumulative_sick
1118+
cumulative_censored = (truth.shape[0]-np.sum(truth[:, :intervals], axis=0)) - cumulative_sick
11131119
alive_per_step = np.sum(truth[:, :intervals], axis=0)
11141120
sick_per_step = np.sum(truth[:, intervals:], axis=0)
11151121
survivorship = np.cumprod(1 - (sick_per_step / alive_per_step))

ml4h/tensorize/dataflow/bigquery_ukb_queries.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,14 @@ def write_tensor_from_sql(sampleid_to_rows, output_path, tensor_type):
7777
hd5.create_dataset('icd', (1,), data=JOIN_CHAR.join(icds), dtype=h5py.special_dtype(vlen=str))
7878
elif tensor_type == 'categorical':
7979
for row in rows:
80-
hd5_dataset_name = dataset_name_from_meaning('categorical', [row['field'], row['meaning'], str(row['instance']), str(row['array_idx'])])
80+
fields = [str(row['fieldid']), row['field'], row['meaning'],
81+
str(row['instance']), str(row['array_idx'])]
82+
hd5_dataset_name = dataset_name_from_meaning('categorical', fields)
8183
_write_float_or_warn(sample_id, row, hd5_dataset_name, hd5)
8284
elif tensor_type == 'continuous':
8385
for row in rows:
84-
hd5_dataset_name = dataset_name_from_meaning('continuous', [str(row['fieldid']), row['field'], str(row['instance']), str(row['array_idx'])])
86+
fields = [str(row['fieldid']), row['field'], str(row['instance']), str(row['array_idx'])]
87+
hd5_dataset_name = dataset_name_from_meaning('continuous', fields)
8588
_write_float_or_warn(sample_id, row, hd5_dataset_name, hd5)
8689
elif tensor_type in ['disease', 'phecode_disease']:
8790
for row in rows:

ml4h/tensorize/tensor_writer_ukbb.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -786,11 +786,9 @@ def _write_ecg_bike_tensors(ecgs, xml_field, hd5, sample_id, stats):
786786
for ecg in ecgs:
787787
root = et.parse(ecg).getroot()
788788
date = datetime.datetime.strptime(_date_str_from_ecg(root), '%Y-%m-%d')
789-
write_to_hd5 = partial(create_tensor_in_hd5, hd5=hd5, path_prefix='ukb_ecg_bike', stats=stats, date=date)
790-
logging.info('Got ECG for sample:{} XML field:{}'.format(sample_id, xml_field))
791-
792789
instance = ecg.split(JOIN_CHAR)[-2]
793-
write_to_hd5(storage_type=StorageType.STRING, name='instance', value=instance)
790+
write_to_hd5 = partial(create_tensor_in_hd5, hd5=hd5, path_prefix='ukb_ecg_bike', instance=instance, stats=stats, date=date)
791+
logging.info(f'Got ECG for sample:{sample_id} XML field:{xml_field}')
794792

795793
protocol = root.findall('./Protocol/Phase')[0].find('ProtocolName').text
796794
write_to_hd5(storage_type=StorageType.STRING, name='protocol', value=protocol)
@@ -881,10 +879,13 @@ def _write_ecg_bike_tensors(ecgs, xml_field, hd5, sample_id, stats):
881879
if field_val is False:
882880
continue
883881
trends[lead_field][i, lead_to_int[lead_num]] = field_val
884-
trends['time'][i] = SECONDS_PER_MINUTE * int(trend_entry.find("EntryTime/Minute").text) + int(trend_entry.find("EntryTime/Second").text)
885-
trends['PhaseTime'][i] = SECONDS_PER_MINUTE * int(trend_entry.find("PhaseTime/Minute").text) + int(trend_entry.find("PhaseTime/Second").text)
886-
trends['PhaseName'][i] = phase_to_int[trend_entry.find('PhaseName').text]
887-
trends['Artifact'][i] = float(trend_entry.find('Artifact').text.strip('%')) / 100 # Artifact is reported as a percentage
882+
try:
883+
trends['time'][i] = SECONDS_PER_MINUTE * int(trend_entry.find("EntryTime/Minute").text) + int(trend_entry.find("EntryTime/Second").text)
884+
trends['PhaseTime'][i] = SECONDS_PER_MINUTE * int(trend_entry.find("PhaseTime/Minute").text) + int(trend_entry.find("PhaseTime/Second").text)
885+
trends['PhaseName'][i] = phase_to_int[trend_entry.find('PhaseName').text]
886+
trends['Artifact'][i] = float(trend_entry.find('Artifact').text.strip('%')) / 100 # Artifact is reported as a percentage
887+
except AttributeError as e:
888+
stats['AttributeError on Trend Data'] += 1
888889

889890
for field, trend_list in trends.items():
890891
write_to_hd5(name=f'trend_{str.lower(field)}', value=trend_list)
@@ -900,12 +901,15 @@ def _write_ecg_bike_tensors(ecgs, xml_field, hd5, sample_id, stats):
900901
write_to_hd5(name=f'{str.lower(phase_name)}_duration', value=[phase_duration])
901902

902903
# HR stats
903-
max_hr = _xml_path_to_float(root, './ExerciseMeasurements/MaxHeartRate')
904-
resting_hr = _xml_path_to_float(root, './ExerciseMeasurements/RestingStats/RestHR')
905-
max_pred_hr = _xml_path_to_float(root, './ExerciseMeasurements/MaxPredictedHR')
906-
write_to_hd5(name='max_hr', value=[max_hr])
907-
write_to_hd5(name='resting_hr', value=[resting_hr])
908-
write_to_hd5(name='max_pred_hr', value=[max_pred_hr])
904+
try:
905+
max_hr = _xml_path_to_float(root, './ExerciseMeasurements/MaxHeartRate')
906+
write_to_hd5(name='max_hr', value=[max_hr])
907+
resting_hr = _xml_path_to_float(root, './ExerciseMeasurements/RestingStats/RestHR')
908+
write_to_hd5(name='resting_hr', value=[resting_hr])
909+
max_pred_hr = _xml_path_to_float(root, './ExerciseMeasurements/MaxPredictedHR')
910+
write_to_hd5(name='max_pred_hr', value=[max_pred_hr])
911+
except AttributeError as e:
912+
stats['AttributeError on HR Stats'] += 1
909913

910914

911915
def _write_tensors_from_niftis(folder: str, hd5: h5py.File, field_id: str, stats: Counter):

0 commit comments

Comments
 (0)