Skip to content

Conversation

@ruokun-niu
Copy link

@ruokun-niu ruokun-niu commented Nov 30, 2025

UIUC ID: ruokunn2

Summary

This PR adds the LabTOP (Lab Test Outcome Prediction) model to PyHealth, enabling continuous numerical prediction of laboratory test values.

LabTOP uses digit-wise tokenization to represent numerical values as sequences of individual digits (e.g., 123.45['1','2','3','.','4','5']), preserving exact precision while maintaining a compact vocabulary of ~20-50 tokens.

Paper Reference

Implementation Details

  • Architecture: GPT-2 transformer (12 layers, 768 dim, ~53M parameters)
  • Classes Added:
    • DigitWiseTokenizer: Converts numbers ↔ digit sequences
    • LabTOPVocabulary: Manages complete vocabulary (special tokens + digits + lab codes)
    • LabTOP: Main model class inheriting from BaseModel

Files Modified

  • pyhealth/models/labtop.py (new file, ~600 lines)

Performance (from paper)

  • MAE: 0.064, SMAPE: 14.80%, NMAE: 0.042 on MIMIC-IV (44 lab types)

Signed-off-by: ruokun-niu <[email protected]>
@LogicFan
Copy link
Collaborator

LogicFan commented Dec 7, 2025

Could you add some test case for this model? thanks.

@LogicFan LogicFan requested a review from Copilot December 7, 2025 16:17
@LogicFan LogicFan self-requested a review December 7, 2025 16:18
@LogicFan LogicFan added the model Contribute a new model to PyHealth label Dec 7, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Comment on lines +1 to +536
"""
LabTOP: Lab Test Outcome Prediction using GPT-2 with Digit-Wise Tokenization

Paper: Im et al. "LabTOP: A Unified Model for Lab Test Outcome Prediction
on Electronic Health Records" CHIL 2025 (Best Paper Award)
https://arxiv.org/abs/2502.14259

This implementation follows the PyHealth BaseModel structure.
"""

from typing import List, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2LMHeadModel

try:
from pyhealth.models import BaseModel
from pyhealth.datasets import SampleDataset
except ImportError:
# Fallback for standalone use
class BaseModel(nn.Module):
def __init__(self, dataset=None, feature_keys=None, label_key=None, mode="regression"):
super().__init__()
self.dataset = dataset
self.feature_keys = feature_keys
self.label_key = label_key
self.mode = mode


class DigitWiseTokenizer:
"""
Tokenizer that converts numerical values to digit sequences.

This is LabTOP's key innovation for preserving exact numerical precision
while maintaining a compact vocabulary.

Example:
>>> tokenizer = DigitWiseTokenizer(precision=2)
>>> tokens = tokenizer.number_to_tokens(123.45)
>>> # Returns: ['1', '2', '3', '.', '4', '5']

Args:
precision: Number of decimal places to keep (default: 2)
"""

def __init__(self, precision: int = 2):
self.precision = precision

# Special tokens
self.special_tokens = {
'PAD': '<|pad|>',
'EOS': '<|endoftext|>',
'SEP': '|endofevent|',
'LAB': '<|lab|>',
'AGE': '<|age|>',
'GENDER_M': '<|gender_m|>',
'GENDER_F': '<|gender_f|>',
}

# Digit tokens (0-9, '.', '-')
self.digit_tokens = [str(i) for i in range(10)] + ['.', '-']

# Build vocabulary
self.vocab = {}
self.id_to_token = {}

# Add special tokens
idx = 0
for token in self.special_tokens.values():
self.vocab[token] = idx
self.id_to_token[idx] = token
idx += 1

# Add digit tokens
for token in self.digit_tokens:
self.vocab[token] = idx
self.id_to_token[idx] = token
idx += 1

def number_to_tokens(self, number: float) -> List[str]:
"""Convert a number to list of digit tokens."""
number = round(float(number), self.precision)
num_str = f"{number:.{self.precision}f}"
return list(num_str)

def tokens_to_number(self, tokens: List[str]) -> Optional[float]:
"""Convert list of digit tokens back to number."""
num_str = ''.join(tokens)
try:
return float(num_str)
except ValueError:
return None

