Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
LZH-YS1998 authored May 23, 2024
0 parents commit d177d04
Show file tree
Hide file tree
Showing 45 changed files with 7,534 additions and 0 deletions.
210 changes: 210 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction

A pytorch implementation for the paper: [FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction]<br />

[Zhonghang Li](https://scholar.google.com/citations?user=__9uvQkAAAAJ), [Lianghao Xia](https://akaxlh.github.io/), [Yong Xu](https://scholar.google.com/citations?user=1hx5iwEAAAAJ), [Chao Huang](https://sites.google.com/view/chaoh)* (*Correspondence)<br />

**[Data Intelligence Lab](https://sites.google.com/view/chaoh/home)@[University of Hong Kong](https://www.hku.hk/)**, [South China University of Technology](https://www.scut.edu.cn/en/), PAZHOU LAB

-----------

## Introduction

<p style="text-align: justify">
In this work, we introduce a simple and universal spatio-temporal prompt-tuning framework, which addresses the significant challenge posed by distribution shift in this field.
To achieve this objective, we present FlashST, a framework that adapts pretrained models to the specific characteristics of diverse downstream datasets, thereby improving generalization across various prediction scenarios.
We begin by utilizing a lightweight spatio-temporal prompt network for in-context learning, capturing spatio-temporal invariant knowledge and facilitating effective adaptation to diverse scenarios. Additionally, we incorporate a distribution mapping mechanism to align the data distributions of pre-training and downstream data, facilitating effective knowledge transfer in spatio-temporal forecasting. Empirical evaluations demonstrate the effectiveness of our FlashST across different spatio-temporal prediction tasks.

</p>

![The detailed framework of the proposed FlashST.](https://github.com/LZH-YS1998/GPT-ST_img/blob/main/FlashST.png)

-----------
<span id='Usage'/>



## Getting Started

<span id='all_catelogue'/>

### Table of Contents:
* <a href='#Code Structure'>1. Code Structure</a>
* <a href='#Environment'>2. Environment </a>
* <a href='#Run the codes'>3. Run the codes </a>

****


<span id='Code Structure'/>

### 1. Code Structure <a href='#all_catelogue'>[Back to Top]</a>


* **conf**: This folder includes parameter settings for FlashST (`config.conf`) as well as all other baseline models.
* **data**: The documentation encompasses all the datasets utilized in our work, alongside prefabricated files and the corresponding file generation codes necessary for certain baselines.
* **lib**: Including a series of initialization methods for data processing, as follows:
* `data_process.py`: Load, split, generate data, normalization method, slicing, etc.
* `logger.py`: For output printing.
* `metrics.py`: Method for calculating evaluation indicators.
* `predifineGraph.py`: Predefined graph generation method.
* `TrainInits.py`: Training initialization, including settings of optimizer, device, random seed, etc.
* **model**: Includes the implementation of FlashST and all baseline models, along with the necessary code to support the framework's execution. The `args.py` script is utilized to generate the required prefabricated data and parameter configurations for different baselines. Additionally, the `SAVE` folder serves as the storage location for saving the pre-trained models.
* **SAVE**: This folder serves as the storage location for saving the trained models, including pretrain, eval and ori.


```
│ README.md
│ requirements.txt
├─conf
│ ├─AGCRN
│ ├─ASTGCN
│ ├─FlashST
│ │ │ config.conf
│ │ │ Params_pretrain.py
│ ├─GWN
│ ├─MSDR
│ ├─MTGNN
│ ├─PDFormer
│ ├─ST-WA
│ ├─STFGNN
│ ├─STGCN
│ ├─STSGCN
│ └─TGCN
├─data
│ ├─CA_District5
│ ├─chengdu_didi
│ ├─NYC_BIKE
│ ├─PEMS03
│ ├─PEMS04
│ ├─PEMS07
│ ├─PEMS07M
│ ├─PEMS08
│ ├─PDFormer
│ ├─STFGNN
│ └─STGODE
├─lib
│ │ data_process.py
│ │ logger.py
│ │ metrics.py
│ │ predifineGraph.py
│ │ TrainInits.py
├─model
│ │ FlashST.py
│ │ PromptNet.py
│ │ Run.py
│ │ Trainer.py
│ │
│ ├─AGCRN
│ ├─ASTGCN
│ ├─DMSTGCN
│ ├─GWN
│ ├─MSDR
│ ├─MTGNN
│ ├─PDFormer
│ ├─STFGNN
│ ├─STGCN
│ ├─STGODE
│ ├─STSGCN
│ ├─ST_WA
│ └─TGCN
└─SAVE
└─pretrain
├─GWN
│ GWN_P8437.pth
├─MTGNN
│ MTGNN_P8437.pth
├─PDFormer
│ PDFormer_P8437.pth
└─STGCN
STGCN_P8437.pth
```

---------

<span id='Environment'/>

### 2.Environment <a href='#all_catelogue'>[Back to Top]</a>
The code can be run in the following environments, other version of required packages may also work.
* python==3.9.12
* numpy==1.23.1
* pytorch==1.9.0
* cudatoolkit==11.1.1

Or you can install the required environment, which can be done by running the following commands:
```
# cteate new environmrnt
conda create -n FlashST python=3.9.12
# activate environmrnt
conda activate FlashST
# Torch with CUDA 11.1
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
# Install required libraries
pip install -r requirements.txt
```

---------

<span id='Run the codes'/>

### 3. Run the codes <a href='#all_catelogue'>[Back to Top]</a>

* First, download "data" folder from [hugging face (data.zip)](https://huggingface.co/datasets/bjdwh/FlashST-DATA/tree/main) or [wisemodel (data.zip)](https://wisemodel.cn/datasets/BJDWH/FlashST-data/file), put it in the FlashST-main directory, unzip and then enter "model" folder:
```
cd model
```
* To test different models in various modes, you can execute the Run.py code. There are some examples:
```
# Evaluate the performance of MTGNN enhanced by FlashST on the PEMS07M dataset
python Run.py -dataset_test PEMS07M -mode eval -model MTGNN
# Evaluate the performance of STGCN enhanced by FlashST on the CA_District5 dataset
python Run.py -dataset_test CA_District5 -mode eval -model STGCN
# Evaluate the original performance of STGCN on the chengdu_didi dataset
python Run.py -dataset_test chengdu_didi -mode ori -model STGCN
# Pretrain from scratch with MTGNN model, checkpoint will be saved in FlashST-main/SAVE/pretrain/MTGNN(model name)/xxx.pth
python Run.py -mode pretrain -model MTGNN
```

* Parameter setting instructions. The parameter settings consist of two parts: the pre-training model and the baseline model. To avoid any confusion arising from potential overlapping parameter names, we employ a hyphen (-) to specify the parameters of FlashST and use a double hyphen (--) to specify the parameters of the baseline model. Here is an example:
```
# Set first_layer_embedding_size and out_layer_dim to 32 in STFGNN
python Run.py -model STFGNN -mode ori -dataset_test PEMS08 --first_layer_embedding_size 32 --out_layer_dim 32
```

---------


<!-- ## Citation
If you find UrbanGPT useful in your research or applications, please kindly cite:
```
@misc{li2024urbangpt,
title={UrbanGPT: Spatio-Temporal Large Language Models},
author={Zhonghang Li and Lianghao Xia and Jiabin Tang and Yong Xu and Lei Shi and Long Xia and Dawei Yin and Chao Huang},
year={2024},
eprint={2403.00813},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
---------
-->

## Acknowledgements
We developed our code framework drawing inspiration from [AGCRN](https://github.com/LeiBAI/AGCRN) and [GPT-ST](https://github.com/HKUDS/GPT-ST). Furthermore, the implementation of the baselines primarily relies on a combination of the code released by the original author and the code from [LibCity](https://github.com/LibCity/Bigscity-LibCity). We extend our heartfelt gratitude for their remarkable contribution.
Binary file added SAVE/pretrain/GWN/GWN_P8437.pth
Binary file not shown.
Binary file added SAVE/pretrain/MTGNN/MTGNN_P8437.pth
Binary file not shown.
Binary file added SAVE/pretrain/PDFormer/PDFormer_P8437.pth
Binary file not shown.
Binary file added SAVE/pretrain/STGCN/STGCN_P8437.pth
Binary file not shown.
58 changes: 58 additions & 0 deletions lib/TrainInits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import random
import numpy as np

def init_seed(seed, seed_mode):
'''
Disable cudnn to maximize reproducibility
'''
if seed_mode:
torch.cuda.cudnn_enabled = False
torch.backends.cudnn.deterministic = True
# random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)

def init_device(opt):
if torch.cuda.is_available():
opt.cuda = True
torch.cuda.set_device(int(opt.device[5]))
else:
opt.cuda = False
opt.device = 'cpu'
return opt

def init_optim(model, opt):
'''
Initialize optimizer
'''
return torch.optim.Adam(params=model.parameters(),lr=opt.lr_init)

def init_lr_scheduler(optim, opt):
'''
Initialize the learning rate scheduler
'''
#return torch.optim.lr_scheduler.StepLR(optimizer=optim,gamma=opt.lr_scheduler_rate,step_size=opt.lr_scheduler_step)
return torch.optim.lr_scheduler.MultiStepLR(optimizer=optim, milestones=opt.lr_decay_steps,
gamma = opt.lr_scheduler_rate)

def print_model_parameters(model, only_num = True):
print('*****************Model Parameter*****************')
if not only_num:
for name, param in model.named_parameters():
print(name, param.shape, param.requires_grad)
total_num = sum([param.nelement() for param in model.parameters()])
print('Total params num: {}'.format(total_num))
print('*****************Finish Parameter****************')

def get_memory_usage(device):
allocated_memory = torch.cuda.memory_allocated(device) / (1024*1024.)
cached_memory = torch.cuda.memory_cached(device) / (1024*1024.)
print('Allocated Memory: {:.2f} MB, Cached Memory: {:.2f} MB'.format(allocated_memory, cached_memory))
return allocated_memory, cached_memory
Loading

0 comments on commit d177d04

Please sign in to comment.