Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 152 additions & 7 deletions pyhealth/datasets/shhs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@

from pyhealth.datasets import BaseSignalDataset

from pyhealth.datasets.utils import read_edf_data, save_to_npz
from tqdm import tqdm

class SHHSDataset(BaseSignalDataset):
"""Base EEG dataset for Sleep Heart Health Study (SHHS)
"""EEG and ECG dataset for Sleep Heart Health Study (SHHS)

Dataset is available at https://sleepdata.org/datasets/shhs

The Sleep Heart Health Study (SHHS) is a multi-center cohort study implemented by the National Heart Lung & Blood Institute to determine the cardiovascular and other consequences of sleep-disordered breathing. It tests whether sleep-related breathing is associated with an increased risk of coronary heart disease, stroke, all cause mortality, and hypertension. In all, 6,441 men and women aged 40 years and older were enrolled between November 1, 1995 and January 31, 1998 to take part in SHHS Visit 1. During exam cycle 3 (January 2001- June 2003), a second polysomnogram (SHHS Visit 2) was obtained in 3,295 of the participants. CVD Outcomes data were monitored and adjudicated by parent cohorts between baseline and 2011. More than 130 manuscripts have been published investigating predictors and outcomes of sleep disorders.

This dataset supports both EEG and ECG signal processing.

Args:
dataset_name: name of the dataset.
root: root directory of the raw data (should contain many csv files).
root: root directory of the raw data (should contain EDF files and annotations).
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.

Attributes:
task: Optional[str], name of the task (e.g., "sleep staging").
task: Optional[str], name of the task (e.g., "sleep staging", "ecg analysis").
Default is None.
samples: Optional[List[Dict]], a list of samples, each sample is a dict with
patient_id, record_id, and other task-specific attributes as key.
Expand All @@ -30,16 +34,26 @@ class SHHSDataset(BaseSignalDataset):
a list of sample indices. Default is None.
visit_to_index: Optional[Dict[str, List[int]]], a dict mapping visit_id to a
list of sample indices. Default is None.
patients: Dict[str, List[Dict]], processed patient data with EEG/ECG file paths.

Examples:
>>> from pyhealth.datasets import SHHSDataset
>>> dataset = SHHSDataset(
... root="/srv/local/data/SHHS/",
... )
>>> dataset.stat()
>>> dataset.info()
>>> # Process ECG data
>>> dataset.process_ECG_data(out_dir="./ecg_output")
"""

def __init__(self, root, dev=False, refresh_cache=False, **kwargs):
"""Initialize SHHS Dataset"""
super().__init__()
self.root = root
self.dev = dev
self.refresh_cache = refresh_cache
self.filepath = os.path.join(os.path.expanduser("~"), ".cache", "pyhealth_shhs")
self.patients = self.process_EEG_data()

def parse_patient_id(self, file_name):
"""
Args:
Expand Down Expand Up @@ -103,13 +117,144 @@ def process_EEG_data(self):
)
return patients

def process_ECG_data(self, out_dir, target_fs=None, select_chs=["ECG"], require_annotations=False):
"""
Extract SHHS ECG signals + labels and save them as .npz files.

Args:
out_dir: Destination directory for generated .npz files.
target_fs: Optional int, target sampling rate (e.g., 100 Hz).
select_chs: list of channels to extract, default ECG.
require_annotations: If True, skip files without annotations. If False, process signals without labels.

Expected SHHS directory structure:
root/
edfs/shhs1/*.edf
edfs/shhs2/*.edf
annotations-events-profusion/shhs1/*.xml
annotations-events-profusion/label/*.xml (for shhs2)
"""

shhs_configs = [
{
"data_dir": os.path.join(self.root, "edfs", "shhs1"),
"annotation_dir": os.path.join(self.root, "annotations-events-profusion", "shhs1"),
"label": "shhs1"
},
{
"data_dir": os.path.join(self.root, "edfs", "shhs2"),
"annotation_dir": os.path.join(self.root, "annotations-events-profusion", "label"),
"label": "shhs2"
}
]

os.makedirs(out_dir, exist_ok=True)
processed_count = 0
skipped_count = 0

for config in shhs_configs:
data_dir = config["data_dir"]
annotation_dir = config["annotation_dir"]
dir_label = config["label"]

if not os.path.exists(data_dir):
print(f"Directory missing: {data_dir}")
continue

files = [f for f in os.listdir(data_dir) if f.endswith(".edf")]
print(f"Processing ECG for {dir_label}: {len(files)} EDF files found")

if not files:
continue

for file in tqdm(files, desc=f"Processing {dir_label}"):
sid = self.parse_patient_id(file)
data_path = os.path.join(data_dir, file)

# Determine annotation file path
if dir_label == "shhs1":
annotation_filename = f"shhs1-{sid}-profusion.xml"
else: # shhs2
annotation_filename = f"shhs2-{sid}-profusion.xml"

label_path = os.path.join(annotation_dir, annotation_filename)

# Check if annotation exists
has_annotation = os.path.exists(label_path)

if require_annotations and not has_annotation:
print(f"Skipping {sid}: missing annotation {label_path}")
skipped_count += 1
continue