def encode(self, tokens: List[str]) -> List[int]:
"""Convert tokens to IDs."""
return [self.vocab.get(token, self.vocab[self.special_tokens['PAD']])
for token in tokens]

def decode(self, ids: List[int]) -> List[str]:
"""Convert token IDs back to tokens."""
return [self.id_to_token.get(id, self.special_tokens['PAD'])
for id in ids]

def __len__(self) -> int:
return len(self.vocab)


class LabTOPVocabulary:
"""
Complete vocabulary for LabTOP including special tokens, digit tokens,
and lab item codes.

Args:
lab_items: List of unique lab item IDs
digit_tokenizer: DigitWiseTokenizer instance
"""

def __init__(self, lab_items: List[int], digit_tokenizer: DigitWiseTokenizer):
self.digit_tokenizer = digit_tokenizer

# Start with digit tokenizer vocab
self.vocab = dict(digit_tokenizer.vocab)
self.id_to_token = dict(digit_tokenizer.id_to_token)

# Add lab item codes
idx = len(self.vocab)
for lab_id in sorted(lab_items):
token = f"<|lab_{lab_id}|>"
self.vocab[token] = idx
self.id_to_token[idx] = token
idx += 1

# Store special token IDs
self.pad_token_id = self.vocab[digit_tokenizer.special_tokens['PAD']]
self.eos_token_id = self.vocab[digit_tokenizer.special_tokens['EOS']]
self.sep_token_id = self.vocab[digit_tokenizer.special_tokens['SEP']]
self.lab_token_id = self.vocab[digit_tokenizer.special_tokens['LAB']]

def encode_event(self, event: Dict) -> List[int]:
"""
Encode a lab event into token IDs.

Event format: <|lab|> <|lab_50912|> 1 . 2 3 |endofevent|

Args:
event: Dict with 'code' (itemid) and 'value' (number)

Returns:
List of token IDs
"""
tokens = []

# Lab type marker
tokens.append(self.digit_tokenizer.special_tokens['LAB'])

# Lab item code
lab_token = f"<|lab_{event['code']}|>"
tokens.append(lab_token)

# Lab value (digit-wise)
value_tokens = self.digit_tokenizer.number_to_tokens(event['value'])
tokens.extend(value_tokens)

# Separator
tokens.append(self.digit_tokenizer.special_tokens['SEP'])

# Convert to IDs
return [self.vocab[t] for t in tokens]

def encode_demographics(self, age: Optional[int], gender: Optional[str]) -> List[int]:
"""
Encode patient demographics.

Format: <|age|> 6 5 <|gender_m|>
"""
tokens = []

# Age
if age is not None:
tokens.append(self.digit_tokenizer.special_tokens['AGE'])
age_tokens = self.digit_tokenizer.number_to_tokens(int(age))
tokens.extend(age_tokens)

# Gender
if gender == 'M':
tokens.append(self.digit_tokenizer.special_tokens['GENDER_M'])
elif gender == 'F':
tokens.append(self.digit_tokenizer.special_tokens['GENDER_F'])

# Convert to IDs
return [self.vocab[t] for t in tokens]

def __len__(self) -> int:
return len(self.vocab)


