This is a PyTorch implementation of the paper: Online Multi-Modal Spatio-Temporal Prediction: A Reinforcement Learning and Dynamic Contrastive Framework(ROMST)
The code is tested with Python 3.10+ and PyTorch (CUDA recommended).
pip install -r requirements.txtLogging is automatically saved to checkpoint/log/ with timestamped filenames.
The script supports multiple datasets. Prepare the following directories/files.
- BjTT
- Structure:
mydatasets/BjTT/data/{1,2,3}/*.npymydatasets/BjTT/text/{1,2,3}/*.txt
- Prior matrix:
mydatasets/prior_matrix/BjTT_matrix_prior.npy
- Terra
- Time series:
mydatasets/Terra/time_series/wind_daily.npy - Images:
mydatasets/Terra/image/relief_{lat}{N|S}_{lon}{E|W}.png - Texts:
mydatasets/Terra/texts/meta_{lat}{N|S}_{lon}{E|W}.txt - Prior matrix:
mydatasets/prior_matrix/Terra_matrix_prior.npy
- PEMS04
- Time series:
mydatasets/PEMS04/data.npy - Prior matrix:
mydatasets/prior_matrix/PEMS04_matrix_prior.npy
- GreenEarthNet
- Temporal:
mydatasets/GreenEarthNet/time_series.npy - Image RGB sequence:
mydatasets/GreenEarthNet/image_rgb.npy - Prior matrix:
mydatasets/prior_matrix/GreenEarthNet_matrix_prior.npy
Create folders as needed:
mkdir -p mydatasets/prior_matrix
mkdir -p mydatasets/BjTT
mkdir -p mydatasets/Terra/{time_series,image,texts}
mkdir -p mydatasets/PEMS04
mkdir -p mydatasets/GreenEarthNetNote: Due to the excessive size of certain datasets, only partial datasets are provided herein. The complete datasets can be found on the official website
- BjTT: https://github.com/ChyaZhang/BjTT
- Terra: https://github.com/CityMind-Lab/NeurIPS24-Terra
- PEMS04: https://github.com/Davidham3/ASTGCN
- GreenEarthNet: https://github.com/vitusbenson/greenearthnet
Run the two-stage pipeline with train.py. Key flags are shown below. Adjust paths if your data is elsewhere.
- BjTT
python train.py \
--dataset BjTT \
--data_dir mydatasets/BjTT \
--prior_matrix_path mydatasets/prior_matrix \
--seq_len 12 --pred_len 12 \
--num_nodes 1260 \
--use_text True --use_image False\
--pruning_ratio 0.3 --contrastive_weight 0.0001 \
--batch_size 32 --epochs 100 \
--device cuda --save_dir ./checkpoint- Terra
python train.py \
--dataset Terra \
--data_dir mydatasets/Terra \
--time_series_path mydatasets/Terra/time_series/wind_daily.npy \
--image_dir mydatasets/Terra/image \
--text_dir mydatasets/Terra/texts \
--prior_matrix_path mydatasets/prior_matrix \
--seq_len 12 --pred_len 12 \
--num_nodes 100 \
--use_text True --use_image True \
--pruning_ratio 0.3 --contrastive_weight 0.6 \
--batch_size 16 --epochs 100 \
--device cuda --save_dir ./checkpoint- PEMS04
python train.py \
--dataset PEMS04 \
--data_dir mydatasets/PEMS04 \
--prior_matrix_path mydatasets/prior_matrix \
--seq_len 12 --pred_len 12 \
--num_nodes 307 \
--use_image False --use_text False \
--contrastive_weight 1 \
--batch_size 32 --epochs 100 \
--device cuda --save_dir ./checkpoint- GreenEarthNet
python train.py \
--dataset GreenEarthNet \
--data_dir mydatasets/GreenEarthNet \
--prior_matrix_path mydatasets/prior_matrix \
--seq_len 12 --pred_len 12 \
--num_nodes 1024 \
--use_image True --use_text False \
--contrastive_weight 0.001 \
--batch_size 32 --epochs 100 \
--device cuda --save_dir ./checkpointNotes
- The script prunes the text encoder once at startup when
--use_textis enabled:--pruning_ratio [0.0–1.0]. - Continual learning segments are controlled internally; you can change
indexvia--indexand rerun if needed.
- Logs:
checkpoint/log/training_log_{dataset}_p{pruning_ratio}_c{contrastive_weight}_{YYYYMMDD_HHMMSS}.log - Best checkpoints:
checkpoint/{dataset}/best_spatiotemporal_model_*.pthandcheckpoint/{dataset}/best_multimodal_model_*.pth - Segment summary:
checkpoint/results/two_stage_segments_{dataset}_aug_{contrastive_weight}_prune_{pruning_ratio}.csv - Optional visualizations (enable with
--enable_reward_viz):- Rewards:
checkpoint/reward/reward_trend_{dataset}_*.csv|.png - Modal weights:
checkpoint/weights/weights_trend_{dataset}_*.csv|.png|.txt
- Rewards: