Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions TarDiff/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# TarDiff: Target-Oriented Diffusion Guidance for Synthetic Electronic Health Record Time Series Generation



## Introduction
![TarDiff OverView](./images/overview.png)

Synthetic Electronic Health Record (EHR) time-series generation is crucial for advancing clinical machine learning models, as it helps address data scarcity by providing more training data. However, most existing approaches focus primarily on replicating statistical distributions and temporal dependencies of real-world data. We argue that fidelity to observed data alone does not guarantee better model performance, as common patterns may dominate, limiting the representation of rare but important conditions. This highlights the need for generate synthetic samples to improve performance of specific clinical models to fulfill their target outcomes. To address this, we propose TarDiff, a novel target-oriented diffusion framework that integrates task-specific influence guidance into the synthetic data generation process. Unlike conventional approaches that mimic training data distributions, TarDiff optimizes synthetic samples by quantifying their expected contribution to improving downstream model performance through influence functions. Specifically, we measure the reduction in task-specific loss induced by synthetic samples and embed this influence gradient into the reverse diffusion process, thereby steering the generation towards utility-optimized data. Evaluated on six publicly available EHR datasets, TarDiff achieves state-of-the-art performance, outperforming existing methods by up to 20.4% in AUPRC and 18.4% in AUROC. Our results demonstrate that TarDiff not only preserves temporal fidelity but also enhances downstream model performance, offering a robust solution to data scarcity and class imbalance in healthcare analytics.


## 1 · Environment

Prepare TarDiff's environment.
```
conda env create -f environment.yaml
conda activate tardiff
```

Prepare TS downstream task environment depands on the repo you used for the specific task.

## 2 · Data Pre-processing

You can access the raw datasets at the following links:

- [eICU Collaborative Research Database](https://eicu-crd.mit.edu/)
- [MIMIC-III Clinical Database](https://physionet.org/content/mimiciii/1.4/)

> **Note:** Both datasets require prior approval and credentialing before download.

We focus exclusively on the multivariate time-series recordings available in these datasets.
To assist with preprocessing, we provide high-level extraction scripts under **`data_preprocess/`**.



## 3 · Stage 1 · Train the *Base* Diffusion Model

```bash
bash train.sh # trains TarDiff on MIMIC-III ICU-stay data
```

> **Edit tip:** open the example YAML in `configs/base/` and replace any placeholder data paths with your own before running.

This step produces an unconditional diffusion model checkpoint—no guidance yet.

---

## 4 · Stage 2 · Train a Downstream Task Model (Guidance Source)

An example RNN classifier is supplied in **`classifier/`**.

```bash
cd classifier
bash train.sh # saves weights to classifier/checkpoint/
cd ..
```

Feel free to swap in any architecture that suits your task.

---

## 5 · Stage 3 · Target-Guided Generation

With **both** checkpoints ready—the diffusion backbone and the task model—start guided sampling:

```bash
bash generation.sh # remember to update paths to both weights
```

The script creates a synthetic dataset tailored to the guidance task.

---

## 6 · Stage 4 · Utility Evaluation — *TSTR* and *TSRTR*

After generation, you can assess the utility of the synthetic data for the **target task** using two complementary protocols:

| Protocol | Training Set | Test Set | Question Answered |
|:---------|:-------------|:---------|:------------------|
| **TSTR** (Train-Synthetic, Test-Real) | **Synthetic only** | **Real** | “If I train a model purely on synthetic EHRs, how well does it generalize to real patients?” |
| **TSRTR** (Train-Synthetic-Real, Test-Real) | **Synthetic + Real** (α ∈ {0.2, …, 1.0}) | **Real** | “If I augment the real training set with α× synthetic samples, does it improve model performance?” |

---

### How to run TSTR and TSRTR evaluations

You can directly reuse the training script under `classifier/` to run both evaluations:

```bash
cd classifier
bash train.sh # Edit the training data path to point to either synthetic-only or mixed (real + synthetic) data
```

- For **TSTR**, set the training set path to the synthetic dataset.
- For **TSRTR**, combine the real training data with synthetic samples according to your desired α ratio, and update the path accordingly.

The downstream model will be trained and evaluated automatically on the real validation and test sets.

---

Enjoy exploring target-oriented diffusion for healthcare ML! For issues or pull requests, please open a GitHub ticket.
2 changes: 2 additions & 0 deletions TarDiff/classifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
213 changes: 213 additions & 0 deletions TarDiff/classifier/classifier_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.


from __future__ import annotations

import argparse
from pathlib import Path
from model import RNNClassifier
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import os
import pandas as pd
from typing import Optional, Tuple


class TimeSeriesDataset(Dataset):

def __init__(
self,
data,
labels,
normalize=True,
stats=None,
eps=1e-8,
):

if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if isinstance(labels, np.ndarray):
labels = torch.from_numpy(labels)
assert data.ndim == 3, "data must be (N, seq_len, n_features)"
assert len(data) == len(labels)
self.data = data.float()
self.labels = labels.long()
self.normalize = normalize
self.eps = eps

if self.normalize:
if stats is None:
# compute mean/std over all time‑steps *per feature*
mean = self.data.mean(dim=(0, 1), keepdim=True) # (1,1,F)
std = self.data.std(dim=(0, 1), keepdim=True)
else:
mean, std = stats
if isinstance(mean, np.ndarray):
mean = torch.from_numpy(mean)
if isinstance(std, np.ndarray):
std = torch.from_numpy(std)
mean, std = mean.float(), std.float()
self.register_buffer("_mean",
mean) # cached on device when .to(...)
self.register_buffer("_std", std.clamp_min(self.eps))

# tiny helper so buffers exist even on CPU tensors
def register_buffer(self, name: str, tensor: torch.Tensor):
object.__setattr__(self, name, tensor)

# expose stats to reuse on other splits
@property
def stats(self) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
if not self.normalize:
return None
return self._mean, self._std

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
x = self.data[idx]
if self.normalize:
x = (x - self._mean) / self._std
x = x.squeeze(0)
return x, self.labels[idx]


# -----------------------------------------------------------------------------
# Train / Eval helpers
# -----------------------------------------------------------------------------


def _accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
if logits.dim() == 1: # binary with BCEWithLogits
preds = (torch.sigmoid(logits) > 0.5).long()
else: # multi‑class with CE
preds = logits.argmax(dim=1)
return (preds == labels).float().mean().item()


def _run_epoch(model, loader, criterion, optimizer, device, train=True):
if train:
model.train()
else:
model.eval()
total_loss = 0.0
total_acc = 0.0
for x, y in loader:
x, y = x.to(device), y.to(device)
if train:
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y.float() if logits.dim() == 1 else y)
if train:
loss.backward()
optimizer.step()
total_loss += loss.item() * x.size(0)
total_acc += _accuracy(logits, y) * x.size(0)
n = len(loader.dataset)
return total_loss / n, total_acc / n


# -----------------------------------------------------------------------------
# Main – quick demo on synthetic data
# -----------------------------------------------------------------------------


def main(args):
rng = np.random.default_rng(args.seed)
if os.path.exists(args.train_data) and os.path.exists(args.val_data):
X_train, y_train = pd.read_pickle(args.train_data)
X_train = X_train.transpose(0, 2, 1)
X_val, y_val = pd.read_pickle(args.val_data)
X_val = X_val.transpose(0, 2, 1)
else:
X_train = rng.standard_normal(size=(20000, 24, 7), dtype=np.float32)
y_train = rng.integers(0,
args.num_classes,
size=(20000, ),
dtype=np.int64)
X_val = rng.standard_normal(size=(5000, 24, 7), dtype=np.float32)
y_val = rng.integers(0,
args.num_classes,
size=(5000, ),
dtype=np.int64)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
train_set = TimeSeriesDataset(X_train, y_train)
val_set = TimeSeriesDataset(X_val, y_val)

train_loader = DataLoader(train_set,
batch_size=args.batch_size,
shuffle=True)
val_loader = DataLoader(val_set, batch_size=args.batch_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RNNClassifier(
input_dim=7,
hidden_dim=args.hidden_dim,
num_layers=args.num_layers,
rnn_type=args.rnn_type,
num_classes=args.num_classes,
dropout=args.dropout,
).to(device)

criterion = (nn.BCEWithLogitsLoss()
if args.num_classes == 1 else nn.CrossEntropyLoss())
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

best_val = 0.0
Path(args.ckpt_dir).mkdir(parents=True, exist_ok=True)

for epoch in tqdm(range(1, args.epochs + 1)):
tr_loss, tr_acc = _run_epoch(model,
train_loader,
criterion,
optimizer,
device,
train=True)
va_loss, va_acc = _run_epoch(model,
val_loader,
criterion,
optimizer,
device,
train=False)
print(
f"Epoch {epoch:02d} | train {tr_loss:.4f}/{tr_acc:.4f} | val {va_loss:.4f}/{va_acc:.4f}"
)
if va_acc > best_val:
best_val = va_acc
print(f"New best val acc: {best_val:.4f} -> saving model")
torch.save({"model_state": model.state_dict()},
Path(args.ckpt_dir) / "best_model.pt")
print(f"Train Finished. Best val acc: {best_val:.4f}")


# -----------------------------------------------------------------------------
# CLI
# -----------------------------------------------------------------------------

if __name__ == "__main__":
p = argparse.ArgumentParser(
description="Bidirectional LSTM/GRU time‑series classifier")
p.add_argument("--hidden_dim", type=int, default=128)
p.add_argument("--num_layers", type=int, default=2)
p.add_argument("--rnn_type", choices=["lstm", "gru"], default="lstm")
p.add_argument("--num_classes",
type=int,
default=1,
help="1 for binary, >1 for multi‑class")
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--epochs", type=int, default=40)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--dropout", type=float, default=0.2)
p.add_argument("--train_data", type=str, default="data/train_data.npy")
p.add_argument("--val_data", type=str, default="data/val_data.npy")
p.add_argument("--ckpt_dir", type=str, default="checkpoints")
p.add_argument("--seed", type=int, default=42)

args = p.parse_args()
main(args)
38 changes: 38 additions & 0 deletions TarDiff/classifier/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.


from __future__ import annotations
import torch
from torch import nn


class RNNClassifier(nn.Module):
"""Bidirectional LSTM/GRU classifier for fixed‑length sequences."""

def __init__(
self,
input_dim: int,
hidden_dim: int = 128,
num_layers: int = 2,
rnn_type: str = "lstm",
num_classes: int = 2,
dropout: float = 0.2,
) -> None:
super().__init__()
rnn_cls = {"lstm": nn.LSTM, "gru": nn.GRU}[rnn_type.lower()]
self.rnn = rnn_cls(
input_size=input_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if num_layers > 1 else 0.0,
)
self.fc = nn.Linear(hidden_dim * 2, num_classes)

def forward(self, x: torch.Tensor) -> torch.Tensor: # (B, T, F)
rnn_out, _ = self.rnn(x) # (B, T, 2*H)
last_hidden = rnn_out[:, -1, :] # final time‑step representation
logits = self.fc(last_hidden) # (B, C) or (B, 1)
return logits.squeeze(-1) # binary → (B,) ; multi‑class stays (B, C)
6 changes: 6 additions & 0 deletions TarDiff/classifier/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# prepare guidance model train_tuple.pkl : (data, label)

python classifier_train.py --num_classes 1 --rnn_type gru --hidden_dim 256 --train_data data/mimic_icustay/train_tuple.pkl --val_data data/mimic_icustay/val_tuple.pkl



Loading