class LabTOP(BaseModel):
"""
LabTOP: Lab Test Outcome Prediction Model

A GPT-2 based transformer that predicts lab test outcomes using
digit-wise tokenization for continuous numerical predictions.

Paper: Im et al. "LabTOP: A Unified Model for Lab Test Outcome Prediction
on Electronic Health Records" CHIL 2025 (Best Paper Award)
https://arxiv.org/abs/2502.14259

Key Innovation:
- Digit-wise tokenization: Represents numerical values as sequences
of individual digits (e.g., 123.45 → ['1','2','3','.','4','5'])
- Preserves exact numerical precision
- Small vocabulary (only ~20-50 tokens total)
- Unified model for all lab test types

Args:
dataset: PyHealth dataset object
feature_keys: List of input feature names
label_key: Target lab test name
mode: Prediction mode (default: "regression")
n_layers: Number of transformer layers (default: 12)
n_heads: Number of attention heads (default: 12)
embedding_dim: Embedding dimension (default: 768)
max_seq_length: Maximum sequence length (default: 1024)
digit_precision: Decimal precision for values (default: 2)
dropout: Dropout rate (default: 0.1)
**kwargs: Additional arguments

Examples:
>>> from pyhealth.datasets import MIMIC4Dataset
>>> from pyhealth.models import LabTOP
>>>
>>> # Load dataset
>>> dataset = MIMIC4Dataset(root="/data/mimic4")
>>>
>>> # Initialize model
>>> model = LabTOP(
... dataset=dataset,
... feature_keys=["demographics", "lab_history"],
... label_key="lab_value",
... mode="regression",
... embedding_dim=768,
... n_layers=12
... )
>>>
>>> # Forward pass
>>> outputs = model(**batch)
>>> loss = outputs["loss"]

References:
- Paper: https://arxiv.org/abs/2502.14259
- Code: https://github.com/sujeongim/LabTOP
"""

def __init__(
self,
dataset,
feature_keys: List[str],
label_key: str,
mode: str = "regression",
n_layers: int = 12,
n_heads: int = 12,
embedding_dim: int = 768,
max_seq_length: int = 1024,
digit_precision: int = 2,
dropout: float = 0.1,
**kwargs
):
super(LabTOP, self).__init__(
dataset=dataset,
feature_keys=feature_keys,
label_key=label_key,
mode=mode,
)

self.n_layers = n_layers
self.n_heads = n_heads
self.embedding_dim = embedding_dim
self.max_seq_length = max_seq_length
self.digit_precision = digit_precision

# Initialize digit-wise tokenizer
self.digit_tokenizer = DigitWiseTokenizer(precision=digit_precision)

# Build vocabulary (will be updated with lab items from dataset)
# For now, use a placeholder - should be set with build_vocabulary()
self.vocabulary = None
self.vocab_size = len(self.digit_tokenizer) # Base vocab size

# GPT-2 configuration
self.gpt2_config = GPT2Config(
vocab_size=self.vocab_size, # Will be updated after vocab built
n_positions=max_seq_length,
n_embd=embedding_dim,
n_layer=n_layers,
n_head=n_heads,
n_inner=embedding_dim * 4,
activation_function='gelu_new',
resid_pdrop=dropout,
embd_pdrop=dropout,
attn_pdrop=dropout,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
)

# Initialize GPT-2 model (will be rebuilt after vocabulary is set)
self.model = None

def build_vocabulary(self, lab_items: List[int]) -> None:
"""
Build complete vocabulary including lab item codes.

This should be called after determining unique lab items from dataset.

Args:
lab_items: List of unique lab item IDs
"""
self.vocabulary = LabTOPVocabulary(lab_items, self.digit_tokenizer)
self.vocab_size = len(self.vocabulary)

# Update GPT-2 config with actual vocab size
self.gpt2_config.vocab_size = self.vocab_size
self.gpt2_config.bos_token_id = self.vocabulary.eos_token_id
self.gpt2_config.eos_token_id = self.vocabulary.eos_token_id
self.gpt2_config.pad_token_id = self.vocabulary.pad_token_id

# Build GPT-2 model
self.model = GPT2LMHeadModel(self.gpt2_config)

def prepare_input(
self,
demographics: Dict,
lab_history: List[Dict],
max_length: Optional[int] = None
) -> Dict[str, torch.Tensor]:
"""
Prepare input sequence from patient data.

Args:
demographics: Dict with 'age' and 'gender'
lab_history: List of events with 'code', 'value', 'timestamp'
max_length: Maximum sequence length (uses self.max_seq_length if None)

Returns:
Dict with 'input_ids', 'attention_mask'
"""
if self.vocabulary is None:
raise ValueError("Vocabulary not built. Call build_vocabulary() first.")

max_length = max_length or self.max_seq_length

# Start with demographics
token_ids = self.vocabulary.encode_demographics(
demographics.get('age'),
demographics.get('gender')
)

# Add lab events (sorted by timestamp)
for event in lab_history:
event_ids = self.vocabulary.encode_event(event)
token_ids.extend(event_ids)

if len(token_ids) >= max_length - 1:
break

# Add EOS token
token_ids.append(self.vocabulary.eos_token_id)

# Truncate if needed
if len(token_ids) > max_length:
token_ids = token_ids[:max_length]

# Create attention mask
attention_mask = [1] * len(token_ids)

# Pad to max_length
while len(token_ids) < max_length:
token_ids.append(self.vocabulary.pad_token_id)
attention_mask.append(0)

return {
'input_ids': torch.tensor([token_ids], dtype=torch.long),
'attention_mask': torch.tensor([attention_mask], dtype=torch.long)
}

def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs
) -> Dict[str, torch.Tensor]:
"""
Forward pass through LabTOP model.

Args:
input_ids: Token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
labels: Target token IDs for training [batch_size, seq_len]

Returns:
Dict with 'logits', 'loss' (if labels provided), 'y_prob', 'y_true'
"""
if self.model is None:
raise ValueError("Model not built. Call build_vocabulary() first.")

# Forward through GPT-2
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
return_dict=True
)

result = {
'logits': outputs.logits, # [batch_size, seq_len, vocab_size]
}

if labels is not None:
result['loss'] = outputs.loss
# For PyHealth compatibility
result['y_true'] = labels
result['y_prob'] = F.softmax(outputs.logits, dim=-1)

return result

def generate(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
max_new_tokens: int = 20,
temperature: float = 1.0,
**kwargs
) -> torch.Tensor:
"""
Generate lab value prediction autoregressively.

Args:
input_ids: Input token IDs
attention_mask: Attention mask
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature

Returns:
Generated token IDs
"""
if self.model is None:
raise ValueError("Model not built. Call build_vocabulary() first.")

return self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=self.vocabulary.pad_token_id,
eos_token_id=self.vocabulary.eos_token_id,
**kwargs
)

def decode_prediction(self, token_ids: List[int]) -> Optional[float]:
"""
Decode generated token IDs back to numerical value.

Args:
token_ids: List of generated token IDs

Returns:
Predicted numerical value or None if invalid
"""
if self.vocabulary is None:
raise ValueError("Vocabulary not built.")

# Find digit tokens between lab marker and separator
digit_tokens = []
in_value = False

for token_id in token_ids:
token = self.vocabulary.id_to_token.get(token_id, '')

if token == self.digit_tokenizer.special_tokens['LAB']:
in_value = True
continue

if token == self.digit_tokenizer.special_tokens['SEP']:
break

if in_value and token in self.digit_tokenizer.digit_tokens:
digit_tokens.append(token)

# Convert to number
return self.digit_tokenizer.tokens_to_number(digit_tokens)


# For backward compatibility and standalone testing
if __name__ == "__main__":
print("LabTOP Model")
print("=" * 70)
print("A GPT-2 based model for lab test outcome prediction")
print("with digit-wise tokenization.")
print()
print("Paper: Im et al. CHIL 2025 (Best Paper Award)")
print("https://arxiv.org/abs/2502.14259")
print("=" * 70)

# Example usage
print("\nExample: Building vocabulary and model")

# Mock dataset
class MockDataset:
pass

# Initialize model
model = LabTOP(
dataset=MockDataset(),
feature_keys=["demographics", "lab_history"],
label_key="lab_value",
mode="regression",
n_layers=12,
embedding_dim=768
)

# Build vocabulary with example lab items
lab_items = [50912, 50931, 50971] # Creatinine, Glucose, Potassium
model.build_vocabulary(lab_items)

print(f"✅ Model built with {len(model.vocabulary)} vocab tokens")
print(f" Parameters: {sum(p.numel() for p in model.model.parameters()):,}")

# Test tokenization
print("\nExample: Digit-wise tokenization")
tokenizer = DigitWiseTokenizer(precision=2)
value = 123.45
tokens = tokenizer.number_to_tokens(value)
print(f" Value: {value}")
print(f" Tokens: {tokens}")
print(f" Reconstructed: {tokenizer.tokens_to_number(tokens)}") No newline at end of file
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LabTOP model should be exported in pyhealth/models/__init__.py to make it accessible via from pyhealth.models import LabTOP. Currently, all other models are exported there (lines 1-29), but LabTOP is missing.

Add to __init__.py:

from .labtop import LabTOP, DigitWiseTokenizer, LabTOPVocabulary

Copilot uses AI. Check for mistakes.
)

# Add lab events (sorted by timestamp)
for event in lab_history:
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop doesn't sort lab events by timestamp before encoding, despite the comment "Add lab events (sorted by timestamp)" on line 358. This could lead to incorrect temporal ordering if the input lab_history is not pre-sorted.

Either:

  1. Add sorting: for event in sorted(lab_history, key=lambda e: e.get('timestamp', 0)):, or
  2. Document that lab_history must be pre-sorted by the caller
Suggested change
for event in lab_history:
for event in sorted(lab_history, key=lambda e: e.get('timestamp', 0)):

Copilot uses AI. Check for mistakes.
Comment on lines +255 to +274
def __init__(
self,
dataset,
feature_keys: List[str],
label_key: str,
mode: str = "regression",
n_layers: int = 12,
n_heads: int = 12,
embedding_dim: int = 768,
max_seq_length: int = 1024,
digit_precision: int = 2,
dropout: float = 0.1,
**kwargs
):
super(LabTOP, self).__init__(
dataset=dataset,
feature_keys=feature_keys,
label_key=label_key,
mode=mode,
)
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model signature deviates from PyHealth's BaseModel pattern. According to PyHealth conventions (see base_model.py:30-34 and rnn.py:156-158), feature_keys and label_key should not be required parameters - they are automatically derived from dataset.input_schema and dataset.output_schema in the BaseModel constructor.

Recommendation: Remove feature_keys and label_key from the constructor parameters and rely on the automatic schema extraction from BaseModel. If you need to access them, use self.feature_keys and self.label_keys after calling super().__init__().

Copilot uses AI. Check for mistakes.
Comment on lines +21 to +30
# Fallback for standalone use
class BaseModel(nn.Module):
def __init__(self, dataset=None, feature_keys=None, label_key=None, mode="regression"):
super().__init__()
self.dataset = dataset
self.feature_keys = feature_keys
self.label_key = label_key
self.mode = mode


Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback BaseModel implementation (lines 22-28) is incomplete and may lead to runtime errors. It's missing critical methods like get_output_size(), get_loss_function(), prepare_y_prob(), and the device property that are used by PyHealth's training pipeline.

Since this is meant for PyHealth integration, the fallback should either be removed entirely (raising an ImportError if PyHealth is not available), or it should include a more complete stub implementation. The current approach creates a false sense of compatibility.

Suggested change
# Fallback for standalone use
class BaseModel(nn.Module):
def __init__(self, dataset=None, feature_keys=None, label_key=None, mode="regression"):
super().__init__()
self.dataset = dataset
self.feature_keys = feature_keys
self.label_key = label_key
self.mode = mode
# PyHealth is required for LabTOP. Please install pyhealth.
raise ImportError("PyHealth is required for LabTOP. Please install pyhealth (pip install pyhealth) to use this model.")

Copilot uses AI. Check for mistakes.
Comment on lines +306 to +307
# Initialize GPT-2 model (will be rebuilt after vocabulary is set)
self.model = None
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model is initialized with self.model = None and will raise errors if forward() or generate() are called before build_vocabulary(). This is a fragile API design that violates PyHealth's pattern where models are fully functional after __init__().

The two-phase initialization (construct then build_vocabulary()) is problematic because:

  1. Users could easily forget to call build_vocabulary()
  2. The model reports as initialized but is actually non-functional
  3. This breaks PyHealth's trainer expectations

Consider either:

  1. Accepting lab_items as a constructor parameter and building the model immediately, or
  2. Extracting lab items from the dataset automatically in __init__(), or
  3. Documenting this requirement very clearly and adding validation that raises a helpful error early

Copilot uses AI. Check for mistakes.
result['loss'] = outputs.loss
# For PyHealth compatibility
result['y_true'] = labels
result['y_prob'] = F.softmax(outputs.logits, dim=-1)
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For regression tasks, returning y_prob = F.softmax(outputs.logits, dim=-1) is incorrect. The softmax converts logits to a probability distribution over vocabulary tokens, which is appropriate for the language modeling objective but not meaningful as a regression prediction.

For PyHealth compatibility with regression mode, y_prob should contain the actual predicted numerical values (decoded from the generated token sequence), not token probabilities. Consider removing this line or replacing it with decoded predictions:

# For regression, y_prob should be decoded predictions
result['y_prob'] = self.decode_predictions_batch(outputs.logits)
Suggested change
result['y_prob'] = F.softmax(outputs.logits, dim=-1)
# For regression, y_prob should be decoded predictions
pred_token_ids = torch.argmax(outputs.logits, dim=-1) # [batch_size, seq_len]
result['y_prob'] = [self.decode_prediction(token_ids.tolist()) for token_ids in pred_token_ids]

Copilot uses AI. Check for mistakes.
Comment on lines +198 to +492
class LabTOP(BaseModel):
"""
LabTOP: Lab Test Outcome Prediction Model

A GPT-2 based transformer that predicts lab test outcomes using
digit-wise tokenization for continuous numerical predictions.

Paper: Im et al. "LabTOP: A Unified Model for Lab Test Outcome Prediction
on Electronic Health Records" CHIL 2025 (Best Paper Award)
https://arxiv.org/abs/2502.14259

Key Innovation:
- Digit-wise tokenization: Represents numerical values as sequences
of individual digits (e.g., 123.45 → ['1','2','3','.','4','5'])
- Preserves exact numerical precision
- Small vocabulary (only ~20-50 tokens total)
- Unified model for all lab test types

Args:
dataset: PyHealth dataset object
feature_keys: List of input feature names
label_key: Target lab test name
mode: Prediction mode (default: "regression")
n_layers: Number of transformer layers (default: 12)
n_heads: Number of attention heads (default: 12)
embedding_dim: Embedding dimension (default: 768)
max_seq_length: Maximum sequence length (default: 1024)
digit_precision: Decimal precision for values (default: 2)
dropout: Dropout rate (default: 0.1)
**kwargs: Additional arguments

Examples:
>>> from pyhealth.datasets import MIMIC4Dataset
>>> from pyhealth.models import LabTOP
>>>
>>> # Load dataset
>>> dataset = MIMIC4Dataset(root="/data/mimic4")
>>>
>>> # Initialize model
>>> model = LabTOP(
... dataset=dataset,
... feature_keys=["demographics", "lab_history"],
... label_key="lab_value",
... mode="regression",
... embedding_dim=768,
... n_layers=12
... )
>>>
>>> # Forward pass
>>> outputs = model(**batch)
>>> loss = outputs["loss"]

References:
- Paper: https://arxiv.org/abs/2502.14259
- Code: https://github.com/sujeongim/LabTOP
"""

def __init__(
self,
dataset,
feature_keys: List[str],
label_key: str,
mode: str = "regression",
n_layers: int = 12,
n_heads: int = 12,
embedding_dim: int = 768,
max_seq_length: int = 1024,
digit_precision: int = 2,
dropout: float = 0.1,
**kwargs
):
super(LabTOP, self).__init__(
dataset=dataset,
feature_keys=feature_keys,
label_key=label_key,
mode=mode,
)

self.n_layers = n_layers
self.n_heads = n_heads
self.embedding_dim = embedding_dim
self.max_seq_length = max_seq_length
self.digit_precision = digit_precision

# Initialize digit-wise tokenizer
self.digit_tokenizer = DigitWiseTokenizer(precision=digit_precision)

# Build vocabulary (will be updated with lab items from dataset)
# For now, use a placeholder - should be set with build_vocabulary()
self.vocabulary = None
self.vocab_size = len(self.digit_tokenizer) # Base vocab size