try:
if has_annotation:
# Process with annotations
data, fs, stages = read_edf_data(
data_path=data_path,
label_path=label_path,
dataset="SHHS",
select_chs=select_chs,
target_fs=target_fs,
)
outfile = os.path.join(out_dir, f"{dir_label}-{sid}.npz")
save_to_npz(outfile, data, stages, fs)
print(f"✓ Processed {sid} with annotations")
else:
# Process without annotations (signals only) - skip label_path entirely
try:
# Try to read EDF file directly without using read_edf_data for labels
import mne
raw = mne.io.read_raw_edf(data_path, preload=True, verbose=False)

# Select channels
if select_chs:
available_chs = [ch for ch in select_chs if ch in raw.ch_names]
if not available_chs:
print(f"⚠ No requested channels found in {sid}. Available: {raw.ch_names}")
skipped_count += 1
continue
raw = raw.pick_channels(available_chs)

# Get data and sampling frequency
data = raw.get_data()
fs = raw.info['sfreq']

# Resample if needed
if target_fs and target_fs != fs:
raw = raw.resample(target_fs)
data = raw.get_data()
fs = target_fs

outfile = os.path.join(out_dir, f"{dir_label}-{sid}_no_labels.npz")
save_to_npz(outfile, data, None, fs)
print(f"⚠ Processed {sid} without annotations (signals only)")

except Exception as edf_error:
print(f"❌ Error reading EDF file for {sid}: {edf_error}")
skipped_count += 1
continue

processed_count += 1

except Exception as e:
print(f"❌ Error processing patient {sid}: {e}")
skipped_count += 1

print(f"\nECG extraction completed:")
print(f" ✓ Successfully processed: {processed_count} files")
print(f" ⚠ Skipped/failed: {skipped_count} files")

return processed_count > 0

if __name__ == "__main__":
dataset = SHHSDataset(
root="/srv/local/data/SHHS/polysomnography",
dev=True,
refresh_cache=True,
)
dataset.stat()
dataset.info()
print(f"Dataset loaded with {len(dataset.patients)} patients")
print(list(dataset.patients.items())[0])

# Example ECG processing
# dataset.process_ECG_data(out_dir="./ecg_output")
86 changes: 86 additions & 0 deletions pyhealth/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from pyhealth import BASE_CACHE_PATH
from pyhealth.utils import create_directory

import numpy as np
import pyedflib
from scipy.signal import resample
from xml.etree import ElementTree as ET

MODULE_CACHE_PATH = os.path.join(BASE_CACHE_PATH, "datasets")
create_directory(MODULE_CACHE_PATH)

Expand Down Expand Up @@ -436,6 +441,87 @@ def load_processors(processor_dir: str) -> Tuple[Dict, Dict]:

return input_processors, output_processors

def read_edf_data(data_path, label_path, dataset, select_chs, target_fs=None):
"""
Lightweight EDF reader + sleep stage extractor WITHOUT MNE.

Parameters:
data_path: path to EDF file.
label_path: SHHS XML annotation file.
dataset: "SHHS" or "MESA".
select_chs: list of channels to extract.
target_fs: optional downsample frequency.

Returns:
data: (T, C) extracted channel signals.
fs: sampling frequency.
stages: stage array aligned with signal.
"""

# Dataset-specific channel exclusions (as before)
if dataset == "SHHS":
exclude_chs = ["SaO2", "H.R.", "SOUND", "AIRFLOW", "POSITION", "LIGHT"]
elif dataset == "MESA":
exclude_chs = ["EEG1", "EEG2", "Snore", "Thor", "Abdo", "Leg", "Therm", "Pos"]
else:
raise ValueError("Unsupported dataset. Use 'SHHS' or 'MESA'.")

# ---- Read EDF ----
f = pyedflib.EdfReader(data_path)
channel_labels = f.getSignalLabels()
original_fs = f.getSampleFrequencies()

# Determine which EDF channels to load
selected_idxs = []
for ch in select_chs:
if ch in channel_labels:
selected_idxs.append(channel_labels.index(ch))
else:
raise ValueError(f"Channel {ch} not found in EDF")

# Read signals into (T, C)
signals = []
for idx in selected_idxs:
sig = f.readSignal(idx)
signals.append(sig)
f.close()

# Shape: (C, T) → (T, C)
data = np.stack(signals, axis=-1)

# Use the first selected channel’s fs (usually they share same fs)
fs = original_fs[selected_idxs[0]]

# ---- Downsample ----
if target_fs and fs > target_fs:
factor = target_fs / fs
data = resample(data, int(len(data) * factor), axis=0)
fs = target_fs

# ---- Parse sleep stage annotations from XML ----
tree = ET.parse(label_path)
root = tree.getroot()
stages = np.array([int(s.text) for s in root[4].findall("SleepStage")], dtype=np.int8)

# Merge stages for consistency with standard:
# 3←4, 4←5 (deep/REM merge)
stages[stages == 4] = 3
stages[stages == 5] = 4

# Expand 30-second epochs to per-sample labels
expanded = np.repeat(stages, int(30 * fs))

# Align lengths
min_len = min(len(data), len(expanded))
data = data[:min_len]
stages = expanded[:min_len]

return data, fs, stages


def save_to_npz(out_path, data, stages, fs):
"""Saves extracted ECG/PPG/sleep staging data to NPZ."""
np.savez(out_path, data=data, stages=stages, fs=fs)

if __name__ == "__main__":
print(list_nested_levels([1, 2, 3]))
Expand Down
Loading