Skip to content

hugh-min/VRL-LGDNet

Repository files navigation

VRL-LGDNet

VRL-LGDNet is a deep-learning framework for classifying vitreoretinal lymphoma (VRL) from OCT images. It uses a multimodal architecture combining OCT B-scan features with teacher-guided attention maps and patient metadata to distinguish VRL from other retinal diseases (Control and Uveitis).

Overview

VRL is a rare but vision-threatening intraocular lymphoma that is easily confused with uveitis and other retinal diseases on OCT. VRL-LGDNet proposes a Lesion-Guided Distillation Network that:

  1. Teacher model: A pre-trained Faster R-CNN lesion detector generates class activation maps (attention maps) highlighting VRL-associated lesion regions.
  2. Student model: A 4-channel (RGB + attention map) nf_resnet50 backbone extracts joint lesion-aware features.
  3. Tabular branch: Patient demographics (sex, age, diseased eyes) are fused via a shallow MLP.
  4. Distillation: The student is trained to reproduce the teacher's detection softmax outputs, reinforcing lesion-aware representation learning.

Ablation study

Experiment Image input Teacher / Distillation Tabular
full (proposed) 4-ch RGB+attn
ablation3 3-ch RGB
ablation2 4-ch RGB+attn
ablation1 3-ch RGB

Installation

git clone https://github.com/<your-org>/VRL-LGDNet.git
cd VRL-LGDNet
pip install -r requirements.txt

Python ≥ 3.10 is required (uses X | Y union type hint syntax).

GPU with ≥ 8 GB VRAM is recommended for training.

Data preparation

data_processing.py handles image cropping and train/val/test split assignment. It expects a folder-level index CSV that you build from your own data:

python data_processing.py \
    --data-csv   /path/to/data.csv \
    --raw-root   /path/to/raw_images \
    --output-dir /path/to/processed_images \
    --output-csv /path/to/data_processed.csv \
    --stats-csv  /path/to/stats.csv

The key output is a processed data CSV with the following columns:

Column Description
Group VRL or Control (or Uveitis for the uveitis experiment)
subject Patient identifier
laterity OD / OS / unknown
folderpath Absolute path to the processed image folder
image_count Number of B-scan images
device Heidelberg or Shiwei
age Patient age (integer)
sex 1=male, 2=female
diseased_eyes 1=unilateral, 2=bilateral
split train / val / test

Training

Teacher model (lesion detector)

Train the Faster R-CNN teacher before running the student experiments. The teacher requires a bounding-box annotation CSV with columns: image_path, label_name, x, y, width, height.

python train_teacher.py \
    --labels-csv /path/to/roi_labels.csv \
    --output-dir ./teacher_output

The saved checkpoint (teacher_output/checkpoints/best_teacher.pth) is passed as --teacher-path to the student training scripts.

Label names expected in the annotation CSV (LABEL_TO_IDX in train_teacher.py):

label_name Class Lesion
confluent_rpe_detachment 1 Confluent RPE detachment
intraretinal_abnormal_reflectivity 2 Intraretinal abnormal reflectivity
preretinal_deposits 3 Preretinal deposits
poor_quality Excluded (quality filter)

Rename the keys in LABEL_TO_IDX to match your annotation tool's export format.

VRL vs Control (primary experiment)

python train_vrl_vs_control.py \
    --data-csv   /path/to/data_processed.csv \
    --teacher-path /path/to/teacher_model.pth \
    --output-dir ./output_vrl_vs_control

Run only specific experiments:

python train_vrl_vs_control.py \
    --data-csv /path/to/data_processed.csv \
    --output-dir ./output \
    --experiments full ablation2

VRL vs Uveitis

python train_vrl_vs_uveitis.py \
    --data-csv   /path/to/data_uveitis_processed.csv \
    --teacher-path /path/to/teacher_model.pth \
    --output-dir ./output_vrl_vs_uveitis

Select checkpoint by eye-level F1:

python train_vrl_vs_uveitis.py \
    --data-csv /path/to/data_uveitis_processed.csv \
    --output-dir ./output_uveitis \
    --selection-metric eye_f1 \
    --tag run_eyef1

Key arguments

Argument Default Description
--data-csv required Path to processed data CSV
--teacher-path None Path to teacher Faster R-CNN .pth. If omitted, attention maps are zeroed (ablation1/ablation3 unaffected)
--output-dir ./output_... Root output directory
--experiments all 4 Space-separated subset: full ablation1 ablation2 ablation3
--seed 42 Global random seed
--selection-metric image_f1 image_f1 or eye_f1 (uveitis script only)
--tag "" Run tag for parallel experiments (uveitis script only)

Output structure

output_dir/
  checkpoints/
    full/best.pth
    ablation1/best.pth
    ablation2/best.pth
    ablation3/best.pth
  teacher_cache/
    teacher_train.csv
    teacher_val.csv
    teacher_test.csv
    attn_maps/
      <image_stem>_attn.npy
  history/
    history_full.csv
    history_ablation1.csv
    ...
  report.md

Visualization

Generate ROC curves and Grad-CAM++ heatmaps after training:

python generate_visualizations.py \
    --data-csv          /path/to/data_processed.csv \
    --teacher-cache-dir ./output_vrl_vs_control/teacher_cache \
    --ckpt-dir          ./output_vrl_vs_control/checkpoints \
    --output-dir        ./visualizations

Skip heatmaps and generate only ROC curves:

python generate_visualizations.py \
    --data-csv /path/to/data_processed.csv \
    --teacher-cache-dir ./output/teacher_cache \
    --ckpt-dir ./output/checkpoints \
    --output-dir ./vis \
    --skip-gradcam

Model architecture

The models/lgdnet.py module exports:

  • FullModel — full proposed model (4-ch + tabular + distillation)
  • MODEL_REGISTRY — dict mapping {name: class}
  • OCTDataset — unified dataset for all four modes
  • run_epoch, eval_eye_level, compute_metrics — training utilities
  • load_teacher, cache_teacher_outputs — teacher inference helpers
  • set_seed, safe_float — utility functions

Backbone

nf_resnet50 from the timm library (NFNet normalization-free). Input: 512×512 pixels.

Teacher model

Faster R-CNN with ResNet-50 FPN (torchvision.models.detection.fasterrcnn_resnet50_fpn). Loaded via load_teacher(path, n_classes=4) where n_classes is the number of detection classes (excluding background).

The teacher produces per-image class activation vectors (the softmax of the detection head outputs), which are stored as .npy attention maps and served as the distillation target during student training.

Hyperparameters

Hyperparameter Value
Learning rate 0.0001
Weight decay 0.0064
Distillation α 1.1114
Batch size 8
Image size 512×512
Max epochs 50
Early stop patience 6
Seed 42
Teacher threshold 0.5

Citation

If you use VRL-LGDNet in your research, please cite:

@article{vrl_lgdnet_2025,
  title   = {Artificial intelligence-assisted screening of vitreoretinal lymphoma based on optical coherence tomography},
  author  = {Suowang Zhou, et al.},
  journal = {To be submitted},
  year    = {2026},
}

License

This project is released under the MIT License.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages