Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions OATS/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# OATS: Online Data Augmentation for Time Series Foundation Models

This repository contains the code for paper "OATS: Online Data Augmentation for Time Series Foundation Models".

## Overview

OATS introduces a novel online data augmentation framework specifically designed to enhance the training of time series foundation models (TSFM). Unlike traditional offline augmentation methods that pre-generate synthetic data, OATS generates synthetic data by using training samples with high data attribution scores as guiding signals.

OATS consists of three key components:
- Time-series Influence Scores (TSIS) integrate data attribution with time series-specific knowledge to dynamically assess the quality of each training sample, creating a generation guiding signal.
- High-quality Guided Data Augmentation leverages the guiding signal to condition a diffusion model trained on a small subset of the TSFM training data for synthetic data generation.
- Explore-Exploit Mechanism reduces computational overhead and effectively balances between leveraging calculated scores and exploring new samples. The influence scores are stochastically re-evaluated to incorporate model training dynamics ("explore") while preserving previously identified high-quality data ("exploit").


![Method](assets/method.png)

## Environment and Dataset

### Dataset preparation

#### TSFM pretrain dataset
Download dataset for TSFM from [here](https://huggingface.co/datasets/Qingren/TSFM-ScalingLaws-Dataset). The directory organization structure is as follows:

```
- dataset_train
|- Lotsa16B
|- Lotsa1B
|- Lotsa100M
|- Lotsa10M
- dataset_test
|- Lotsa16B
|- Lotsa1B
|- Lotsa100M
|- Lotsa10M
|- LSF
|- Monash
```

#### Generation model training data
Extracte dataset from for diffusion model. The dataset is extracted from the Lotsa100M dataset with a sampling rate 5% of the dataset in 20 selected subdatasets.

```bash
python extract_data_generation.py -cp cli/conf/pretrain\
-cn default_ddp_val_enc\
model=encoder_10M\
model.enable_influence_scoring=true\
data=lotsa100M_weighted\
trainer.max_epochs=0\
model.num_warmup_steps=0
```
The directory organization structure is as follows:
```bash
extracted_label_patches_australian_electricity_demand.npy
extracted_label_patches_azure_vm_traces_2017.npy
extracted_label_patches_buildings_900k.npy
extracted_label_patches_CloudOpsTSF_dataset.npy
extracted_label_patches_CMIP6_dataset.npy
...
```

### Environment setting

```bash
# Clone the repository
git clone https://github.com/microsoft/TimeCraft.git
cd TimeCraft/OATS

# Create and activate conda environment
conda env create -f environment.yml
conda activate oats
```


## Quick Start
Step 1. Train a time series generation model with the extracted sampled data.
```bash
cd models/gen_model

python main_train.py --base configs/multi_domain_timedp_local.yaml --gpus 0, --logdir ./logs/ -sl 320 -up -nl 16 --batch_size 128 -lr 0.0001 -s 0
```

Step 2. Train the time series foundation model
```bash
python -m cli.train_val\
-cp conf/pretrain\
-cn default_ddp_val_enc\
model=encoder\
model.enable_influence_scoring=true\
data=lotsa100M_weighted\
val_data=all\
trainer.logger.project=TSFM_PRETRAIN\
run_name=encoder10M_etth1_develop\
model.generate_after_epoch=0\
model.influence_filter_ratio=1.0\
model.select_from_generated=false
```

Outputs: The results can be found in wandb log and `./outputs/pretrain/`


Binary file added OATS/assets/method.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
174 changes: 174 additions & 0 deletions OATS/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
name: oats
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- asttokens=3.0.0=pyhd8ed1ab_1
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2025.8.3=hbd8a1cb_0
- comm=0.2.3=pyhe01879c_0
- debugpy=1.8.11=py310h6a678d5_0
- decorator=5.2.1=pyhd8ed1ab_0
- exceptiongroup=1.3.0=pyhd8ed1ab_0
- executing=2.2.1=pyhd8ed1ab_0
- expat=2.7.1=h6a678d5_0
- importlib-metadata=8.7.0=pyhe01879c_1
- ipykernel=6.30.1=pyh82676e8_0
- ipython=8.37.0=pyh8f84b5b_0
- jedi=0.19.2=pyhd8ed1ab_1
- jupyter_client=8.6.3=pyhd8ed1ab_1
- jupyter_core=5.8.1=pyh31011fe_0
- ld_impl_linux-64=2.40=h12ee557_0
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libsodium=1.0.18=h36c2ea0_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- libxcb=1.17.0=h9b100fa_0
- matplotlib-inline=0.1.7=pyhd8ed1ab_1
- ncurses=6.5=h7934f7d_0
- nest-asyncio=1.6.0=pyhd8ed1ab_1
- openssl=3.0.17=h5eee18b_0
- packaging=25.0=pyh29332c3_1
- parso=0.8.5=pyhcf101f3_0
- pexpect=4.9.0=pyhd8ed1ab_1
- pickleshare=0.7.5=pyhd8ed1ab_1004
- pip=25.1=pyhc872135_2
- prompt-toolkit=3.0.52=pyha770c72_0
- psutil=5.9.0=py310h5eee18b_1
- pthread-stubs=0.3=h0ce48e5_1
- ptyprocess=0.7.0=pyhd8ed1ab_1
- pure_eval=0.2.3=pyhd8ed1ab_1
- pygments=2.19.2=pyhd8ed1ab_0
- python=3.10.18=h1a3bd86_0
- python-dateutil=2.9.0.post0=pyhe01879c_2
- pyzmq=26.2.0=py310h6a678d5_0
- readline=8.3=hc2a1206_0
- setuptools=78.1.1=py310h06a4308_0
- six=1.17.0=pyhe01879c_1
- sqlite=3.50.2=hb25bd0a_1
- stack_data=0.6.3=pyhd8ed1ab_1
- tk=8.6.15=h54e0aa7_0
- tornado=6.5.1=py310h5eee18b_0
- traitlets=5.14.3=pyhd8ed1ab_1
- typing_extensions=4.15.0=pyhcf101f3_0
- wcwidth=0.2.13=pyhd8ed1ab_1
- wheel=0.45.1=py310h06a4308_0
- xorg-libx11=1.8.12=h9b100fa_1
- xorg-libxau=1.0.12=h9b100fa_0
- xorg-libxdmcp=1.1.5=h9b100fa_0
- xorg-xorgproto=2024.1=h5eee18b_1
- xz=5.6.4=h5eee18b_1
- zeromq=4.3.5=h6a678d5_0
- zipp=3.23.0=pyhd8ed1ab_0
- zlib=1.2.13=h5eee18b_1
- pip:
- absl-py==2.3.1
- aiohappyeyeballs==2.6.1
- aiohttp==3.12.15
- aiosignal==1.4.0
- annotated-types==0.7.0
- antlr4-python3-runtime==4.9.3
- async-timeout==5.0.1
- attrs==25.3.0
- certifi==2025.8.3
- charset-normalizer==3.4.3
- click==8.2.1
- contourpy==1.3.2
- cycler==0.12.1
- datasets==2.17.1
- dattri==0.2.0
- dill==0.3.8
- einops==0.7.0
- fast-jl==0.1.3
- filelock==3.19.1
- fonttools==4.59.2
- frozenlist==1.7.0
- fsspec==2023.10.0
- gitdb==4.0.12
- gitpython==3.1.45
- gluonts==0.14.4
- grpcio==1.74.0
- hf-xet==1.1.8
- huggingface-hub==0.34.4
- hydra-core==1.3.0
- idna==3.10
- jax==0.6.1
- jaxlib==0.6.1
- jaxtyping==0.2.38
- jinja2==3.1.6
- joblib==1.5.2
- kiwisolver==1.4.9
- lightning==2.5.3
- lightning-utilities==0.15.2
- markdown==3.8.2
- markupsafe==3.0.2
- matplotlib==3.10.5
- mido==1.3.3
- ml-dtypes==0.5.3
- mpmath==1.3.0
- multidict==6.6.4
- multiprocess==0.70.16
- networkx==3.4.2
- numpy==1.26.4
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu12==12.1.105
- nvidia-cuda-nvrtc-cu12==12.1.105
- nvidia-cuda-runtime-cu12==12.1.105
- nvidia-cudnn-cu12==9.1.0.70
- nvidia-cufft-cu12==11.0.2.54
- nvidia-curand-cu12==10.3.2.106
- nvidia-cusolver-cu12==11.4.5.107
- nvidia-cusparse-cu12==12.1.0.106
- nvidia-nccl-cu12==2.20.5
- nvidia-nvjitlink-cu12==12.9.86
- nvidia-nvtx-cu12==12.1.105
- omegaconf==2.3.0
- opt-einsum==3.4.0
- orjson==3.11.2
- pandas==2.1.4
- patsy==1.0.1
- pillow==11.3.0
- platformdirs==4.3.8
- pretty-midi==0.2.10
- propcache==0.3.2
- protobuf==6.32.0
- pyarrow==21.0.0
- pyarrow-hotfix==0.7
- pydantic==2.11.7
- pydantic-core==2.33.2
- pyparsing==3.2.3
- python-dotenv==1.0.0
- pytorch-lightning==2.5.3
- pytz==2025.2
- pyyaml==6.0.2
- requests==2.32.5
- safetensors==0.6.2
- scikit-learn==1.7.1
- scipy==1.11.4
- sentry-sdk==2.35.0
- smmap==5.0.2
- statsmodels==0.14.5
- sympy==1.14.0
- tensorboard==2.20.0
- tensorboard-data-server==0.7.2
- threadpoolctl==3.6.0
- toolz==0.12.1
- torch==2.4.1
- torchaudio==2.4.1
- torchmetrics==1.8.1
- torchvision==0.19.1
- tqdm==4.67.1
- triton==3.0.0
- typing-extensions==4.14.1
- typing-inspection==0.4.1
- tzdata==2025.2
- urllib3==2.5.0
- wadler-lindig==0.1.7
- wandb==0.21.1
- werkzeug==3.1.3
- xxhash==3.5.0
- yarl==1.20.1
4 changes: 4 additions & 0 deletions OATS/models/.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
LOTSA_16B_PATH=PATH/TO/LOTSA_16B
LOTSA_1B_PATH=PATH/TO/LOTSA_1B
LOTSA_100M_PATH=./dataset_train/Lotsa100M/
LOTSA_10M_PATH=PATH/TO/LOTSA_10M
4 changes: 4 additions & 0 deletions OATS/models/cli/conf/eval/data/lsf_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: tsfm.eval_util.data.get_lsf_test_dataset
dataset_name: ???
prediction_length: ???
mode: S
4 changes: 4 additions & 0 deletions OATS/models/cli/conf/eval/data/monash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: tsfm.eval_util.data.get_gluonts_test_dataset
dataset_name: ???
prediction_length: null
mode: S
24 changes: 24 additions & 0 deletions OATS/models/cli/conf/eval/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
hydra:
run:
dir: outputs/${hydra:job.name}/${hydra:runtime.choices.data}/${data.dataset_name}/${data.mode}/prediction_length=${data.prediction_length}/${run_name}
defaults:
- model: ???
- data: ???
- _self_
run_name: ???
metrics:
- _target_: gluonts.ev.metrics.MSE
- _target_: tsfm.eval_util.metrics.MedianMSE
- _target_: gluonts.ev.metrics.MAE
- _target_: gluonts.ev.metrics.MASE
- _target_: gluonts.ev.metrics.MAPE
- _target_: gluonts.ev.metrics.SMAPE
- _target_: gluonts.ev.metrics.MSIS
- _target_: gluonts.ev.metrics.RMSE
- _target_: gluonts.ev.metrics.NRMSE
- _target_: gluonts.ev.metrics.ND
- _target_: gluonts.ev.metrics.MeanWeightedSumQuantileLoss
quantile_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
batch_size: 512
min_batch_size: 1
device: auto
5 changes: 5 additions & 0 deletions OATS/models/cli/conf/eval/model/decoder_lightning_ckpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: tsfm.model.decoder.TransformerDecoderForecast.load_from_checkpoint
checkpoint_path: ...
num_samples: 100
patch_size: 32
context_length: ???
5 changes: 5 additions & 0 deletions OATS/models/cli/conf/eval/model/encoder_lightning_ckpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: tsfm.model.encoder.TransformerEncoderForecast.load_from_checkpoint
checkpoint_path: ...
num_samples: 100
patch_size: 32
context_length: ???
Loading