# GPT-2 configuration
self.gpt2_config = GPT2Config(
vocab_size=self.vocab_size, # Will be updated after vocab built
n_positions=max_seq_length,
n_embd=embedding_dim,
n_layer=n_layers,
n_head=n_heads,
n_inner=embedding_dim * 4,
activation_function='gelu_new',
resid_pdrop=dropout,
embd_pdrop=dropout,
attn_pdrop=dropout,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
)

# Initialize GPT-2 model (will be rebuilt after vocabulary is set)
self.model = None

def build_vocabulary(self, lab_items: List[int]) -> None:
"""
Build complete vocabulary including lab item codes.

This should be called after determining unique lab items from dataset.

Args:
lab_items: List of unique lab item IDs
"""
self.vocabulary = LabTOPVocabulary(lab_items, self.digit_tokenizer)
self.vocab_size = len(self.vocabulary)

# Update GPT-2 config with actual vocab size
self.gpt2_config.vocab_size = self.vocab_size
self.gpt2_config.bos_token_id = self.vocabulary.eos_token_id
self.gpt2_config.eos_token_id = self.vocabulary.eos_token_id
self.gpt2_config.pad_token_id = self.vocabulary.pad_token_id

# Build GPT-2 model
self.model = GPT2LMHeadModel(self.gpt2_config)

def prepare_input(
self,
demographics: Dict,
lab_history: List[Dict],
max_length: Optional[int] = None
) -> Dict[str, torch.Tensor]:
"""
Prepare input sequence from patient data.

Args:
demographics: Dict with 'age' and 'gender'
lab_history: List of events with 'code', 'value', 'timestamp'
max_length: Maximum sequence length (uses self.max_seq_length if None)

Returns:
Dict with 'input_ids', 'attention_mask'
"""
if self.vocabulary is None:
raise ValueError("Vocabulary not built. Call build_vocabulary() first.")

max_length = max_length or self.max_seq_length

# Start with demographics
token_ids = self.vocabulary.encode_demographics(
demographics.get('age'),
demographics.get('gender')
)

# Add lab events (sorted by timestamp)
for event in lab_history:
event_ids = self.vocabulary.encode_event(event)
token_ids.extend(event_ids)

if len(token_ids) >= max_length - 1:
break

# Add EOS token
token_ids.append(self.vocabulary.eos_token_id)

# Truncate if needed
if len(token_ids) > max_length:
token_ids = token_ids[:max_length]

# Create attention mask
attention_mask = [1] * len(token_ids)

# Pad to max_length
while len(token_ids) < max_length:
token_ids.append(self.vocabulary.pad_token_id)
attention_mask.append(0)

return {
'input_ids': torch.tensor([token_ids], dtype=torch.long),
'attention_mask': torch.tensor([attention_mask], dtype=torch.long)
}

def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs
) -> Dict[str, torch.Tensor]:
"""
Forward pass through LabTOP model.

Args:
input_ids: Token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
labels: Target token IDs for training [batch_size, seq_len]

Returns:
Dict with 'logits', 'loss' (if labels provided), 'y_prob', 'y_true'
"""
if self.model is None:
raise ValueError("Model not built. Call build_vocabulary() first.")

# Forward through GPT-2
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
return_dict=True
)

result = {
'logits': outputs.logits, # [batch_size, seq_len, vocab_size]
}

if labels is not None:
result['loss'] = outputs.loss
# For PyHealth compatibility
result['y_true'] = labels
result['y_prob'] = F.softmax(outputs.logits, dim=-1)

return result

def generate(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
max_new_tokens: int = 20,
temperature: float = 1.0,
**kwargs
) -> torch.Tensor:
"""
Generate lab value prediction autoregressively.

Args:
input_ids: Input token IDs
attention_mask: Attention mask
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature

Returns:
Generated token IDs
"""
if self.model is None:
raise ValueError("Model not built. Call build_vocabulary() first.")

return self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=self.vocabulary.pad_token_id,
eos_token_id=self.vocabulary.eos_token_id,
**kwargs
)

def decode_prediction(self, token_ids: List[int]) -> Optional[float]:
"""
Decode generated token IDs back to numerical value.

Args:
token_ids: List of generated token IDs

Returns:
Predicted numerical value or None if invalid
"""
if self.vocabulary is None:
raise ValueError("Vocabulary not built.")

# Find digit tokens between lab marker and separator
digit_tokens = []
in_value = False

for token_id in token_ids:
token = self.vocabulary.id_to_token.get(token_id, '')

if token == self.digit_tokenizer.special_tokens['LAB']:
in_value = True
continue

if token == self.digit_tokenizer.special_tokens['SEP']:
break

if in_value and token in self.digit_tokenizer.digit_tokens:
digit_tokens.append(token)

# Convert to number
return self.digit_tokenizer.tokens_to_number(digit_tokens)

Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LabTOP model implementation lacks test coverage. Other PyHealth models have comprehensive test files in tests/core/ (e.g., test_mlp.py, test_rnn.py, test_transformer.py) that verify:

  1. Model initialization
  2. Forward pass with correct output structure
  3. Backward pass (gradient computation)
  4. Custom hyperparameters

Please add a tests/core/test_labtop.py file with similar test cases to ensure the model integrates correctly with PyHealth's infrastructure, especially given the custom tokenization and two-phase initialization pattern.

Copilot uses AI. Check for mistakes.
Comment on lines +330 to +346
def prepare_input(
self,
demographics: Dict,
lab_history: List[Dict],
max_length: Optional[int] = None
) -> Dict[str, torch.Tensor]:
"""
Prepare input sequence from patient data.

Args:
demographics: Dict with 'age' and 'gender'
lab_history: List of events with 'code', 'value', 'timestamp'
max_length: Maximum sequence length (uses self.max_seq_length if None)

Returns:
Dict with 'input_ids', 'attention_mask'
"""
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring for prepare_input() describes a method that appears to be designed for standalone use with raw patient data, but the actual forward() method expects pre-tokenized input_ids. This creates confusion about the intended API usage.

Either:

  1. Update forward() to use prepare_input() internally, or
  2. Clarify in the docstring that prepare_input() is a utility method for manual data preparation, not part of the standard training flow

Copilot uses AI. Check for mistakes.
Comment on lines +95 to +98
def encode(self, tokens: List[str]) -> List[int]:
"""Convert tokens to IDs."""
return [self.vocab.get(token, self.vocab[self.special_tokens['PAD']])
for token in tokens]
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The encode() method silently falls back to PAD token for unknown tokens. This could mask bugs where invalid tokens are being encoded (e.g., special characters, out-of-vocabulary codes).

Consider either:

  1. Raising an error for unknown tokens (strict mode), or
  2. Adding a warning when unknown tokens are encountered, or
  3. Documenting this fallback behavior clearly in the docstring

Example:

def encode(self, tokens: List[str]) -> List[int]:
    """Convert tokens to IDs. Unknown tokens are mapped to PAD."""
    result = []
    for token in tokens:
        if token not in self.vocab:
            # Optionally log warning
            result.append(self.vocab[self.special_tokens['PAD']])
        else:
            result.append(self.vocab[token])
    return result

Copilot uses AI. Check for mistakes.
Comment on lines +382 to +383
'input_ids': torch.tensor([token_ids], dtype=torch.long),
'attention_mask': torch.tensor([attention_mask], dtype=torch.long)
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensors created in prepare_input() are not moved to the model's device. When the model is on GPU, these CPU tensors will cause device mismatch errors during forward pass.

Fix by moving tensors to the model's device:

return {
    'input_ids': torch.tensor([token_ids], dtype=torch.long).to(self.device),
    'attention_mask': torch.tensor([attention_mask], dtype=torch.long).to(self.device)
}

Note: BaseModel provides a self.device property via the _dummy_param (see base_model.py:44, 75-82), so you can use it directly.

Suggested change
'input_ids': torch.tensor([token_ids], dtype=torch.long),
'attention_mask': torch.tensor([attention_mask], dtype=torch.long)
'input_ids': torch.tensor([token_ids], dtype=torch.long).to(self.device),
'attention_mask': torch.tensor([attention_mask], dtype=torch.long).to(self.device)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Contribute a new model to PyHealth

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants