diff --git a/OATS/README.md b/OATS/README.md new file mode 100644 index 0000000..7d3e215 --- /dev/null +++ b/OATS/README.md @@ -0,0 +1,100 @@ +# OATS: Online Data Augmentation for Time Series Foundation Models + +This repository contains the code for paper "OATS: Online Data Augmentation for Time Series Foundation Models". + +## Overview + +OATS introduces a novel online data augmentation framework specifically designed to enhance the training of time series foundation models (TSFM). Unlike traditional offline augmentation methods that pre-generate synthetic data, OATS generates synthetic data by using training samples with high data attribution scores as guiding signals. + +OATS consists of three key components: +- Time-series Influence Scores (TSIS) integrate data attribution with time series-specific knowledge to dynamically assess the quality of each training sample, creating a generation guiding signal. +- High-quality Guided Data Augmentation leverages the guiding signal to condition a diffusion model trained on a small subset of the TSFM training data for synthetic data generation. +- Explore-Exploit Mechanism reduces computational overhead and effectively balances between leveraging calculated scores and exploring new samples. The influence scores are stochastically re-evaluated to incorporate model training dynamics ("explore") while preserving previously identified high-quality data ("exploit"). + + +![Method](assets/method.png) + +## Environment and Dataset + +### Dataset preparation + +#### TSFM pretrain dataset +Download dataset for TSFM from [here](https://huggingface.co/datasets/Qingren/TSFM-ScalingLaws-Dataset). The directory organization structure is as follows: + +``` +- dataset_train + |- Lotsa16B + |- Lotsa1B + |- Lotsa100M + |- Lotsa10M +- dataset_test + |- Lotsa16B + |- Lotsa1B + |- Lotsa100M + |- Lotsa10M + |- LSF + |- Monash +``` + +#### Generation model training data +Extracte dataset from for diffusion model. The dataset is extracted from the Lotsa100M dataset with a sampling rate 5% of the dataset in 20 selected subdatasets. + +```bash +python extract_data_generation.py -cp cli/conf/pretrain\ + -cn default_ddp_val_enc\ + model=encoder_10M\ + model.enable_influence_scoring=true\ + data=lotsa100M_weighted\ + trainer.max_epochs=0\ + model.num_warmup_steps=0 +``` +The directory organization structure is as follows: +```bash +extracted_label_patches_australian_electricity_demand.npy +extracted_label_patches_azure_vm_traces_2017.npy +extracted_label_patches_buildings_900k.npy +extracted_label_patches_CloudOpsTSF_dataset.npy +extracted_label_patches_CMIP6_dataset.npy +... +``` + +### Environment setting + +```bash +# Clone the repository +git clone https://github.com/microsoft/TimeCraft.git +cd TimeCraft/OATS + +# Create and activate conda environment +conda env create -f environment.yml +conda activate oats +``` + + +## Quick Start +Step 1. Train a time series generation model with the extracted sampled data. +```bash +cd models/gen_model + +python main_train.py --base configs/multi_domain_timedp_local.yaml --gpus 0, --logdir ./logs/ -sl 320 -up -nl 16 --batch_size 128 -lr 0.0001 -s 0 +``` + +Step 2. Train the time series foundation model +```bash +python -m cli.train_val\ + -cp conf/pretrain\ + -cn default_ddp_val_enc\ + model=encoder\ + model.enable_influence_scoring=true\ + data=lotsa100M_weighted\ + val_data=all\ + trainer.logger.project=TSFM_PRETRAIN\ + run_name=encoder10M_etth1_develop\ + model.generate_after_epoch=0\ + model.influence_filter_ratio=1.0\ + model.select_from_generated=false +``` + +Outputs: The results can be found in wandb log and `./outputs/pretrain/` + + diff --git a/OATS/assets/method.png b/OATS/assets/method.png new file mode 100644 index 0000000..e857885 Binary files /dev/null and b/OATS/assets/method.png differ diff --git a/OATS/environment.yml b/OATS/environment.yml new file mode 100644 index 0000000..6d31a2c --- /dev/null +++ b/OATS/environment.yml @@ -0,0 +1,174 @@ +name: oats +channels: + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - asttokens=3.0.0=pyhd8ed1ab_1 + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2025.8.3=hbd8a1cb_0 + - comm=0.2.3=pyhe01879c_0 + - debugpy=1.8.11=py310h6a678d5_0 + - decorator=5.2.1=pyhd8ed1ab_0 + - exceptiongroup=1.3.0=pyhd8ed1ab_0 + - executing=2.2.1=pyhd8ed1ab_0 + - expat=2.7.1=h6a678d5_0 + - importlib-metadata=8.7.0=pyhe01879c_1 + - ipykernel=6.30.1=pyh82676e8_0 + - ipython=8.37.0=pyh8f84b5b_0 + - jedi=0.19.2=pyhd8ed1ab_1 + - jupyter_client=8.6.3=pyhd8ed1ab_1 + - jupyter_core=5.8.1=pyh31011fe_0 + - ld_impl_linux-64=2.40=h12ee557_0 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libsodium=1.0.18=h36c2ea0_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - libuuid=1.41.5=h5eee18b_0 + - libxcb=1.17.0=h9b100fa_0 + - matplotlib-inline=0.1.7=pyhd8ed1ab_1 + - ncurses=6.5=h7934f7d_0 + - nest-asyncio=1.6.0=pyhd8ed1ab_1 + - openssl=3.0.17=h5eee18b_0 + - packaging=25.0=pyh29332c3_1 + - parso=0.8.5=pyhcf101f3_0 + - pexpect=4.9.0=pyhd8ed1ab_1 + - pickleshare=0.7.5=pyhd8ed1ab_1004 + - pip=25.1=pyhc872135_2 + - prompt-toolkit=3.0.52=pyha770c72_0 + - psutil=5.9.0=py310h5eee18b_1 + - pthread-stubs=0.3=h0ce48e5_1 + - ptyprocess=0.7.0=pyhd8ed1ab_1 + - pure_eval=0.2.3=pyhd8ed1ab_1 + - pygments=2.19.2=pyhd8ed1ab_0 + - python=3.10.18=h1a3bd86_0 + - python-dateutil=2.9.0.post0=pyhe01879c_2 + - pyzmq=26.2.0=py310h6a678d5_0 + - readline=8.3=hc2a1206_0 + - setuptools=78.1.1=py310h06a4308_0 + - six=1.17.0=pyhe01879c_1 + - sqlite=3.50.2=hb25bd0a_1 + - stack_data=0.6.3=pyhd8ed1ab_1 + - tk=8.6.15=h54e0aa7_0 + - tornado=6.5.1=py310h5eee18b_0 + - traitlets=5.14.3=pyhd8ed1ab_1 + - typing_extensions=4.15.0=pyhcf101f3_0 + - wcwidth=0.2.13=pyhd8ed1ab_1 + - wheel=0.45.1=py310h06a4308_0 + - xorg-libx11=1.8.12=h9b100fa_1 + - xorg-libxau=1.0.12=h9b100fa_0 + - xorg-libxdmcp=1.1.5=h9b100fa_0 + - xorg-xorgproto=2024.1=h5eee18b_1 + - xz=5.6.4=h5eee18b_1 + - zeromq=4.3.5=h6a678d5_0 + - zipp=3.23.0=pyhd8ed1ab_0 + - zlib=1.2.13=h5eee18b_1 + - pip: + - absl-py==2.3.1 + - aiohappyeyeballs==2.6.1 + - aiohttp==3.12.15 + - aiosignal==1.4.0 + - annotated-types==0.7.0 + - antlr4-python3-runtime==4.9.3 + - async-timeout==5.0.1 + - attrs==25.3.0 + - certifi==2025.8.3 + - charset-normalizer==3.4.3 + - click==8.2.1 + - contourpy==1.3.2 + - cycler==0.12.1 + - datasets==2.17.1 + - dattri==0.2.0 + - dill==0.3.8 + - einops==0.7.0 + - fast-jl==0.1.3 + - filelock==3.19.1 + - fonttools==4.59.2 + - frozenlist==1.7.0 + - fsspec==2023.10.0 + - gitdb==4.0.12 + - gitpython==3.1.45 + - gluonts==0.14.4 + - grpcio==1.74.0 + - hf-xet==1.1.8 + - huggingface-hub==0.34.4 + - hydra-core==1.3.0 + - idna==3.10 + - jax==0.6.1 + - jaxlib==0.6.1 + - jaxtyping==0.2.38 + - jinja2==3.1.6 + - joblib==1.5.2 + - kiwisolver==1.4.9 + - lightning==2.5.3 + - lightning-utilities==0.15.2 + - markdown==3.8.2 + - markupsafe==3.0.2 + - matplotlib==3.10.5 + - mido==1.3.3 + - ml-dtypes==0.5.3 + - mpmath==1.3.0 + - multidict==6.6.4 + - multiprocess==0.70.16 + - networkx==3.4.2 + - numpy==1.26.4 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==9.1.0.70 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.20.5 + - nvidia-nvjitlink-cu12==12.9.86 + - nvidia-nvtx-cu12==12.1.105 + - omegaconf==2.3.0 + - opt-einsum==3.4.0 + - orjson==3.11.2 + - pandas==2.1.4 + - patsy==1.0.1 + - pillow==11.3.0 + - platformdirs==4.3.8 + - pretty-midi==0.2.10 + - propcache==0.3.2 + - protobuf==6.32.0 + - pyarrow==21.0.0 + - pyarrow-hotfix==0.7 + - pydantic==2.11.7 + - pydantic-core==2.33.2 + - pyparsing==3.2.3 + - python-dotenv==1.0.0 + - pytorch-lightning==2.5.3 + - pytz==2025.2 + - pyyaml==6.0.2 + - requests==2.32.5 + - safetensors==0.6.2 + - scikit-learn==1.7.1 + - scipy==1.11.4 + - sentry-sdk==2.35.0 + - smmap==5.0.2 + - statsmodels==0.14.5 + - sympy==1.14.0 + - tensorboard==2.20.0 + - tensorboard-data-server==0.7.2 + - threadpoolctl==3.6.0 + - toolz==0.12.1 + - torch==2.4.1 + - torchaudio==2.4.1 + - torchmetrics==1.8.1 + - torchvision==0.19.1 + - tqdm==4.67.1 + - triton==3.0.0 + - typing-extensions==4.14.1 + - typing-inspection==0.4.1 + - tzdata==2025.2 + - urllib3==2.5.0 + - wadler-lindig==0.1.7 + - wandb==0.21.1 + - werkzeug==3.1.3 + - xxhash==3.5.0 + - yarl==1.20.1 diff --git a/OATS/models/.env b/OATS/models/.env new file mode 100644 index 0000000..b7a69cc --- /dev/null +++ b/OATS/models/.env @@ -0,0 +1,4 @@ +LOTSA_16B_PATH=PATH/TO/LOTSA_16B +LOTSA_1B_PATH=PATH/TO/LOTSA_1B +LOTSA_100M_PATH=./dataset_train/Lotsa100M/ +LOTSA_10M_PATH=PATH/TO/LOTSA_10M \ No newline at end of file diff --git a/OATS/models/cli/conf/eval/data/lsf_test.yaml b/OATS/models/cli/conf/eval/data/lsf_test.yaml new file mode 100644 index 0000000..96b4569 --- /dev/null +++ b/OATS/models/cli/conf/eval/data/lsf_test.yaml @@ -0,0 +1,4 @@ +_target_: tsfm.eval_util.data.get_lsf_test_dataset +dataset_name: ??? +prediction_length: ??? +mode: S \ No newline at end of file diff --git a/OATS/models/cli/conf/eval/data/monash.yaml b/OATS/models/cli/conf/eval/data/monash.yaml new file mode 100644 index 0000000..1957f83 --- /dev/null +++ b/OATS/models/cli/conf/eval/data/monash.yaml @@ -0,0 +1,4 @@ +_target_: tsfm.eval_util.data.get_gluonts_test_dataset +dataset_name: ??? +prediction_length: null +mode: S \ No newline at end of file diff --git a/OATS/models/cli/conf/eval/default.yaml b/OATS/models/cli/conf/eval/default.yaml new file mode 100644 index 0000000..95760c7 --- /dev/null +++ b/OATS/models/cli/conf/eval/default.yaml @@ -0,0 +1,24 @@ +hydra: + run: + dir: outputs/${hydra:job.name}/${hydra:runtime.choices.data}/${data.dataset_name}/${data.mode}/prediction_length=${data.prediction_length}/${run_name} +defaults: + - model: ??? + - data: ??? + - _self_ +run_name: ??? +metrics: + - _target_: gluonts.ev.metrics.MSE + - _target_: tsfm.eval_util.metrics.MedianMSE + - _target_: gluonts.ev.metrics.MAE + - _target_: gluonts.ev.metrics.MASE + - _target_: gluonts.ev.metrics.MAPE + - _target_: gluonts.ev.metrics.SMAPE + - _target_: gluonts.ev.metrics.MSIS + - _target_: gluonts.ev.metrics.RMSE + - _target_: gluonts.ev.metrics.NRMSE + - _target_: gluonts.ev.metrics.ND + - _target_: gluonts.ev.metrics.MeanWeightedSumQuantileLoss + quantile_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] +batch_size: 512 +min_batch_size: 1 +device: auto \ No newline at end of file diff --git a/OATS/models/cli/conf/eval/model/decoder_lightning_ckpt.yaml b/OATS/models/cli/conf/eval/model/decoder_lightning_ckpt.yaml new file mode 100644 index 0000000..bfe9c07 --- /dev/null +++ b/OATS/models/cli/conf/eval/model/decoder_lightning_ckpt.yaml @@ -0,0 +1,5 @@ +_target_: tsfm.model.decoder.TransformerDecoderForecast.load_from_checkpoint +checkpoint_path: ... +num_samples: 100 +patch_size: 32 +context_length: ??? diff --git a/OATS/models/cli/conf/eval/model/encoder_lightning_ckpt.yaml b/OATS/models/cli/conf/eval/model/encoder_lightning_ckpt.yaml new file mode 100644 index 0000000..41f6e20 --- /dev/null +++ b/OATS/models/cli/conf/eval/model/encoder_lightning_ckpt.yaml @@ -0,0 +1,5 @@ +_target_: tsfm.model.encoder.TransformerEncoderForecast.load_from_checkpoint +checkpoint_path: ... +num_samples: 100 +patch_size: 32 +context_length: ??? diff --git a/OATS/models/cli/conf/pretrain/data/lotsa100M_weighted.yaml b/OATS/models/cli/conf/pretrain/data/lotsa100M_weighted.yaml new file mode 100644 index 0000000..aea125c --- /dev/null +++ b/OATS/models/cli/conf/pretrain/data/lotsa100M_weighted.yaml @@ -0,0 +1,117 @@ +_target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex +_args_: + - _target_: tsfm.data.builder.lotsa_v1.Buildings900KDatasetBuilder + datasets: ${cls_getattr:${._target_},dataset_list} + storage_path: ${env_get:LOTSA_100M_PATH} + weight_map: + buildings_900k: 0.40 + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + - _target_: tsfm.data.builder.lotsa_v1.BuildingsBenchDatasetBuilder + datasets: ["sceaux", "bdg-2_panther", "bdg-2_fox", "bdg-2_bear"] + storage_path: ${env_get:LOTSA_100M_PATH} + weight_map: + sceaux: 6.76 + bdg-2_panther: 1.73 + bdg-2_fox: 3.40 + bdg-2_bear: 3.22 + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + - _target_: tsfm.data.builder.lotsa_v1.CloudOpsTSFDatasetBuilder + datasets: ["azure_vm_traces_2017", "borg_cluster_data_2011", "alibaba_cluster_trace_2018"] + storage_path: ${env_get:LOTSA_100M_PATH} + weight_map: + azure_vm_traces_2017: 1.04 + borg_cluster_data_2011: 0.58 + alibaba_cluster_trace_2018: 0.32 + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + - _target_: tsfm.data.builder.lotsa_v1.CMIP6DatasetBuilder + datasets: ['cmip6_1850', 'cmip6_1855'] + storage_path: ${env_get:LOTSA_100M_PATH} + weight_map: + cmip6_1850: 0.85 + cmip6_1855: 0.85 + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + - _target_: tsfm.data.builder.lotsa_v1.ERA5DatasetBuilder + datasets: ['era5_1989', 'era5_1990'] + storage_path: ${env_get:LOTSA_100M_PATH} + weight_map: + era5_1989: 1.56 + era5_1990: 1.56 + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + - _target_: tsfm.data.builder.lotsa_v1.GluonTSDatasetBuilder + datasets: ["wiki-rolling_nips", "solar_power", "elecdemand", "kaggle_web_traffic_weekly", "traffic_weekly", "australian_electricity_demand"] + storage_path: ${env_get:LOTSA_100M_PATH} + weight_map: + wiki-rolling_nips: 0.168 + solar_power: 1460.95 + elecdemand: 3.46 + kaggle_web_traffic_weekly: 0.02 + traffic_weekly: 0.02 + australian_electricity_demand: 45.57 + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + - _target_: tsfm.data.builder.lotsa_v1.LargeSTDatasetBuilder + datasets: ["largest_2017", "largest_2018", "largest_2019", "largest_2020", "largest_2021"] + storage_path: ${env_get:LOTSA_100M_PATH} + weight_map: + largest_2017: 19.36 + largest_2018: 19.36 + largest_2019: 19.36 + largest_2020: 19.36 + largest_2021: 19.36 + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + - _target_: tsfm.data.builder.lotsa_v1.LibCityDatasetBuilder + datasets: ["LOOP_SEATTLE", "PEMS04", "PEMS07", "PEMS08", "PEMS_BAY", "Q-TRAFFIC"] + storage_path: ${env_get:LOTSA_100M_PATH} + weight_map: + LOOP_SEATTLE: 20.76 + PEMS04: 3.35 + PEMS07: 5.57 + PEMS08: 3.53 + PEMS_BAY: 10.30 + Q-TRAFFIC: 1.16 + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + - _target_: tsfm.data.builder.lotsa_v1.OthersLOTSADatasetBuilder + datasets: ["favorita_sales", "AtrialFibrillation", "BIDMC32HR", "IEEEPPG", "MotorImagery", "PigArtPressure", "PigCVP", "SelfRegulationSCP1", "SelfRegulationSCP2", "TDBrain"] + storage_path: ${env_get:LOTSA_100M_PATH} + weight_map: + favorita_sales: 0.25 + AtrialFibrillation: 0.13 + BIDMC32HR: 0.79 + IEEEPPG: 0.59 + MotorImagery: 0.59 + PigArtPressure: 0.40 + PigCVP: 0.40 + SelfRegulationSCP1: 0.18 + SelfRegulationSCP2: 0.23 + TDBrain: 0.51 + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + - _target_: tsfm.data.builder.lotsa_v1.ProEnFoDatasetBuilder + datasets: ["gfc14_load", "gfc17_load", "spain", "pdb", "elf", "covid19_energy"] + storage_path: ${env_get:LOTSA_100M_PATH} + weight_map: + gfc14_load: 3.46 + gfc17_load: 3.46 + spain: 6.93 + pdb: 3.46 + elf: 4.30 + covid19_energy: 6.30 + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] \ No newline at end of file diff --git a/OATS/models/cli/conf/pretrain/default_ddp_val_enc.yaml b/OATS/models/cli/conf/pretrain/default_ddp_val_enc.yaml new file mode 100644 index 0000000..ab44682 --- /dev/null +++ b/OATS/models/cli/conf/pretrain/default_ddp_val_enc.yaml @@ -0,0 +1,88 @@ +hydra: + run: + # dir: outputs/pretrain/${hydra:runtime.choices.model}/${hydra:runtime.choices.data}/${run_name} + dir: ./outputs/pretrain/${hydra:runtime.choices.model}/${hydra:runtime.choices.data}/${run_name} +defaults: + - model: ??? + - data: ??? + - val_data: null + - _self_ +run_name: ??? +seed: 0 +tf32: true +compile: false # set to mode: default, reduce-overhead, max-autotune +ckpt_path: null # set to "last" to resume training +trainer: + _target_: lightning.Trainer + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 32 + logger: + _target_: lightning.pytorch.loggers.WandbLogger + save_dir: ${hydra:runtime.output_dir} + name: logs + project: ??? + callbacks: + - _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch + # - _target_: lightning.pytorch.callbacks.ModelCheckpoint + # dirpath: ${hydra:runtime.output_dir}/checkpoints + # filename: last + # monitor: epoch + # mode: max + # save_top_k: 1 + # every_n_epochs: 10 + - _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${hydra:runtime.output_dir}/checkpoints + monitor: epoch + save_weights_only: true + mode: max + save_top_k: -1 + every_n_epochs: ${floordiv:${trainer.max_epochs},10} + # epoch-based training provides averaged metrics + # cannot use max_steps with epoch-based training - resume from checkpoint on wrong epoch + max_epochs: 100 # 1_000 + enable_progress_bar: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + gradient_clip_algorithm: norm + limit_val_batches: 1.0 + # check_val_every_n_epoch: 1 + val_check_interval: 200 # Validate every 50 steps + num_sanity_val_steps: -1 +train_dataloader: + _target_: tsfm.data.loader.DataLoader + batch_size: 32 # 32 + batch_size_factor: 2.0 + cycle: true + num_batches_per_epoch: 400 # 400 + shuffle: true + num_workers: 11 + collate_fn: + _target_: tsfm.data.loader.PadCollateWithDatasetIndex + max_length: ${model.module_kwargs.max_seq_len} + seq_fields: ${cls_getattr:${model._target_},seq_fields} + pad_func_map: ${cls_getattr:${model._target_},pad_func_map} + pin_memory: true + drop_last: true + fill_last: false + worker_init_fn: null + prefetch_factor: 2 + persistent_workers: true +val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + shuffle: false + num_workers: 11 + collate_fn: + _target_: tsfm.data.loader.PadCollateWithDatasetIndex + max_length: 60 # 48 + seq_fields: ${cls_getattr:${model._target_},seq_fields} + pad_func_map: ${cls_getattr:${model._target_},pad_func_map} + pin_memory: false + drop_last: false + worker_init_fn: null + prefetch_factor: 2 + persistent_workers: true \ No newline at end of file diff --git a/OATS/models/cli/conf/pretrain/model/encoder.yaml b/OATS/models/cli/conf/pretrain/model/encoder.yaml new file mode 100644 index 0000000..29f870c --- /dev/null +++ b/OATS/models/cli/conf/pretrain/model/encoder.yaml @@ -0,0 +1,40 @@ +_target_: tsfm.model.encoder.TransformerEncoderPretrain +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: tsfm.distribution.MixtureOutput + components: + - _target_: tsfm.distribution.StudentTOutput + - _target_: tsfm.distribution.StudentTOutput + - _target_: tsfm.distribution.StudentTOutput + - _target_: tsfm.distribution.StudentTOutput + d_model: 384 + num_layers: 8 + patch_size: 32 + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 1 +loss_func: + _target_: tsfm.loss.packed.PackedNLLLoss +lr: 1e-3 +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} +num_warmup_steps: 10_000 # 10_000 +enable_influence_scoring: true # Set to false to disable influence scoring +enable_dataset_contribution_logging: true +enable_reweighting: true +num_low_influence_to_remove: 16 +# influence_filter_frequency: 1 +influence_filter_ratio: 0.7 +use_cosine_similarity: false +generate_after_epoch: 30 +select_from_generated: false +add_noise: false +mixup: true \ No newline at end of file diff --git a/OATS/models/cli/conf/pretrain/val_data/all_local.yaml b/OATS/models/cli/conf/pretrain/val_data/all_local.yaml new file mode 100644 index 0000000..90cc3ca --- /dev/null +++ b/OATS/models/cli/conf/pretrain/val_data/all_local.yaml @@ -0,0 +1,347 @@ +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: electricity + offset: 21044 + windows: 22.0 + distance: 192 + prediction_length: 192 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: electricity + offset: 21044 + windows: 22.0 + distance: 96 + prediction_length: 96 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: electricity + offset: 21044 + windows: 22.0 + distance: 336 + prediction_length: 336 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: electricity + offset: 21044 + windows: 22.0 + distance: 720 + prediction_length: 720 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh1 + offset: 11520 + windows: 25.0 + distance: 192 + prediction_length: 192 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh1 + offset: 11520 + windows: 25.0 + distance: 96 + prediction_length: 96 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh1 + offset: 11520 + windows: 25.0 + distance: 336 + prediction_length: 336 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh1 + offset: 11520 + windows: 25.0 + distance: 720 + prediction_length: 720 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh2 + offset: 11520 + windows: 25.0 + distance: 192 + prediction_length: 192 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh2 + offset: 11520 + windows: 25.0 + distance: 96 + prediction_length: 96 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh2 + offset: 11520 + windows: 25.0 + distance: 336 + prediction_length: 336 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] + +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh2 + offset: 11520 + windows: 25.0 + distance: 720 + prediction_length: 720 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: weather + offset: 42156 + windows: 49.0 + distance: 192 + prediction_length: 192 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: weather + offset: 42156 + windows: 49.0 + distance: 96 + prediction_length: 96 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: weather + offset: 42156 + windows: 49.0 + distance: 336 + prediction_length: 336 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: weather + offset: 42156 + windows: 49.0 + distance: 720 + prediction_length: 720 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTm1 + offset: 46080 + windows: 117.0 + distance: 192 + prediction_length: 192 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTm1 + offset: 46080 + windows: 117.0 + distance: 96 + prediction_length: 96 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTm1 + offset: 46080 + windows: 117.0 + distance: 336 + prediction_length: 336 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTm1 + offset: 46080 + windows: 117.0 + distance: 720 + prediction_length: 720 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTm2 + offset: 46080 + windows: 117.0 + distance: 192 + prediction_length: 192 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTm2 + offset: 46080 + windows: 117.0 + distance: 96 + prediction_length: 96 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTm2 + offset: 46080 + windows: 117.0 + distance: 336 + prediction_length: 336 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] +- _target_: tsfm.data.builder.ConcatDatasetBuilderWithGlobalIndex + _args_: + - _target_: tsfm.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTm2 + offset: 46080 + windows: 117.0 + distance: 720 + prediction_length: 720 + context_length: 1000 + patch_size: 32 + storage_path: dataset_test/LSF/ + sample_time_series: + _target_: tsfm.data.dataset.SampleTimeSeriesType + _args_: ["proportional"] \ No newline at end of file diff --git a/OATS/models/cli/train_val.py b/OATS/models/cli/train_val.py new file mode 100644 index 0000000..2da5cee --- /dev/null +++ b/OATS/models/cli/train_val.py @@ -0,0 +1,176 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from functools import partial +from typing import Callable, Optional + +import hydra +import lightning as L +import torch +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch.utils._pytree import tree_map +from torch.utils.data import Dataset, DistributedSampler + +from tsfm.common import hydra_util # noqa: hydra resolvers +from tsfm.data.loader import DataLoader + + + +class DataModule(L.LightningDataModule): + def __init__( + self, + cfg: DictConfig, + train_dataset: Dataset, + val_dataset: Optional[Dataset | list[Dataset]], + data_builder=None, + ): + super().__init__() + self.cfg = cfg + self.train_dataset = train_dataset + self.data_builder = data_builder + + if val_dataset is not None: + self.val_dataset = val_dataset + self.val_dataloader = self._val_dataloader + + @staticmethod + def get_dataloader( + dataset: Dataset, + dataloader_func: Callable[..., DataLoader], + shuffle: bool, + world_size: int, + batch_size: int, + num_batches_per_epoch: Optional[int] = None, + ) -> DataLoader: + sampler = ( + DistributedSampler( + dataset, + num_replicas=None, + rank=None, + shuffle=shuffle, + seed=0, + drop_last=False, + ) + if world_size > 1 + else None + ) + return dataloader_func( + dataset=dataset, + shuffle=shuffle if sampler is None else None, + sampler=sampler, + batch_size=batch_size, + num_batches_per_epoch=num_batches_per_epoch, + ) + + def train_dataloader(self) -> DataLoader: + return self.get_dataloader( + self.train_dataset, + instantiate(self.cfg.train_dataloader, _partial_=True), + self.cfg.train_dataloader.shuffle, + self.trainer.world_size, + self.train_batch_size, + num_batches_per_epoch=self.train_num_batches_per_epoch, + ) + + @staticmethod + def get_torch_dataloader( + dataset: Dataset, + dataloader_func: Callable[..., DataLoader], + shuffle: bool, + world_size: int, + batch_size: int, + num_batches_per_epoch: Optional[int] = None, + ) -> DataLoader: + sampler = ( + DistributedSampler( + dataset, + num_replicas=None, + rank=None, + shuffle=shuffle, + seed=0, + drop_last=False, + ) + if world_size > 1 + else None + ) + return dataloader_func( + dataset=dataset, + shuffle=shuffle if sampler is None else None, + sampler=sampler, + batch_size=batch_size, + ) + + def _val_dataloader(self) -> DataLoader | list[DataLoader]: + return tree_map( + partial( + self.get_torch_dataloader, + dataloader_func=instantiate(self.cfg.val_dataloader, _partial_=True), + shuffle=self.cfg.val_dataloader.shuffle, + world_size=self.trainer.world_size, + batch_size=self.val_batch_size, + num_batches_per_epoch=None, + ), + self.val_dataset, + ) + + @property + def train_batch_size(self) -> int: + return self.cfg.train_dataloader.batch_size // ( + self.trainer.world_size * self.trainer.accumulate_grad_batches + ) + + @property + def val_batch_size(self) -> int: + return self.cfg.val_dataloader.batch_size // ( + self.trainer.world_size * self.trainer.accumulate_grad_batches + ) + + @property + def train_num_batches_per_epoch(self) -> int: + return ( + self.cfg.train_dataloader.num_batches_per_epoch + * self.trainer.accumulate_grad_batches + ) + + +@hydra.main(version_base="1.3", config_name="default.yaml") +def main(cfg: DictConfig): + if cfg.tf32: + assert cfg.trainer.precision == 32 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + model: L.LightningModule = instantiate(cfg.model, _convert_="all") + + if cfg.compile: + model.module.compile(mode=cfg.compile) + trainer: L.Trainer = instantiate(cfg.trainer) + + # Instantiate the data builder and create dataset + data_builder = instantiate(cfg.data) + train_dataset: Dataset = data_builder.load_dataset(model.train_transform_map) + + val_dataset: Optional[Dataset | list[Dataset]] = ( + tree_map( + lambda ds: ds.load_dataset(model.val_transform_map), + instantiate(cfg.val_data, _convert_="all"), + ) + if "val_data" in cfg + else None + ) + L.seed_everything(cfg.seed, workers=True) + print("train_dataset:", train_dataset) + print("val_dataset:", val_dataset) + print("train_dataset size:", len(train_dataset)) + # print("val_dataset size:", len(val_dataset[0]), len(val_dataset[1]), len(val_dataset[2])) + print("val_dataset size:", len(val_dataset[0])) + trainer.fit( + model, + datamodule=DataModule(cfg, train_dataset, val_dataset, data_builder), # Pass data_builder + ckpt_path=cfg.ckpt_path, + ) + + +if __name__ == "__main__": + main() diff --git a/OATS/models/extract_data_generation.py b/OATS/models/extract_data_generation.py new file mode 100644 index 0000000..aa36e1d --- /dev/null +++ b/OATS/models/extract_data_generation.py @@ -0,0 +1,275 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from functools import partial +from typing import Callable, Optional +import os + +import hydra +import lightning as L +import torch +import numpy as np +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch.utils._pytree import tree_map +from torch.utils.data import Dataset, DistributedSampler + +from tsfm.common import hydra_util # noqa: hydra resolvers +from tsfm.data.loader import DataLoader + + +class DataModule(L.LightningDataModule): + def __init__( + self, + cfg: DictConfig, + train_dataset: Dataset, + val_dataset: Optional[Dataset | list[Dataset]], + data_builder=None, + ): + super().__init__() + self.cfg = cfg + self.train_dataset = train_dataset + self.data_builder = data_builder + + if val_dataset is not None: + self.val_dataset = val_dataset + self.val_dataloader = self._val_dataloader + + @staticmethod + def get_dataloader( + dataset: Dataset, + dataloader_func: Callable[..., DataLoader], + shuffle: bool, + world_size: int, + batch_size: int, + num_batches_per_epoch: Optional[int] = None, + ) -> DataLoader: + sampler = ( + DistributedSampler( + dataset, + num_replicas=None, + rank=None, + shuffle=shuffle, + seed=0, + drop_last=False, + ) + if world_size > 1 + else None + ) + return dataloader_func( + dataset=dataset, + shuffle=shuffle if sampler is None else None, + sampler=sampler, + batch_size=batch_size, + num_batches_per_epoch=num_batches_per_epoch, + ) + + def train_dataloader(self) -> DataLoader: + return self.get_dataloader( + self.train_dataset, + instantiate(self.cfg.train_dataloader, _partial_=True), + self.cfg.train_dataloader.shuffle, + self.trainer.world_size, + self.train_batch_size, + num_batches_per_epoch=self.train_num_batches_per_epoch, + ) + + @staticmethod + def get_torch_dataloader( + dataset: Dataset, + dataloader_func: Callable[..., DataLoader], + shuffle: bool, + world_size: int, + batch_size: int, + num_batches_per_epoch: Optional[int] = None, + ) -> DataLoader: + sampler = ( + DistributedSampler( + dataset, + num_replicas=None, + rank=None, + shuffle=shuffle, + seed=0, + drop_last=False, + ) + if world_size > 1 + else None + ) + return dataloader_func( + dataset=dataset, + shuffle=shuffle if sampler is None else None, + sampler=sampler, + batch_size=batch_size, + ) + + def _val_dataloader(self) -> DataLoader | list[DataLoader]: + return tree_map( + partial( + self.get_torch_dataloader, + dataloader_func=instantiate(self.cfg.val_dataloader, _partial_=True), + shuffle=self.cfg.val_dataloader.shuffle, + world_size=self.trainer.world_size, + batch_size=self.val_batch_size, + num_batches_per_epoch=None, + ), + self.val_dataset, + ) + + @property + def train_batch_size(self) -> int: + return self.cfg.train_dataloader.batch_size // ( + self.trainer.world_size * self.trainer.accumulate_grad_batches + ) + + @property + def val_batch_size(self) -> int: + return self.cfg.val_dataloader.batch_size // ( + self.trainer.world_size * self.trainer.accumulate_grad_batches + ) + + @property + def train_num_batches_per_epoch(self) -> int: + return ( + self.cfg.train_dataloader.num_batches_per_epoch + * self.trainer.accumulate_grad_batches + ) + + +@hydra.main(version_base="1.3", config_name="default.yaml") +def main(cfg: DictConfig): + if cfg.tf32: + assert cfg.trainer.precision == 32 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + model: L.LightningModule = instantiate(cfg.model, _convert_="all") + + if cfg.compile: + model.module.compile(mode=cfg.compile) + trainer: L.Trainer = instantiate(cfg.trainer) + + # Instantiate the data builder and create dataset + data_builder = instantiate(cfg.data) + train_dataset: Dataset = data_builder.load_dataset(model.train_transform_map) + + val_dataset: Optional[Dataset | list[Dataset]] = ( + tree_map( + lambda ds: ds.load_dataset(model.val_transform_map), + instantiate(cfg.val_data, _convert_="all"), + ) + if "val_data" in cfg + else None + ) + L.seed_everything(cfg.seed, workers=True) + + print("train_dataset size:", len(train_dataset)) + + datamodule = DataModule(cfg, train_dataset, [val_dataset[0]], data_builder) + + # only work when epoch=0 + trainer.fit( + model, + datamodule=datamodule, # Pass data_builder + ckpt_path=cfg.ckpt_path, + ) + + # Initialize dictionary to collect processed patches by dataset + dataset_patches = {} + + print("Processing training dataloader...") + print("Expected to process approximately 15000 samples") + + total_samples = 0 + + # Go through the whole training dataloader + for batch_idx, batch in enumerate(datamodule.train_dataloader()): + if batch_idx % 100 == 0: + print(f"Processing batch {batch_idx}, total samples so far: {total_samples}") + + # Get the label data - shape should be (batch_size, 512, 32) + labels = batch['label'] # (batch_size, 512, 32) + batch_size = labels.shape[0] + + # Get the time_id data to check for valid positions + time_ids = batch['time_id'] # (batch_size, 512) + + # Get the dataset indices to determine which subdataset each sample belongs to + dataset_indices = batch.get('_dataset_idx', None) + + # Process each sample in the batch + for sample_idx in range(batch_size): + label_sample = labels[sample_idx] # (512, 32) + time_id_sample = time_ids[sample_idx] # (512,) + + # Determine which subdataset this sample belongs to + if dataset_indices is not None: + global_idx = dataset_indices[sample_idx].item() + # Get dataset name using the data_builder's method + if hasattr(data_builder, 'get_dataset_name_for_global_index'): + dataset_name = data_builder.get_dataset_name_for_global_index(global_idx) + else: + dataset_name = f"dataset_{global_idx}" + else: + dataset_name = "train" + + # Initialize list for this dataset if not exists + if dataset_name not in dataset_patches: + dataset_patches[dataset_name] = [] + + # Sample 10 successive patches from the second dimension (512) + # We need to ensure we don't go out of bounds and time_ids are valid + if label_sample.shape[0] >= 10: + # Find valid starting positions where time_id is non-zero for at least 10 consecutive positions + # (except possibly the first one can be 0) + valid_start_positions = [] + + for potential_start in range(label_sample.shape[0] - 9): + # Check if the next 9 positions (after the first) have non-zero time_id + time_slice = time_id_sample[potential_start:potential_start + 10] + # Allow first one to be 0, but the rest should be non-zero + if torch.sum(time_slice[1:] != 0) == 9: # All 9 positions after first should be non-zero + valid_start_positions.append(potential_start) + + if len(valid_start_positions) > 0: + # Randomly select from valid positions + start_idx = valid_start_positions[torch.randint(0, len(valid_start_positions), (1,)).item()] + + # Extract 10 successive patches + patches = [] + for i in range(10): + patch = label_sample[start_idx + i] # (32,) + patches.append(patch) + + # Concatenate the 10 patches: [patch1(32), patch2(32), ..., patch10(32)] + concatenated_patches = torch.cat(patches, dim=0) # (320,) + + # Reshape to (320, 1) and add batch dimension to get (1, 320, 1) + processed_sample = concatenated_patches.unsqueeze(-1).unsqueeze(0) # (1, 320, 1) + + # Convert to numpy and add to collection for this dataset + dataset_patches[dataset_name].append(processed_sample.cpu().numpy()) + total_samples += 1 + if len(dataset_patches[dataset_name]) == 0: + print(f"Skipping sample {sample_idx} from dataset {dataset_name} because it has less than 10 valid positions") + print("The valid start positions are: ", valid_start_positions) + + print(f"Total samples processed: {total_samples}") + print(f"Found {len(dataset_patches)} subdatasets") + + # Save each subdataset's patches to a separate file + for dataset_name, patches_list in dataset_patches.items(): + if patches_list: + # Concatenate all samples along the first dimension + final_array = np.concatenate(patches_list, axis=0) # (total_samples, 320, 1) + print(f"Dataset '{dataset_name}': {final_array.shape}") + + # Save to npy file in the analysis folder + output_path = os.path.join(os.path.dirname(__file__), f"extracted_label_patches_{dataset_name}.npy") + np.save(output_path, final_array) + print(f"Saved processed patches for '{dataset_name}' to: {output_path}") + else: + print(f"No samples were processed successfully for dataset '{dataset_name}'") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/OATS/models/gen_model/configs/multi_domain_timedp_local.yaml b/OATS/models/gen_model/configs/multi_domain_timedp_local.yaml new file mode 100644 index 0000000..8310c8c --- /dev/null +++ b/OATS/models/gen_model/configs/multi_domain_timedp_local.yaml @@ -0,0 +1,111 @@ +seq_length: &seqlen 96 +model: + base_learning_rate: 0.001 + target: ldm.models.diffusion.ddpm_time.LatentDiffusion + params: + linear_start: 0.0005 + linear_end: 0.1 + num_timesteps_cond: 1 + log_every_t: 40 + timesteps: 200 + loss_type: l1 + first_stage_key: "context" + cond_stage_key: "context" + seq_len: *seqlen + channels: 1 + cond_stage_trainable: True + concat_mode: False + scale_by_std: False # True + monitor: 'val/loss_simple_ema' + conditioning_key: crossattn + cond_drop_prob: 0.5 + class_condition: True + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [1000] + cycle_lengths: [10000000000000] + f_start: [1.e-6] + f_max: [1.] + f_min: [ 1.] + + unet_config: + target: ldm.modules.diffusionmodules.ts_unet.UNetModel + params: + seq_len: *seqlen + dims: 1 + in_channels: 1 + out_channels: 1 + model_channels: 64 + attention_resolutions: [ 1, 2, 4] + num_res_blocks: 2 + channel_mult: [ 1,2,4,4 ] + num_heads: 8 + use_scale_shift_norm: True + resblock_updown: True + context_dim: 32 + repre_emb_channels: 32 + latent_unit: 1 + use_spatial_transformer: true + use_pam: true + class_conditionnal_number: 20 + + first_stage_config: # no first stage model for ts data + target: ldm.models.autoencoder.IdentityFirstStage + + cond_stage_config: + target: ldm.modules.encoders.modules.DomainUnifiedPrototyper + params: + dim: 32 + window: *seqlen + latent_dim: 32 # 32 * 3 + num_latents: 16 + num_channels: 1 + +data: + target: ldm.data.tsg_dataset.TSGDataModule + params: + data_path_dict: + CMIP6_dataset: "extracted_label_patches_CMIP6_dataset.npy" + CloudOpsTSF_dataset: "extracted_label_patches_CloudOpsTSF_dataset.npy" + ERA5_dataset: "extracted_label_patches_ERA5_dataset.npy" + LOOP_SEATTLE: "extracted_label_patches_LOOP_SEATTLE.npy" + LibCity_dataset: "extracted_label_patches_LibCity_dataset.npy" + OthersLOTSA_dataset: "extracted_label_patches_OthersLOTSA_dataset.npy" + PEMS07: "extracted_label_patches_PEMS07.npy" + PEMS_BAY: "extracted_label_patches_PEMS_BAY.npy" + Q-TRAFFIC: "extracted_label_patches_Q-TRAFFIC.npy" + australian_electricity_demand: "extracted_label_patches_australian_electricity_demand.npy" + azure_vm_traces_2017: "extracted_label_patches_azure_vm_traces_2017.npy" + buildings_900k: "extracted_label_patches_buildings_900k.npy" + favorita_sales: "extracted_label_patches_favorita_sales.npy" + largest_2017: "extracted_label_patches_largest_2017.npy" + largest_2018: "extracted_label_patches_largest_2018.npy" + largest_2019: "extracted_label_patches_largest_2019.npy" + largest_2020: "extracted_label_patches_largest_2020.npy" + largest_2021: "extracted_label_patches_largest_2021.npy" + solar_power: "extracted_label_patches_solar_power.npy" + wiki-rolling_nips: "extracted_label_patches_wiki-rolling_nips.npy" + window: *seqlen + val_portion: 0.1 + batch_size: 256 + num_workers: 8 + normalize: centered_pit + drop_last: True + reweight: True + input_channels: 1 + +lightning: + callbacks: + image_logger: + target: utils.callback_utils.TSLogger + params: + batch_frequency: 2000 + max_images: 8 + increase_log_steps: false + + + trainer: + benchmark: True + max_steps: 50000 \ No newline at end of file diff --git a/OATS/models/gen_model/ldm/data/tsg_dataset.py b/OATS/models/gen_model/ldm/data/tsg_dataset.py new file mode 100644 index 0000000..a74d000 --- /dev/null +++ b/OATS/models/gen_model/ldm/data/tsg_dataset.py @@ -0,0 +1,370 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from torch.utils.data import Dataset, DataLoader +import pytorch_lightning as pl +import numpy as np +from datetime import datetime +import pandas as pd +from distutils.util import strtobool +from statsmodels.distributions.empirical_distribution import ECDF +from torch.utils.data import WeightedRandomSampler +from einops import rearrange + +class TSGDataset(Dataset): # For generation task. Unified Univariate Generation Dataset + def __init__(self, data_dict: dict): + for key, data in data_dict.items(): + assert data.ndim == 3, f"Data must be 3D, but {key} got {data.ndim}D." + assert data.shape[2] == 1, f"Only univariate time series are supported, but {key} got {data.shape[2]} channels." + self.data_dict = data_dict + self.cal_data_stats() + + def cal_data_stats(self): + total_items = 0 + n_items_dict = {} + key_list = [] + key_idx_list = [] + + for key, data in self.data_dict.items(): + num_items = data.shape[0] + total_items += num_items + n_items_dict[key] = num_items + key_list.append(key) + key_idx_list.append(total_items) + + self.total_items = total_items + self.items_dict = n_items_dict + self.key_list = key_list + self.key_idx_list = np.array(key_idx_list) + + def get_reweight_sampler(self): + dataset_weights = np.array([1 / self.items_dict[key] for key in self.key_list], dtype=np.float32) + sample_weights = np.repeat(dataset_weights, [self.items_dict[key] for key in self.key_list]) + sampler = WeightedRandomSampler(weights=sample_weights, num_samples=self.total_items, replacement=True) + return sampler + + def __len__(self): + return self.total_items # self.num_slices + + def __getitem__(self, idx): + assert idx < self.total_items, f"Index({idx}) must be less than number of items({self.total_items})." + data_key = np.where(self.key_idx_list > idx)[0].min() # np.argmin(self.key_idx_list > idx) + data_start_idx = self.key_idx_list[data_key-1] if data_key > 0 else 0 + data: np.ndarray = self.data_dict[self.key_list[data_key]] + + valid_idx = idx - data_start_idx + context = data[valid_idx,:,0] + + return { + 'context': context, # shape: (window,) + 'data_key': data_key + } + + +class TSGDataModule(pl.LightningDataModule): + ''' + Data module for unified time series generation task. + Slicing is also done with this module. So the train/val is i.i.d within train dataset. + ''' + def __init__(self, data_path_dict, window=96, val_portion=0.1, as_tensor:bool=True, normalize="centered_pit", batch_size=128, num_workers=0, pin_memory=True, drop_last=False, reweight=False, input_channels=1, **kwargs): + super().__init__() + self.data_path_dict = data_path_dict # {data_name: data_path} + self.data_dict = {} + self.norm_data_dict = {} + self.normalizer_dict = {} + self.norm_train_dict = {} + self.norm_val_dict = {} + self.window = window + self.val_portion = val_portion + self.as_tensor = as_tensor + assert normalize in [None, 'zscore', 'robust_iqr', 'robust_mad', 'pit', 'centered_pit', 'minmax'], f"Normalize({normalize}) must be in (zscore, robust_iqr, robust_mad, pit)." + self.normalize = normalize + + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.kwargs = kwargs + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.reweight = reweight + self.input_channels = input_channels + # self.transform = None + self.kwargs = kwargs + self.key_list = [] + self.drop_last = drop_last + + def prepare_data(self) -> None: + print(f"Normalizing data with: {self.normalize}") + self.key_list = [] + for data_name, data_path in self.data_path_dict.items(): + self.key_list.append(data_name) + this_data = load_data_from_file(data_path).astype(np.float32) + if this_data.ndim == 3: # in shape (N, T, C) + this_data = rearrange(this_data, 'n t c -> (n c) t 1') # to shape (N*C, T) + elif this_data.ndim == 2: + this_data = this_data[..., np.newaxis] # make shape (N, T, 1) + else: + raise ValueError(f"Unsupported data shape: {this_data.shape}") + # first normalize, then split + normalizer = self.fit_normalizer(this_data) + self.data_dict[data_name] = this_data + self.normalizer_dict[data_name] = normalizer + norm_data = self.transform(this_data, normalizer) + self.norm_data_dict[data_name] = norm_data + train_data, val_data = self.split_train_val(norm_data) # slice and split here + self.norm_train_dict[data_name] = train_data + self.norm_val_dict[data_name] = val_data + + print(f"Loaded data: {data_name}; Train shape: {train_data.shape}, Validation shape: {val_data.shape}.") + print(f"With normalizer fit as: {normalizer}") + + def split_train_val(self, data: np.ndarray): + # By default, data are sliced into non-overlapped sequences. + # shuffle stack_data, only along the first dimension + np.random.shuffle(data) + total_instances = data.shape[0] + num_val_instances = int(total_instances * self.val_portion) + train_data = data[:-num_val_instances] + val_data = data[-num_val_instances:] + + return train_data, val_data + + def train_dataloader(self): + train_dataset = TSGDataset(self.norm_train_dict) + sampler = None + if self.reweight: + sampler = train_dataset.get_reweight_sampler() + return DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=self.drop_last, sampler=sampler, **self.kwargs) + else: + return DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=True, drop_last=self.drop_last, **self.kwargs) + + def val_dataloader(self): + val_dataset = TSGDataset(self.norm_val_dict) + return DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, **self.kwargs) + + def test_dataloader(self, **kwargs): + return None + + def fit_normalizer(self, data: np.ndarray): + normalizer = {} + data = data.flatten() + if self.normalize == 'zscore': + normalizer['mean'] = np.nanmean(data) + normalizer['std'] = np.nanstd(data) + elif self.normalize == 'robust_iqr': + normalizer['median'] = np.median(data) + normalizer['iqr'] = np.subtract(*np.percentile(data, [75, 25])) + elif self.normalize == 'robust_mad': + normalizer['median'] = np.median(data) + normalizer['mad'] = np.median(np.abs(data - normalizer['median'])) + elif self.normalize == 'minmax': + normalizer['min'] = np.nanmin(data) + normalizer['max'] = np.nanmax(data) + elif self.normalize == 'pit' or self.normalize == 'centered_pit': + ecdf = ECDF(data) + normalizer['ecdf'] = ecdf + return normalizer + + def transform(self, data: np.ndarray, normalizer=None, data_name=None): + # if data_name is specified, the normalizer argument will be ignored. + assert normalizer is not None or data_name is not None, "Must specify either normalizer or data name." + if data_name is not None: + assert data_name in self.normalizer_dict.keys(), f"Data name({data_name}) must be in normalizer dict key." + normalizer = self.normalizer_dict[data_name] + if self.normalize == 'zscore': + return (data - normalizer['mean']) / (normalizer['std'] + 1e-8) + elif self.normalize == 'robust_iqr': + return (data - normalizer['median']) / (normalizer['iqr'] + 1e-8) + elif self.normalize == 'robust_mad': + return (data - normalizer['median']) / (normalizer['mad'] + 1e-8) + if self.normalize == 'minmax': + return (data - normalizer['min']) / (normalizer['max'] - normalizer['min'] + 1e-8) + elif self.normalize == 'pit' or self.normalize == 'centered_pit': + data_shape = data.shape + norm_data = normalizer['ecdf'](data.flatten()).reshape(data_shape) + if self.normalize == 'centered_pit': + norm_data = norm_data * 2 - 1 + return norm_data + + def inverse_transform(self, data: np.ndarray, normalizer=None, data_name=None): + # if data_name is specified, the normalizer argument will be ignored. + assert normalizer is not None or data_name is not None, "Must specify either normalizer or data name." + if data_name is not None: + assert data_name in self.normalizer_dict.keys(), f"Data name({data_name}) must be in normalizer dict key." + normalizer = self.normalizer_dict[data_name] + if self.normalize == 'zscore': + return data * normalizer['std'] + normalizer['mean'] + elif self.normalize == 'robust_iqr': + return data * normalizer['iqr'] + normalizer['median'] + elif self.normalize == 'robust_mad': + return data * normalizer['mad'] + normalizer['median'] + if self.normalize == 'minmax': + return data * (normalizer['max'] - normalizer['min']) + normalizer['min'] + elif self.normalize == 'pit' or self.normalize == 'centered_pit': + ecdf: ECDF = normalizer['ecdf'] + ecdf.x[0] = ecdf.x[1] + if self.normalize == 'centered_pit': + data = (data + 1) / 2 + return np.interp(data, ecdf.y, ecdf.x) + +def load_data_from_file(file_path: str): + if file_path.endswith(".csv"): + loaded_data = pd.read_csv(file_path) + return loaded_data.values # no index columns, by default. + elif file_path.endswith(".tsf"): + loaded_data, frequency, forecast_horizon, contain_missing_values, contain_equal_length = convert_tsf_to_dataframe( + file_path, + replace_missing_vals_with="NaN", + value_column_name="series_value" + ) + data = np.stack(loaded_data['series_value'].values).T + return data # no date column + elif file_path.endswith(".npy"): + loaded_data = np.load(file_path) # shape like (N, T) by default + return loaded_data.T + + +def convert_tsf_to_dataframe( + full_file_path_and_name, + replace_missing_vals_with="NaN", + value_column_name="series_value", +): + col_names = [] + col_types = [] + all_data = {} + line_count = 0 + frequency = None + forecast_horizon = None + contain_missing_values = None + contain_equal_length = None + found_data_tag = False + found_data_section = False + started_reading_data_section = False + + with open(full_file_path_and_name, "r", encoding="cp1252") as file: + for line in file: + # Strip white space from start/end of line + line = line.strip() + + if line: + if line.startswith("@"): # Read meta-data + if not line.startswith("@data"): + line_content = line.split(" ") + if line.startswith("@attribute"): + if ( + len(line_content) != 3 + ): # Attributes have both name and type + raise Exception("Invalid meta-data specification.") + + col_names.append(line_content[1]) + col_types.append(line_content[2]) + else: + if ( + len(line_content) != 2 + ): # Other meta-data have only values + raise Exception("Invalid meta-data specification.") + + if line.startswith("@frequency"): + frequency = line_content[1] + elif line.startswith("@horizon"): + forecast_horizon = int(line_content[1]) + elif line.startswith("@missing"): + contain_missing_values = bool( + strtobool(line_content[1]) + ) + elif line.startswith("@equallength"): + contain_equal_length = bool(strtobool(line_content[1])) + + else: + if len(col_names) == 0: + raise Exception( + "Missing attribute section. Attribute section must come before data." + ) + + found_data_tag = True + elif not line.startswith("#"): + if len(col_names) == 0: + raise Exception( + "Missing attribute section. Attribute section must come before data." + ) + elif not found_data_tag: + raise Exception("Missing @data tag.") + else: + if not started_reading_data_section: + started_reading_data_section = True + found_data_section = True + all_series = [] + + for col in col_names: + all_data[col] = [] + + full_info = line.split(":") + + if len(full_info) != (len(col_names) + 1): + raise Exception("Missing attributes/values in series.") + + series = full_info[len(full_info) - 1] + series = series.split(",") + + if len(series) == 0: + raise Exception( + "A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series. Missing values should be indicated with ? symbol" + ) + + numeric_series = [] + + for val in series: + if val == "?": + numeric_series.append(replace_missing_vals_with) + else: + numeric_series.append(float(val)) + + if numeric_series.count(replace_missing_vals_with) == len( + numeric_series + ): + raise Exception( + "All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series." + ) + + all_series.append(pd.Series(numeric_series).array) + + for i in range(len(col_names)): + att_val = None + if col_types[i] == "numeric": + att_val = int(full_info[i]) + elif col_types[i] == "string": + att_val = str(full_info[i]) + elif col_types[i] == "date": + att_val = datetime.strptime( + full_info[i], "%Y-%m-%d %H-%M-%S" + ) + else: + raise Exception( + "Invalid attribute type." + ) # Currently, the code supports only numeric, string and date types. Extend this as required. + + if att_val is None: + raise Exception("Invalid attribute value.") + else: + all_data[col_names[i]].append(att_val) + + line_count = line_count + 1 + + if line_count == 0: + raise Exception("Empty file.") + if len(col_names) == 0: + raise Exception("Missing attribute section.") + if not found_data_section: + raise Exception("Missing series information under data section.") + + all_data[value_column_name] = all_series + loaded_data = pd.DataFrame(all_data) + + return ( + loaded_data, + frequency, + forecast_horizon, + contain_missing_values, + contain_equal_length, + ) diff --git a/OATS/models/gen_model/ldm/lr_scheduler.py b/OATS/models/gen_model/ldm/lr_scheduler.py new file mode 100644 index 0000000..3eaf5ec --- /dev/null +++ b/OATS/models/gen_model/ldm/lr_scheduler.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/OATS/models/gen_model/ldm/models/autoencoder.py b/OATS/models/gen_model/ldm/models/autoencoder.py new file mode 100644 index 0000000..4730f12 --- /dev/null +++ b/OATS/models/gen_model/ldm/models/autoencoder.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/OATS/models/gen_model/ldm/models/diffusion/ddim_time.py b/OATS/models/gen_model/ldm/models/diffusion/ddim_time.py new file mode 100644 index 0000000..d37904c --- /dev/null +++ b/OATS/models/gen_model/ldm/models/diffusion/ddim_time.py @@ -0,0 +1,195 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import numpy as np +from tqdm import tqdm + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.modules.diffusionmodules.util import return_wrap + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device(self.device): + attr = attr.to(torch.device(self.device)) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev, ddim_coef = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_coef', ddim_coef) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + img_callback=None, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, W = shape + size = (batch_size, C, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(cond = conditioning, shape=size, + callback=callback, + img_callback=img_callback, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, callback=None, timesteps=None, + mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., + noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim(x = img, c=cond, mask=mask, t=ts, index=index, use_original_steps=ddim_use_original_steps, + temperature=temperature, noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, **kwargs) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c, **kwargs) + e_t = return_wrap(e_t, torch.full((b, 1, 1), self.ddim_coef[index], device=device)) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/OATS/models/gen_model/ldm/models/diffusion/ddpm_time.py b/OATS/models/gen_model/ldm/models/diffusion/ddpm_time.py new file mode 100644 index 0000000..f9be9c1 --- /dev/null +++ b/OATS/models/gen_model/ldm/models/diffusion/ddpm_time.py @@ -0,0 +1,951 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +# from pytorch_lightning.utilities.distributed import rank_zero_only +from lightning_utilities.core.rank_zero import rank_zero_only + +from ldm.util import exists, default, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim_time import DDIMSampler +from ldm.modules.diffusionmodules.util import return_wrap +import copy + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + seq_len=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.seq_len = seq_len # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + self.register_buffer("shift_coef", - to_torch(np.sqrt(alphas)) * (1. - self.alphas_cumprod_prev) / torch.sqrt(1. - self.alphas_cumprod)) + self.register_buffer("ddim_coef", -self.sqrt_one_minus_alphas_cumprod) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + self.load_epoch = sd['epoch'] + self.load_step = sd["global_step"] + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + eps_pred = return_wrap(model_out, extract_into_tensor(self.ddim_coef, t, x.shape)) + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=eps_pred) + elif self.parameterization == "x0": + x_recon = eps_pred + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + seq = torch.randn(shape, device=device) + intermediates = [seq] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + seq = self.p_sample(seq, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(seq) + if return_intermediates: + return seq, intermediates + return seq + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + seq_len = self.seq_len + channels = self.channels + return self.p_sample_loop((batch_size, channels, seq_len), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + + eps_pred = return_wrap(model_out, extract_into_tensor(self.shift_coef, t, x_start.shape)) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(eps_pred, target, mean=False).mean(dim=[1, 2]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 2: + x = x[..., None] + x = rearrange(x, 'b t c -> b c t') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + # pass + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + cond_drop_prob = None, + class_condition=False, + *args, **kwargs): + self.class_condition = class_condition + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.cond_drop_prob = cond_drop_prob + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + # def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + if hasattr(self.model.diffusion_model,"scale_factor"): + del self.scale_factor + self.register_buffer('scale_factor', self.model.diffusion_model.scale_factor) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING Pre-Trained STD-RESCALING ###") + else: + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c, return_mask=False): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + elif self.cond_stage_model is None: + c, mask = None, None + else: + c, mask = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + if return_mask: + return c, mask + return c + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, return_mask=False): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c, mask = self.get_learned_conditioning(xc, return_mask=True) + else: + c, mask = self.get_learned_conditioning(xc.to(self.device), return_mask=True) + else: + c = xc + if bs is not None: + c = c[:bs] + + else: + c = None + xc = None + mask = None + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + if return_mask: + out.append(mask) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b t c -> b c t').contiguous() + + z = 1. / self.scale_factor * z + + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + kwargs['data_key'] = batch['data_key'].to(self.device) + loss = self(x, c, **kwargs) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c, return_mask=True) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def apply_model(self, x_noisy, t, cond, mask, cfg_scale=1, cond_drop_prob=None, + sampled_concept= None, sampled_index= None, sub_scale=None, data_key=None, **kwargs): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond, 'mask': mask} + + # if data_key is not None: + # print("class condition is ON") + # else: + # print("class condition is OFF") + + if cond_drop_prob is None: + x_recon = self.model.cfg_forward(x_noisy, t, cfg_scale=cfg_scale, sampled_concept=sampled_concept, sampled_index=sampled_index, sub_scale=sub_scale, data_key=data_key, **cond) + else: + x_recon = self.model.forward(x_noisy, t, cond_drop_prob=cond_drop_prob, + sampled_concept=sampled_concept, sampled_index=sampled_index, sub_scale=sub_scale, data_key=data_key, **cond) + + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def p_losses(self, x_start, condmask, t, noise=None, data_key=None): + # print("data_key:", data_key, data_key.shape) # (b,) + noise = default(noise, lambda: torch.randn_like(x_start)) + if condmask is not None: + cond, mask = condmask + else: + cond = None + mask = None + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + + if self.class_condition: + model_output = self.apply_model(x_noisy, t, cond, mask, cond_drop_prob=self.cond_drop_prob, data_key=data_key) + else: + model_output = self.apply_model(x_noisy, t, cond, mask, cond_drop_prob=self.cond_drop_prob) + + eps_pred = return_wrap(model_output, extract_into_tensor(self.shift_coef, t, x_start.shape)) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(eps_pred, target, mean=False).mean([1, 2]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t.cpu()].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(eps_pred, target, mean=False).mean(dim=(1, 2)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + loss_dict.update({f'{prefix}/epoch_num': self.current_epoch}) + loss_dict.update({f'{prefix}/step_num': self.global_step}) + + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, m, clip_denoised: bool, return_x0=False, + score_corrector=None, corrector_kwargs=None, **kwargs): + t_in = t + model_out = self.apply_model(x, t_in, c, m,**kwargs) + + eps_pred = return_wrap(model_out,extract_into_tensor(self.ddim_coef, t, x.shape)) + + if score_corrector is not None: + assert self.parameterization == "eps" + eps_pred = score_corrector.modify_score(self, eps_pred, x, t, c, **corrector_kwargs) + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=eps_pred) + elif self.parameterization == "x0": + x_recon = eps_pred + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, m, clip_denoised=False, repeat_noise=False, + return_x0=False, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None,**kwargs): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, m=m, clip_denoised=clip_denoised, + return_x0=return_x0, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs,**kwargs) + if return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + seq_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None,**kwargs): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + seq = torch.randn(shape, device=self.device) + else: + seq = x_T + inter_recons = [] + inter_seqs = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + seq, x0_partial = self.p_sample(seq, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs,**kwargs) + if mask is not None: + assert x0 is not None + seq_orig = self.q_sample(x0, ts) + seq = seq_orig * mask + (1. - mask) * seq + + if i % log_every_t == 0 or i == timesteps - 1: + inter_recons.append(x0_partial) + inter_seqs.append(seq) + if callback: callback(i) + if seq_callback: seq_callback(seq, i) + return seq, inter_seqs, inter_recons + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, seq_callback=None, start_T=None, + log_every_t=None,**kwargs): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + seq = torch.randn(shape, device=device) + else: + seq = x_T + + intermediates = [seq] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + + seq = self.p_sample(seq, cond, ts, mask, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised,**kwargs) + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(seq) + if callback: callback(i) + if seq_callback: seq_callback(seq, i) + + if return_intermediates: + return seq, intermediates + return seq + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.seq_len) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, shape, return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, mask=mask, x0=x0,**kwargs) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps=20,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.seq_len) + samples, intermediates =ddim_sampler.sample(S = ddim_steps,batch_size = batch_size, + shape = shape,conditioning = cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=8, sample=True, plot_reconstruction=False, + ddim_steps=20, ddim_eta=1., return_keys=None, **kwargs): + + use_ddim = ddim_steps is not None + # use_ddim = False + # plot_swapped_concepts = True + + log = dict() + z, c, x, xrec, xc, mask = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, return_mask=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x # batchsize, channel, window + if plot_reconstruction: + log["reconstruction"] = xrec + + if sample: + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, mask=mask, data_key=batch['data_key'][:N]) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + with self.ema_scope("Uncond Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, cfg_scale=0, mask=mask) + x_samples = self.decode_first_stage(samples) + log["uncond_samples"] = x_samples + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def parameters(self): + return self.diffusion_model.parameters() + + def forward(self, x, t, c_crossattn: list = None, cond_drop_prob = 0., mask=None, **kwargs): + + if (c_crossattn is not None) and (not None in c_crossattn): + cc = torch.cat(c_crossattn, 1) + else: + cc = None + out = self.diffusion_model(x, t, context=cc, mask=mask, cond_drop_prob=cond_drop_prob, **{"data_key": kwargs.get("data_key", None)}) + + return out + + def cfg_forward(self, x, t, c_crossattn: list = None, mask=None, **kwargs): + + if (c_crossattn is not None) and (not None in c_crossattn): + cc = torch.cat(c_crossattn, 1) + else: + cc = None + out = self.diffusion_model.forward_with_cfg(x, t, context=cc, mask=mask, **kwargs) + + return out + diff --git a/OATS/models/gen_model/ldm/modules/attention.py b/OATS/models/gen_model/ldm/modules/attention.py new file mode 100644 index 0000000..56aa7fe --- /dev/null +++ b/OATS/models/gen_model/ldm/modules/attention.py @@ -0,0 +1,220 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., use_pam=False): + super().__init__() + self.use_pam = use_pam + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + if len(context.shape) == 2: + k = self.to_k(context)[:,None] + v = self.to_v(context)[:,None] + else: + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + # print("q, k, v shapes:", q.shape, k.shape, v.shape) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + if self.use_pam: + mask_of_mask = torch.where(mask > 0, torch.zeros_like(mask), torch.ones_like(mask)) + max_neg_value = -torch.finfo(mask.dtype).max + mask = mask_of_mask * max_neg_value + mask + sim = sim + mask + # sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=False, use_pam=False): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim+64, dim_out=dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout, use_pam=use_pam) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, mask=None, emb_class=None): + return checkpoint(self._forward, (x, context, mask, emb_class), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None, mask=None, emb_class=None): + # if emb_class is not None: + # print("I received emb_class:", emb_class.shape) + # else: + # print("emb_class is None, no class condition") + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + + if emb_class is not None: + emb_class = emb_class[:, None, :] # → (B, 1, D), broadcasts across N + emb_class_expanded = emb_class.expand(-1, x.shape[1], -1) + # print(x.shape, emb_class_expanded.shape) + x_mod = torch.cat((self.norm3(x), emb_class_expanded), dim=-1) # → (B, N, D) + else: + x_mod = torch.cat((self.norm3(x), torch.zeros(x.shape[0], x.shape[1], 64).to("cuda")), dim=-1) + x = self.ff(x_mod) + x + # print("x.shape after ff:", x.shape) + return x + + +class Spatial1DTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, use_pam=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_pam=use_pam) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None, mask=None, emb_class=None): + # note: if no context is given, cross-attention defaults to self-attention + # if emb_class is not None: + # print("I received emb_class:", emb_class.shape) + # else: + # print("emb_class is None, no class condition") + b, c, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c w -> b w c') + for block in self.transformer_blocks: + x = block(x, context=context, mask=mask, emb_class=emb_class) + x = rearrange(x, 'b w c -> b c w', w=w) + x = self.proj_out(x) + return x + x_in diff --git a/OATS/models/gen_model/ldm/modules/diffusionmodules/ts_unet.py b/OATS/models/gen_model/ldm/modules/diffusionmodules/ts_unet.py new file mode 100644 index 0000000..5993599 --- /dev/null +++ b/OATS/models/gen_model/ldm/modules/diffusionmodules/ts_unet.py @@ -0,0 +1,769 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import Spatial1DTransformer +from ldm.util import default +from .util import Return + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + +def prob_mask_like(shape, prob, device): + if prob == 1: + return th.ones(shape, device = device, dtype = th.bool) + elif prob == 0: + return th.zeros(shape, device = device, dtype = th.bool) + else: + return th.zeros(shape, device = device).float().uniform_(0, 1) < prob + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, mask=None, emb_class=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, Spatial1DTransformer): + x = layer(x, context, mask=mask, emb_class=emb_class) + else: + x = layer(x) + return x + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + cond_emb_channels=None, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + if cond_emb_channels is not None: + self.cond_emb_layers = nn.Sequential( + nn.SiLU(), + linear( + cond_emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + else: + self.cond_emb_layers = None + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + +class AttentionBlock(nn.Module): + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + seq_len, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + class_conditionnal_number=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + legacy=True, + repre_emb_channels=32, + latent_unit=6, + use_cfg=True, + cond_drop_prob=0.5, + use_pam=False + ): + super().__init__() + # if use_spatial_transformer: + # assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.seq_len = seq_len + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.use_cfg = use_cfg + self.cond_drop_prob = cond_drop_prob + self.latent_unit = latent_unit + self.latent_dim = repre_emb_channels + self.use_pam = use_pam + self.class_conditionnal_number= class_conditionnal_number + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + if self.class_conditionnal_number is not None: + self.class_embedding_dim = 64 + self.class_emb = nn.Embedding(class_conditionnal_number, self.class_embedding_dim) + + if self.use_cfg: + self.cond_emb_channels = repre_emb_channels * latent_unit if self.use_cfg else None + self.null_classes_emb = nn.Parameter(th.randn(1, repre_emb_channels)) + else: + self.cond_emb_channels = None + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else Spatial1DTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_pam=self.use_pam + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else Spatial1DTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_pam=self.use_pam + ), + ResBlock( + ch, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else Spatial1DTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_pam=self.use_pam + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def _forward(self, x, timesteps=None, context=None, mask=None, y=None, cond_drop_prob=0, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn/adaln + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + if kwargs["data_key"] is not None: + # print("class condition is ON, classes are {}".format(kwargs["data_key"])) + # print("class condition is ON") + emb_class = self.class_emb(kwargs["data_key"]) + else: + # print("class condition is OFF") + emb_class = None + # context = None + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + bs, device = x.shape[0], x.device + cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) + if context is not None: + c_num = context.shape[1] + + if cond_drop_prob > 0: + keep_mask = prob_mask_like((bs, c_num, 1), 1 - cond_drop_prob, device = device) + null_classes_emb = repeat(self.null_classes_emb, '1 d -> b n d', b = bs, n = c_num) + context_emb = context * keep_mask + (~keep_mask) * null_classes_emb + + else: + context_emb = context + else: + context_emb = None + + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + k = 0 + for module in self.input_blocks: + h = module(h, emb, context_emb, mask=mask, emb_class=emb_class) + hs.append(h) + if k == 5: + a = 1 + k += 1 + h = self.middle_block(h, emb, context_emb, mask=mask, emb_class=emb_class) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context_emb, mask=mask, emb_class=emb_class) + h = h.type(x.dtype) + pred = self.out(h) + return Return(pred = pred) + + def forward(self, x, timesteps=None, context=None, mask=None, y=None, cond_drop_prob=0, **kwargs): + out = self._forward(x, timesteps, context, mask, y, cond_drop_prob, **kwargs) + return out + + def forward_with_cfg(self, x, timesteps=None, context=None, y=None, cfg_scale=1,**kwargs): + model_out = self._forward(x=x, timesteps=timesteps, context=context, y=y, cond_drop_prob=0.,**kwargs) + if cfg_scale == 1: + return model_out + + null_context_out = self._forward(x=x, timesteps=timesteps, context=context, y=y, cond_drop_prob=1.,**kwargs) + cfg_grad = model_out.pred - null_context_out.pred + scaled_out = null_context_out.pred + cfg_scale * cfg_grad + + return Return(pred=scaled_out) diff --git a/OATS/models/gen_model/ldm/modules/diffusionmodules/util.py b/OATS/models/gen_model/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000..7611cc8 --- /dev/null +++ b/OATS/models/gen_model/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,287 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config +from typing import NamedTuple + +class Return(NamedTuple): + pred: torch.Tensor + +class Return_grad(NamedTuple): + pred: torch.Tensor + out_grad: torch.Tensor + +class Return_grad_full(NamedTuple): + pred: torch.Tensor + out_grad: torch.Tensor + sub_grad: torch.Tensor + +class Return_grad_cfg(NamedTuple): + pred: torch.Tensor + out_grad: torch.Tensor + sub_grad: torch.Tensor + null_pred: torch.Tensor + +def return_wrap(inp, coef): + if isinstance(inp, Return) or isinstance(inp, Return_grad_cfg): + return inp.pred + elif isinstance(inp, Return_grad) or isinstance(inp, Return_grad_full): + # return inp.out_grad + return inp.pred + coef * inp.out_grad + + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + ddim_coef = -np.sqrt(1. - alphas) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev, ddim_coef + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/OATS/models/gen_model/ldm/modules/distributions/distributions.py b/OATS/models/gen_model/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000..10690be --- /dev/null +++ b/OATS/models/gen_model/ldm/modules/distributions/distributions.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def kl_splits(self, latent_unit=6): + mean_splits = self.mean.chunk(latent_unit, dim=-1) + var_splits = self.var.chunk(latent_unit, dim=-1) + logvar_splits = self.logvar.chunk(latent_unit, dim=-1) + kl_loss = 0 + for mean, var, logvar in zip(mean_splits, var_splits, logvar_splits): + kl_split = 0.5 * torch.sum(torch.pow(mean, 2) + + var - 1.0 - logvar, + dim=-1) + kl_loss += torch.sum(kl_split) / kl_split.shape[0] + return kl_loss/latent_unit + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/OATS/models/gen_model/ldm/modules/ema.py b/OATS/models/gen_model/ldm/modules/ema.py new file mode 100644 index 0000000..7201b23 --- /dev/null +++ b/OATS/models/gen_model/ldm/modules/ema.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/OATS/models/gen_model/ldm/modules/encoders/modules.py b/OATS/models/gen_model/ldm/modules/encoders/modules.py new file mode 100644 index 0000000..e063db0 --- /dev/null +++ b/OATS/models/gen_model/ldm/modules/encoders/modules.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +from einops import repeat +import copy + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +class ResBlockTime(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None, bn=False): + super(ResBlockTime, self).__init__() + + if mid_channels is None: + mid_channels = out_channels + + layers = [ + nn.ReLU(), + nn.Conv1d(in_channels, mid_channels, + kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv1d(mid_channels, out_channels, + kernel_size=1, stride=1, padding=0) + ] + if bn: + layers.insert(2, nn.BatchNorm1d(out_channels)) + self.convs = nn.Sequential(*layers) + + def forward(self, x): + return x + self.convs(x) + +class View(nn.Module): + def __init__(self, size): + super(View, self).__init__() + self.size = size + + def forward(self, tensor): + return tensor.view(self.size) + +class DomainUnifiedEncoder(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + The length of the two part are equal in this implementation. + ''' + def __init__(self, dim, window, num_channels=3, latent_dim=32, bn=True, **kwargs): + super().__init__() + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.in_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True) + ) + + self.out_encoder = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out) + ) + + def forward(self, x): + h = self.in_encoder(x) + mask = None + + out = self.out_encoder(h)[:,None] # b, 1, d + return out, mask + +class DomainUnifiedPrototyper(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + The length of the two part are equal in this implementation. + ''' + def __init__(self, dim, window, num_latents=16, num_channels=3, latent_dim=32, bn=True, **kwargs): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + flatten_dim = int(dim * window / 4) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True) + ) + self.latents = nn.Parameter(torch.empty(num_latents, self.latent_dim), requires_grad=False) + nn.init.orthogonal_(self.latents) + self.init_latents = copy.deepcopy(self.latents.detach()) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.num_latents), + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + b = x.shape[0] + h = self.share_encoder(x) + mask = None + + latents = repeat(self.latents, 'n d -> b n d', b = b) + mask_logit = self.mask_ffn(h) + mask = mask_logit # soft assign + + out = latents # mask + return out, mask + diff --git a/OATS/models/gen_model/ldm/util.py b/OATS/models/gen_model/ldm/util.py new file mode 100644 index 0000000..f0e4f03 --- /dev/null +++ b/OATS/models/gen_model/ldm/util.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import importlib + +import torch + +from inspect import isfunction + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + diff --git a/OATS/models/gen_model/main_train.py b/OATS/models/gen_model/main_train.py new file mode 100644 index 0000000..e2db694 --- /dev/null +++ b/OATS/models/gen_model/main_train.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os, sys +from pytorch_lightning.trainer import Trainer +from utils.cli_utils import get_parser +from utils.init_utils import init_model_data_trainer +from utils.test_utils import test_model_with_dp, test_model_uncond, test_model_unseen + + +if __name__ == "__main__": + + data_root = os.environ['DATA_ROOT'] + + parser = get_parser() + parser.add_argument("--max_steps", type=int, default=50000) + parser.add_argument("--benchmark", type=bool, default=True) + # parser = Trainer.add_argparse_args(parser) + + model, data, trainer, opt, logdir, melk = init_model_data_trainer(parser) + + # run + if opt.train: + try: + trainer.logger.experiment.config.update(opt) + trainer.fit(model, data) + except Exception: + melk() + raise + if not opt.no_test and not trainer.interrupted: + if opt.uncond: + test_model_uncond(model, data, trainer, opt, logdir) + else: + test_model_with_dp(model, data, trainer, opt, logdir) + test_model_unseen(model, data, trainer, opt, logdir) + diff --git a/OATS/models/gen_model/metrics/feature_distance_eval.py b/OATS/models/gen_model/metrics/feature_distance_eval.py new file mode 100644 index 0000000..47c595c --- /dev/null +++ b/OATS/models/gen_model/metrics/feature_distance_eval.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import numpy as np +import torch +from torch import nn +from scipy.stats import entropy +from sklearn.metrics.pairwise import rbf_kernel + +def cal_distances(gt, sp): + gt = gt[~np.isnan(gt)] + # gt = gt[gt= 0.).float() + + for i in range(x_fake.shape[2]): + for t in range(x_fake.shape[1]): + loc = self.locs[i][t].view(1, -1).to(x_fake.device) + x_ti = x_fake[:, t, i].contiguous( + ).view(-1, 1).repeat(1, loc.shape[1]) + dist = torch.abs(x_ti - loc) + left_counter = ((self.deltas[i][t].to(x_fake.device) / 2. - (loc - x_ti)) == 0.).float() + counter = (relu(self.deltas[i][t].to(x_fake.device) / 2. - dist) > 0.).float() + left_counter + density = counter.mean(0) / self.deltas[i][t].to(x_fake.device) + abs_metric = torch.abs( + density - self.densities[i][t].to(x_fake.device)) + loss.append(torch.mean(abs_metric, 0)) + loss_componentwise = torch.stack(loss) + return loss_componentwise + + +def get_mdd_eval(ori_data, generated_data, n_bins=20): + x_real = torch.Tensor(ori_data) + x_fake = torch.Tensor(generated_data) + mdd = (HistoLoss(x_real, n_bins=n_bins, name='marginal_distribution')(x_fake)).detach().cpu().numpy() + + return mdd diff --git a/OATS/models/gen_model/metrics/metrics_sets.py b/OATS/models/gen_model/metrics/metrics_sets.py new file mode 100644 index 0000000..9142618 --- /dev/null +++ b/OATS/models/gen_model/metrics/metrics_sets.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import numpy as np +from utils.data_utils import test_data_loading +from metrics.feature_distance_eval import get_mdd_eval, mmd_metric, get_flat_distance + + +data_root = os.environ['DATA_ROOT'] + + +def calculate_one(gen_data, scaled_ori, model_name, repeat, data_name, seq_len, uni_data_sub, uni_data_div, n_samples): + this_metrics = {} + print(model_name, gen_data.shape) + scaled_gen = (gen_data - uni_data_sub) / uni_data_div + scaled_gen = scaled_gen[:n_samples, :, None] + this_metrics = update_metrics_dict(this_metrics, model_name, data_name, seq_len, scaled_ori, scaled_gen, repeat_id=repeat) + return this_metrics + +def update_metrics_dict(the_dict, key, data_name, seq_len, ori_data, gen_data, repeat_id=0): + if (key, data_name, seq_len, repeat_id) in the_dict: + print(f'{key} {data_name} {seq_len} {repeat_id} already in the dict, skip!') + return the_dict + + mdd = get_mdd_eval(ori_data, gen_data) + the_dict[(key, data_name, seq_len, repeat_id)] = { + 'mdd': mdd, + } + flat_sk_result = get_flat_distance(ori_data, gen_data) + the_dict[(key, data_name, seq_len, repeat_id)].update(flat_sk_result) + the_dict[(key, data_name, seq_len, repeat_id)].update(mmd_metric(ori_data, gen_data)) + return the_dict + +def run_metrics(data_name, seq_len, model_name, gen_data, scale='zscore', exist_dict={}, repeat_id=0): + extend_metrics = exist_dict + + uni_ori_data = test_data_loading(data_name, seq_len, stride=seq_len, univar=True) + uni_data_min, uni_data_max = np.min(uni_ori_data), np.max(uni_ori_data) + uni_data_mean, uni_data_std = np.mean(uni_ori_data), np.std(uni_ori_data) + if scale == 'minmax': + uni_data_sub, uni_data_div = uni_data_min, uni_data_max - uni_data_min + 1e-7 + elif scale == 'zscore': + uni_data_sub, uni_data_div = uni_data_mean, uni_data_std + 1e-7 + elif scale == 'raw': + uni_data_sub, uni_data_div = 0, 1 + elif scale == 'robust_zscore': + median = np.median(uni_ori_data) + mad = np.median(np.abs(uni_ori_data - median)) + uni_data_sub, uni_data_div = median, 1.4826 * mad + 1e-7 + uni_scaled_ori = (uni_ori_data - uni_data_sub) / uni_data_div + print(data_name, 'univar', uni_scaled_ori.shape) + scaled_ori = uni_scaled_ori + scaled_gen = (gen_data - uni_data_sub) / uni_data_div + extend_metrics = update_metrics_dict(extend_metrics, model_name, data_name, seq_len, scaled_ori, scaled_gen, repeat_id=repeat_id) + return extend_metrics diff --git a/OATS/models/gen_model/utils/callback_utils.py b/OATS/models/gen_model/utils/callback_utils.py new file mode 100644 index 0000000..a1e278c --- /dev/null +++ b/OATS/models/gen_model/utils/callback_utils.py @@ -0,0 +1,337 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import numpy as np +import time +import torch +import pytorch_lightning as pl + +from packaging import version +from omegaconf import OmegaConf +import wandb +from pytorch_lightning.callbacks import Callback +# from pytorch_lightning.utilities.distributed import rank_zero_only +from lightning_utilities.core.rank_zero import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info + +from ldm.util import instantiate_from_config +import matplotlib.pyplot as plt + +class SetupCallback(Callback): + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + self.config = config + self.lightning_config = lightning_config + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def on_fit_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + if "callbacks" in self.lightning_config: + if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: + os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + print("Project config") + print(OmegaConf.to_yaml(self.config)) + OmegaConf.save(self.config, + os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + + print("Lightning config") + print(OmegaConf.to_yaml(self.lightning_config)) + OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass + +def plot_naming(k, bi, ni): # the ni-th row, bi-th column + name = '' + if k == 'diffusion_row': + name = f'sample{bi} diffstep {ni*200}' + return name + + +class TSLogger(Callback): + def __init__(self, batch_frequency, max_images, clamp=False, increase_log_steps=True, + rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, + log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + + @rank_zero_only + def log_local(self, save_dir, split, images, + global_step, current_epoch, batch_idx, + key_list, dm, logger=None): + root = os.path.join(save_dir, "images", split) + image_dict = {} + for k in images: # assume inverse normalization has been applied + grid = images[k] + + grid = grid.numpy() # shape: num_samples, channels, window + if len(grid.shape) == 3: + b, c, w = grid.shape # batchsize, channels, window + for i in range(b): + grid[i] = dm.inverse_transform(grid[i], data_name=key_list[i]) + fig, axs = plt.subplots(c, b, figsize=(b * 4, c * 4)) # c rows, b columns + for bi in range(b): # transposed plotting + if c == 1: # typically 1 x 8 + axs[bi].plot(grid[bi, 0]) + else: + for ci in range(c): + axs[ci, bi].plot(grid[bi, ci]) + elif len(grid.shape) == 4: # compare across rows, so batchsize as num of columns + n, b, c, w = grid.shape + for i in range(b): + grid[:,i] = dm.inverse_transform(grid[:,i], data_name=key_list[i]) + fig, axs = plt.subplots(n, b, figsize=(b * 4, n * 4)) # n rows, b columns + for bi in range(b): + if n == 1: + for ci in range(c): + axs[bi].plot(grid[0, bi, ci]) + axs[bi].set_title(plot_naming(k, bi, n)) + else: + for ni in range(n): + for ci in range(c): + axs[ni, bi].plot(grid[ni, bi, ci]) + axs[ni, bi].set_title(plot_naming(k, bi, ni)) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( + k, + global_step, + current_epoch, + batch_idx) + plt.suptitle(filename) + image_dict[k] = wandb.Image(fig) + plt.close() + logger.experiment.log(image_dict, step=global_step) + + def log_img(self, pl_module, batch, batch_idx, split="train", n_row=8): + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + # self.log_steps = [1000, check_idx] + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and + callable(pl_module.log_images) and + self.max_images > 0): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, n_row=n_row, split=split, **self.log_images_kwargs) + key_list = pl_module.trainer.datamodule.key_list + batch_key_list = [] + for i in range(n_row): + batch_key_list.append(key_list[batch['data_key'][i].detach().cpu().numpy()]) + + for k in images: + if k != "samples_swapping" and k != "samples_swapping_partial": # TODO: change to swapping intercept + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1., 1.) # should clamp to [0,1]? or modify data loader to [-1,1] + else: + images[k] = torch.clamp(images[k], -2., 2.) + + self.log_local(pl_module.logger.save_dir, split, images, + pl_module.global_step, pl_module.current_epoch, batch_idx, + batch_key_list, pl_module.trainer.datamodule,logger=pl_module.logger) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( + check_idx > 0 or self.log_first_step): + try: + self.log_steps.pop(0) + except IndexError as e: + print(e) + pass + return True + return False + + # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + self.log_img(pl_module, batch, batch_idx, split="train") + + # def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.root_gpu) + torch.cuda.synchronize(trainer.root_gpu) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + torch.cuda.synchronize(trainer.root_gpu) + max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.training_type_plugin.reduce(max_memory) + epoch_time = trainer.training_type_plugin.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + +def prepare_trainer_configs(nowname, logdir, opt, lightning_config, ckptdir, model, now, cfgdir, config, trainer_opt): + trainer_kwargs = dict() + + # default logger configs + default_logger_cfgs = { + "wandb": { + "target": "pytorch_lightning.loggers.WandbLogger", + "params": { + "name": f"{nowname}_{now}", + "save_dir": logdir, + "offline": opt.debug, + "id": f"{nowname}_{now}", + "project": "TimeDP", + } + } + } + default_logger_cfg = default_logger_cfgs["wandb"] + if "logger" in lightning_config: + logger_cfg = lightning_config.logger + else: + logger_cfg = OmegaConf.create() + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to + # specify which metric is used to determine best models + default_modelckpt_cfg = { + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": ckptdir, + "filename": "{epoch:06}-{val/loss_simple_ema:.4f}", + "verbose": True, + "save_last": True, + "auto_insert_metric_name": False + } + } + if hasattr(model, "monitor"): + print(f"Monitoring {model.monitor} as checkpoint metric.") + default_modelckpt_cfg["params"]["monitor"] = model.monitor + default_modelckpt_cfg["params"]["save_top_k"] = 3 + default_modelckpt_cfg["params"]["mode"] = "min" + if default_modelckpt_cfg["params"]["monitor"] == "train/step_num": + default_modelckpt_cfg["params"]["every_n_train_steps"] = 2000 + default_modelckpt_cfg["params"]["every_n_epochs"] = None + default_modelckpt_cfg["params"]["filename"] = "{step:09}" + default_modelckpt_cfg["params"]["mode"] = "max" + + + if "modelcheckpoint" in lightning_config: + modelckpt_cfg = lightning_config.modelcheckpoint + else: + modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) + print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") + if version.parse(pl.__version__) < version.parse('1.4.0'): + trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) + + # add callback which sets up log directory + default_callbacks_cfg = { + "setup_callback": { + "target": "utils.callback_utils.SetupCallback", + "params": { + "resume": opt.resume, + "now": now, + "logdir": logdir, + "ckptdir": ckptdir, + "cfgdir": cfgdir, + "config": config, + "lightning_config": lightning_config, + } + }, + "learning_rate_logger": { + "target": "pytorch_lightning.callbacks.LearningRateMonitor", + "params": { + "logging_interval": "step", + } + }, + # "cuda_callback": { + # "target": "utils.callback_utils.CUDACallback" + # }, + } + if version.parse(pl.__version__) >= version.parse('1.4.0'): + default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) + + if "callbacks" in lightning_config: + callbacks_cfg = lightning_config.callbacks + else: + callbacks_cfg = OmegaConf.create() + + if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: + print( + 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') + default_metrics_over_trainsteps_ckpt_dict = { + 'metrics_over_trainsteps_checkpoint': + {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', + 'params': { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + } + } + default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) + + callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) + if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): + callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint + elif 'ignore_keys_callback' in callbacks_cfg: + del callbacks_cfg['ignore_keys_callback'] + + trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + return trainer_kwargs diff --git a/OATS/models/gen_model/utils/cli_utils.py b/OATS/models/gen_model/utils/cli_utils.py new file mode 100644 index 0000000..a4c5d62 --- /dev/null +++ b/OATS/models/gen_model/utils/cli_utils.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import argparse +from pytorch_lightning import Trainer + +def get_parser(**parser_kwargs): + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument("-n","--name",type=str,const=True,default="",nargs="?",help="postfix for logdir") + parser.add_argument("-b","--base",nargs="*",metavar="base_config.yaml",help="paths to base configs. Loaded from left-to-right.", default=list(),) + parser.add_argument("-t","--train",type=str2bool,const=True,default=True,nargs="?",help="train",) + parser.add_argument("-r","--resume",type=str2bool,const=True,default=False,nargs="?",help="resume and test",) + parser.add_argument("--no-test",type=str2bool,const=True,default=False,nargs="?",help="disable test",) + parser.add_argument("-d","--debug",type=str2bool,nargs="?",const=True,default=False,help="debug mode",) + parser.add_argument("-s","--seed",type=int,default=23,help="seed for seed_everything",) + parser.add_argument("-f","--postfix",type=str,default="",help="post-postfix for default name",) + parser.add_argument("-l","--logdir",type=str,default="./logs",help="directory for logging dat shit",) + parser.add_argument("--scale_lr",type=str2bool,nargs="?",const=True,default=False,help="scale base-lr by ngpu * batch_size * n_accumulate",) + parser.add_argument("--ckpt_name",type=str,default="last",help="ckpt name to resume",) + parser.add_argument("-sl","--seq_len", type=int, const=True, default=24,nargs="?", help="sequence length") + parser.add_argument("-uc","--uncond", action='store_true', help="unconditional generation") + parser.add_argument("-up","--use_pam", action='store_true', help="use prototype") + parser.add_argument("-bs","--batch_size", type=int, const=True, default=128,nargs="?", help="batch_size") + parser.add_argument("-nl","--num_latents", type=int, const=True, default=16,nargs="?", help="number of prototypes") + parser.add_argument("-lr","--overwrite_learning_rate", type=float, const=True, default=None, nargs="?", help="learning rate") + + return parser + +def nondefault_trainer_args(opt): + parser = argparse.ArgumentParser() + # parser = Trainer.add_argparse_args(parser) + args = parser.parse_args([]) + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) diff --git a/OATS/models/gen_model/utils/data_utils.py b/OATS/models/gen_model/utils/data_utils.py new file mode 100644 index 0000000..ca90179 --- /dev/null +++ b/OATS/models/gen_model/utils/data_utils.py @@ -0,0 +1,184 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pandas as pd +import numpy as np +from einops import rearrange +from distutils.util import strtobool +from datetime import datetime +from pathlib import Path +import os + +prefix = '' +if 'DATA_ROOT' in os.environ and os.path.exists(os.environ['DATA_ROOT']): + prefix = Path(os.environ['DATA_ROOT']) +else: + print("DATA_ROOT not exist or not defined!") + +test_data_map = { + 'solar': 'solar_{seq_len}_val.npy', + 'electricity': 'electricity_{seq_len}_val.npy', + 'traffic': 'traffic_{seq_len}_val.npy', + 'kddcup': 'kddcup_{seq_len}_val.npy', + 'taxi': 'taxi_{seq_len}_val.npy', + 'exchange': 'exchange_{seq_len}_val.npy', + 'fred_md': 'fred_md_{seq_len}_val.npy', + 'nn5': 'nn5_{seq_len}_val.npy', + 'web': 'web_{seq_len}_test_sample.npy', + 'stock': 'stock_{seq_len}_test_sample.npy', + 'temp': 'temp_{seq_len}_val.npy', + 'rain': 'rain_{seq_len}_val.npy', + 'pedestrian': 'pedestrian_{seq_len}_val.npy', + 'wind_4_seconds': 'wind_4_seconds_{seq_len}_val.npy' +} + +def test_data_loading(data_name, seq_len, stride=1, univar=False): + data_path = prefix / test_data_map[data_name].format(seq_len=seq_len, stride=stride) + ori_data = np.load(data_path) # (n, t, c) + if univar: + ori_data = rearrange(ori_data, 'n t c -> (n c) t 1') + return ori_data + +def convert_tsf_to_dataframe( + full_file_path_and_name, + replace_missing_vals_with="NaN", + value_column_name="series_value", +): + col_names = [] + col_types = [] + all_data = {} + line_count = 0 + frequency = None + forecast_horizon = None + contain_missing_values = None + contain_equal_length = None + found_data_tag = False + found_data_section = False + started_reading_data_section = False + + with open(full_file_path_and_name, "r", encoding="cp1252") as file: + for line in file: + # Strip white space from start/end of line + line = line.strip() + + if line: + if line.startswith("@"): # Read meta-data + if not line.startswith("@data"): + line_content = line.split(" ") + if line.startswith("@attribute"): + if ( + len(line_content) != 3 + ): # Attributes have both name and type + raise Exception("Invalid meta-data specification.") + + col_names.append(line_content[1]) + col_types.append(line_content[2]) + else: + if ( + len(line_content) != 2 + ): # Other meta-data have only values + raise Exception("Invalid meta-data specification.") + + if line.startswith("@frequency"): + frequency = line_content[1] + elif line.startswith("@horizon"): + forecast_horizon = int(line_content[1]) + elif line.startswith("@missing"): + contain_missing_values = bool( + strtobool(line_content[1]) + ) + elif line.startswith("@equallength"): + contain_equal_length = bool(strtobool(line_content[1])) + + else: + if len(col_names) == 0: + raise Exception( + "Missing attribute section. Attribute section must come before data." + ) + + found_data_tag = True + elif not line.startswith("#"): + if len(col_names) == 0: + raise Exception( + "Missing attribute section. Attribute section must come before data." + ) + elif not found_data_tag: + raise Exception("Missing @data tag.") + else: + if not started_reading_data_section: + started_reading_data_section = True + found_data_section = True + all_series = [] + + for col in col_names: + all_data[col] = [] + + full_info = line.split(":") + + if len(full_info) != (len(col_names) + 1): + raise Exception("Missing attributes/values in series.") + + series = full_info[len(full_info) - 1] + series = series.split(",") + + if len(series) == 0: + raise Exception( + "A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series. Missing values should be indicated with ? symbol" + ) + + numeric_series = [] + + for val in series: + if val == "?": + numeric_series.append(replace_missing_vals_with) + else: + numeric_series.append(float(val)) + + if numeric_series.count(replace_missing_vals_with) == len( + numeric_series + ): + raise Exception( + "All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series." + ) + + all_series.append(pd.Series(numeric_series).array) + + for i in range(len(col_names)): + att_val = None + if col_types[i] == "numeric": + att_val = int(full_info[i]) + elif col_types[i] == "string": + att_val = str(full_info[i]) + elif col_types[i] == "date": + att_val = datetime.strptime( + full_info[i], "%Y-%m-%d %H-%M-%S" + ) + else: + raise Exception( + "Invalid attribute type." + ) # Currently, the code supports only numeric, string and date types. Extend this as required. + + if att_val is None: + raise Exception("Invalid attribute value.") + else: + all_data[col_names[i]].append(att_val) + + line_count = line_count + 1 + + if line_count == 0: + raise Exception("Empty file.") + if len(col_names) == 0: + raise Exception("Missing attribute section.") + if not found_data_section: + raise Exception("Missing series information under data section.") + + all_data[value_column_name] = all_series + loaded_data = pd.DataFrame(all_data) + + return ( + loaded_data, + frequency, + forecast_horizon, + contain_missing_values, + contain_equal_length, + ) diff --git a/OATS/models/gen_model/utils/init_utils.py b/OATS/models/gen_model/utils/init_utils.py new file mode 100644 index 0000000..2b0bbc5 --- /dev/null +++ b/OATS/models/gen_model/utils/init_utils.py @@ -0,0 +1,253 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import sys +import argparse +from pytorch_lightning import Trainer +from omegaconf import OmegaConf +from utils.cli_utils import nondefault_trainer_args +from utils.callback_utils import prepare_trainer_configs +from pytorch_lightning import seed_everything +from ldm.util import instantiate_from_config +from pathlib import Path +import datetime +from utils.cli_utils import nondefault_trainer_args +import time + +data_root = os.environ['DATA_ROOT'] + +def init_model_data_trainer(parser): + + opt, unknown = parser.parse_known_args() + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + if opt.name: + name = opt.name + elif opt.base: + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = cfg_name + else: + name = "" + + seed_everything(opt.seed) + + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + # Customize config from opt: + config.model['params']['seq_len'] = opt.seq_len + config.model['params']['unet_config']['params']['seq_len'] = opt.seq_len + config.data['params']['window'] = opt.seq_len + config.data['params']['batch_size'] = opt.batch_size + bs = opt.batch_size + if opt.max_steps: + config.lightning['trainer']['max_steps'] = opt.max_steps + if opt.debug: + config.lightning['trainer']['max_steps'] = 10 + config.lightning['callbacks']['image_logger']['params']['batch_frequency'] = 5 + if opt.overwrite_learning_rate is not None: + config.model['base_learning_rate'] = opt.overwrite_learning_rate + print(f"Setting learning rate (overwritting config file) to {opt.overwrite_learning_rate:.2e}") + base_lr = opt.overwrite_learning_rate + else: + base_lr = config.model['base_learning_rate'] + + nowname = f"{name.split('-')[-1]}_{opt.seq_len}_nl_{opt.num_latents}_lr{base_lr:.1e}_bs{opt.batch_size}_time{time.time()}" + + + if opt.uncond: + config.model['params']['cond_stage_config'] = "__is_unconditional__" + config.model['params']['cond_stage_trainable'] = False + config.model['params']['unet_config']['params']['context_dim'] = None + nowname += f"_uncond" + else: + config.model['params']['cond_stage_config']['params']['window'] = opt.seq_len + + if opt.use_pam: + config.model['params']['cond_stage_config']['target'] = "ldm.modules.encoders.modules.DomainUnifiedPrototyper" + config.model['params']['cond_stage_config']['params']['num_latents'] = opt.num_latents + config.model['params']['unet_config']['params']['latent_unit'] = opt.num_latents + config.model['params']['unet_config']['params']['use_pam'] = True + nowname += f"_pam" + else: + config.model['params']['cond_stage_config']['target'] = "ldm.modules.encoders.modules.DomainUnifiedEncoder" + config.model['params']['unet_config']['params']['use_pam'] = False + + nowname += f"_seed{opt.seed}" + logdir = os.path.join(opt.logdir, cfg_name, nowname) + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") + + metrics_dir = Path(logdir) / 'metric_dict.pkl' + if metrics_dir.exists(): + print(f"Metric exists! Skipping {nowname}") + sys.exit(0) + lightning_config = config.pop("lightning", OmegaConf.create()) + # merge trainer cli with config + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + # default to ddp + trainer_config["accelerator"] = "gpu" + for k in nondefault_trainer_args(opt): + trainer_config[k] = getattr(opt, k) + if not "gpus" in trainer_config: + del trainer_config["accelerator"] + cpu = True + else: + gpuinfo = trainer_config["gpus"] + print(f"Running on GPUs {gpuinfo}") + cpu = False + trainer_opt = argparse.Namespace(**trainer_config) + lightning_config.trainer = trainer_config + + # model + if opt.resume: + ckpt_path = logdir / 'checkpoints' / 'last.ckpt' + config.model['params']['ckpt_path'] = ckpt_path + model = instantiate_from_config(config.model) + + # trainer and callbacks + trainer_kwargs = prepare_trainer_configs(nowname, logdir, opt, lightning_config, ckptdir, model, now, cfgdir, config, trainer_opt) + # trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) + trainer = Trainer(benchmark=trainer_opt.benchmark, max_steps=trainer_opt.max_steps, **trainer_kwargs) + trainer.logdir = logdir ### + + # data + for k, v in config.data.params.data_path_dict.items(): + config.data.params.data_path_dict[k] = v.replace('{DATA_ROOT}', data_root).replace('{SEQ_LEN}', str(opt.seq_len)) + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup("fit") + assert config.data.params.input_channels == 1, \ + "Assertion failed: Only univariate input is supported. Please ensure input_channels == 1." + print("#### Data Preparation Finished #####") + if not cpu: + ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) + else: + ngpu = 1 + if 'accumulate_grad_batches' in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + else: + accumulate_grad_batches = 1 + print(f"accumulate_grad_batches = {accumulate_grad_batches}") + lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches + if opt.scale_lr: + model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + print( + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( + model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + else: + model.learning_rate = base_lr + print("++++ NOT USING LR SCALING ++++") + print(f"Setting learning rate to {model.learning_rate:.2e}") + + # allow checkpointing via USR1 + def melk(*args, **kwargs): + # run all checkpoint hooks + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb; # type: ignore + pudb.set_trace() + + import signal + + signal.signal(signal.SIGUSR1, melk) + signal.signal(signal.SIGUSR2, divein) + + return model, data, trainer, opt, logdir, melk + + +def load_model_data(parser): + + opt, unknown = parser.parse_known_args() + + if opt.name: + name = opt.name + elif opt.base: + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = cfg_name + else: + name = "" + + seed_everything(opt.seed) + + # try: + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + # Customize config from opt: + config.model['params']['seq_len'] = opt.seq_len + config.model['params']['unet_config']['params']['seq_len'] = opt.seq_len + config.data['params']['window'] = opt.seq_len + config.data['params']['batch_size'] = opt.batch_size + bs = opt.batch_size + if opt.max_steps: + config.lightning['trainer']['max_steps'] = opt.max_steps + if opt.debug: + config.lightning['trainer']['max_steps'] = 10 + config.lightning['callbacks']['image_logger']['params']['batch_frequency'] = 5 + if opt.overwrite_learning_rate is not None: + config.model['base_learning_rate'] = opt.overwrite_learning_rate + print(f"Setting learning rate (overwritting config file) to {opt.overwrite_learning_rate:.2e}") + base_lr = opt.overwrite_learning_rate + else: + base_lr = config.model['base_learning_rate'] + + nowname = f"{name.split('-')[-1]}_{opt.seq_len}_nl_{opt.num_latents}_lr{base_lr:.1e}_bs{opt.batch_size}" + + if opt.uncond: + config.model['params']['cond_stage_config'] = "__is_unconditional__" + config.model['params']['cond_stage_trainable'] = False + config.model['params']['unet_config']['params']['context_dim'] = None + nowname += f"_uncond" + else: + config.model['params']['cond_stage_config']['params']['window'] = opt.seq_len + + if opt.use_pam: + config.model['params']['cond_stage_config']['target'] = "ldm.modules.encoders.modules.DomainUnifiedPrototyper" + config.model['params']['cond_stage_config']['params']['num_latents'] = opt.num_latents + config.model['params']['unet_config']['params']['latent_unit'] = opt.num_latents + config.model['params']['unet_config']['params']['use_pam'] = True + nowname += f"_pam" + else: + config.model['params']['cond_stage_config']['target'] = "ldm.modules.encoders.modules.DomainUnifiedEncoder" + config.model['params']['unet_config']['params']['use_pam'] = False + + + + nowname += f"_seed{opt.seed}" + logdir = os.path.join(opt.logdir, cfg_name, nowname) + + # model + ckpt_name = opt.ckpt_name + ckpt_path = logdir / 'checkpoints' / f'{ckpt_name}.ckpt' + config.model['params']['ckpt_path'] = ckpt_path + model = instantiate_from_config(config.model) + + # data + for k, v in config.data.params.data_path_dict.items(): + config.data.params.data_path_dict[k] = v.replace('{DATA_ROOT}', data_root).replace('{SEQ_LEN}', str(opt.seq_len)) + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + print("#### Data Preparation Finished #####") + + return model, data, opt, logdir diff --git a/OATS/models/gen_model/utils/pkl_utils.py b/OATS/models/gen_model/utils/pkl_utils.py new file mode 100644 index 0000000..d7f4c7e --- /dev/null +++ b/OATS/models/gen_model/utils/pkl_utils.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pickle +from pathlib import Path + + +def load_pkl(path: Path): + """Load pkl from path.""" + with open(path, "rb") as infile: + data = pickle.load(infile) + return data + + +def save_pkl(data: object, path: Path): + """Save pkl to path.""" + with open(path, "wb") as outfile: + pickle.dump(data, outfile) diff --git a/OATS/models/gen_model/utils/prepare_datasets.py b/OATS/models/gen_model/utils/prepare_datasets.py new file mode 100644 index 0000000..52c259e --- /dev/null +++ b/OATS/models/gen_model/utils/prepare_datasets.py @@ -0,0 +1,289 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from gluonts.evaluation.backtest import make_evaluation_predictions +from gluonts.dataset.multivariate_grouper import MultivariateGrouper +import numpy as np +import pandas as pd +from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset +import requests +import zipfile +import os +import pandas as pd +import numpy as np +from einops import rearrange +from distutils.util import strtobool +from datetime import datetime +import os +import sys +from pathlib import Path +sys.path.append('..') +from utils.data_utils import convert_tsf_to_dataframe + +PREFIX = './data/' +def download_monash_dataset(): + url_map = { + 'temprain': 'https://zenodo.org/records/5129091/files/temperature_rain_dataset_without_missing_values.zip?download=1', + 'wind_4_seconds': 'https://zenodo.org/records/4656032/files/wind_4_seconds_dataset.zip?download=1', + 'pedestrian': 'https://zenodo.org/records/4656626/files/pedestrian_counts_dataset.zip?download=1' + } + for dataset_name in ['temprain', 'wind_4_seconds', 'pedestrian']: + # download the zip file + url = url_map[dataset_name] + zip_path = f"{dataset_name}.zip" + + print("Downloading dataset...") + response = requests.get(url, stream=True) + with open(zip_path, "wb") as file: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + file.write(chunk) + print("Download complete.") + + # Unzip the ZIP file + extract_path = "./data" + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_path) + print(f"Files extracted to {extract_path}") + + # Deleting the ZIP file + os.remove(zip_path) + print("ZIP file deleted.") + +def download_timegan_stock_dataset(): + url = 'https://raw.githubusercontent.com/jsyoon0823/TimeGAN/refs/heads/master/data/stock_data.csv' + + response = requests.get(url, stream=True) + with open('stock_data.csv', "wb") as file: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + file.write(chunk) + print("Download complete.") + +def create_dataset_csv(dataset_simple,): + print('--------------------------------------------------') + print(f'create {dataset_simple} dataset csv') + dataset_alias = { + 'solar':'solar_nips', + 'electricity': 'electricity', + 'traffic':'traffic_nips', + 'kddcup':'kdd_cup_2018_without_missing', + 'taxi':'taxi_30min', + 'exchange':'exchange_rate_nips', + 'fred_md':'fred_md', + 'nn5': 'nn5_daily_without_missing', + 'web': 'kaggle_web_traffic_without_missing' + } + + dataset_name = dataset_alias[dataset_simple] + dataset = get_dataset(dataset_name, regenerate=False) + metadata, train_data, test_data = dataset.metadata, dataset.train, dataset.test + print("metadata", metadata) + train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality))) + test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), + max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality))) + + print("prepare the dataset") + print(f'len(train_data): {len(train_data)}, len(test_data): {len(test_data)}') + + # group the dataset + train_data=train_grouper(dataset.train) + if dataset_simple == 'kddcup': # contains some special cases + test_data = list(test_data) + for i in range(len(test_data)): + if len(test_data[i]['target']) == 10898: + test_data[i]['target'] = np.concatenate((test_data[i]['target'], np.zeros(8))) + test_data = test_grouper(test_data) + else: + test_data=test_grouper(dataset.test) + + # merge the train and test data + train_data = list(train_data) + test_data = list(test_data) + print(f'train_data.shape: {train_data[0]["target"].shape}') + print(f'test_data.shape: {test_data[-1]["target"].shape}') + train_data_T = np.array(train_data[0]['target']).T + test_data_T = np.array(test_data[-1]['target']).T + print(f'train_data_T.shape: {train_data_T.shape}') + print(f'test_data_T.shape: {test_data_T.shape}') + + print(f'train_data_T[-1][:10]: {train_data_T[-1][:10]}') + print(f'test_data_T[-1][:10]: {test_data_T[-1][:10]}') + + + prediction_length = metadata.prediction_length + test_length = len(test_data)*prediction_length + if dataset_simple =='taxi': + # no train overlap + test_data_T_unic = test_data_T[-test_length-prediction_length:] + else: + # train overlap + test_data_T_unic = test_data_T[-test_length:] + + print((train_data_T[-1][-10:])) + print((test_data_T[-len(test_data)*prediction_length-1][-10:])) + + data_all = np.concatenate((train_data_T, test_data_T_unic), axis=0) + print(f'data_all.shape: {data_all.shape}') + + # generate dataframe + metadata = dataset.metadata + print("metadata", metadata) + freq = metadata.freq + + start = pd.Timestamp("2012-01-01 00:00:00") # Assume starting at 2012-01-01 00:00:00 + index = pd.date_range(start=start, freq=freq, periods=len(data_all)) # generate time series, interval is freq, length is len(data_all) + df = pd.DataFrame(data_all, index=index, columns=range(data_all.shape[1])) # create a dataframe, index is time series, columns is 0,1,2,3,4,5,6,7,8,9 + df.index.name = 'date' + print(f'df.shape: {df.shape}') + + test_len = len(test_data)*prediction_length + valid_len = min(7* prediction_length, test_len) + train_len = len(df) - test_len - valid_len + + if dataset_simple == 'taxi': + train_len = len(df) - test_len - valid_len - prediction_length # exclude extra test data + + print("train_len", train_len) + print("valid_len", valid_len) + print("test_len", test_len) + print("prediction_length", prediction_length) + + df.to_csv(f'./data/{dataset_simple}.csv', index=False) + print(f'./data/{dataset_simple}.csv saved') + +def more_data_loading(data_name, seq_len=168, stride=1, univar=False): + if data_name in data_name_path_map: + data_path = PREFIX + data_name_path_map[data_name] + if data_name in ['stock']: + ori_data = np.loadtxt(data_path, delimiter = ",",skiprows = 1) + else: # no index column + ori_data = pd.read_csv(data_path).values + elif data_name in monash_map: + data_path = PREFIX + monash_map[data_name] + loaded_data, *_ = convert_tsf_to_dataframe(data_path) + ori_data = np.stack(loaded_data['series_value'].values).T + # return ori_data + temp_data = [] + # Cut data by sequence length + for i in range(0, len(ori_data) - seq_len + 1, stride): # we do some slicing here + _x = ori_data[None, i:i + seq_len] + temp_data.append(_x) + + data = np.vstack(temp_data) + if univar: + data = rearrange(data, 'n t c -> (n c) t 1') + + return data # , ori_data + +monash_map = { + 'temprain': './data/temperature_rain_dataset_without_missing_values.tsf', + 'wind_4_seconds': './data/wind_4_seconds_dataset.tsf', + 'pedestrian': './data/pedestrian_counts_dataset.tsf' +} +data_name_path_map = { + 'stock': 'stock_data.csv', + 'solar': './data/solar.csv', + 'electricity': './data/electricity.csv', + 'traffic': './data/traffic.csv', + 'kddcup': './data/kddcup.csv', + 'taxi': './data/taxi.csv', + 'exchange': './data/exchange.csv', + 'fred_md': './data/fred_md.csv', + 'nn5': './data/nn5.csv', + 'web': './data/web.csv', + 'temp': './data/temp.csv', + 'rain': './data/rain.csv', + 'pedestrian': './data/pedestrian.csv' +} + +if __name__ == '__main__': + download_monash_dataset() + download_timegan_stock_dataset() + transformed_dataset = ['solar','electricity','traffic','kddcup','taxi','exchange','fred_md','nn5','web'] + for dataset in transformed_dataset: + try: + create_dataset_csv(dataset) + except: + print(f'{dataset} failed') + pass + + data_path = PREFIX + monash_map['temprain'] + loaded_data, *_ = convert_tsf_to_dataframe(data_path) + loaded_data.head() + + rain = np.stack(loaded_data.loc[loaded_data['obs_or_fcst'] == 'PRCP_SUM']['series_value'].values) + print(rain.shape) + df = pd.DataFrame(rain.T) + df.to_csv(PREFIX + 'rain.csv', index=False) + + temp = np.stack(loaded_data.loc[loaded_data['obs_or_fcst'] == 'T_MEAN']['series_value'].values) + print(temp.shape) + df = pd.DataFrame(temp.T) + df.to_csv(PREFIX + 'temp.csv', index=False) + + mix_dataset = [ + 'solar', 'electricity', 'traffic', 'kddcup', 'taxi', 'exchange', 'fred_md', 'nn5', 'temp', 'rain', 'wind_4_seconds' + ] + for seq_len in [24, 96, 168, 336]: + stride = seq_len + for data_name in mix_dataset: + ori_data = more_data_loading(data_name, seq_len, stride) + print(data_name, ori_data.shape) + test_portion = max(1, int(ori_data.shape[0] * 0.1)) + train_data = ori_data[:-test_portion] + val_data = ori_data[-test_portion:] + + np.save(PREFIX + f'{data_name}_{seq_len}_train.npy', train_data) + np.save(PREFIX + f'{data_name}_{seq_len}_val.npy', val_data) + print(train_data.shape, val_data.shape) + + # pedestrian contains inconsistent length + data_path = PREFIX + monash_map['pedestrian'] + loaded_data, *_ = convert_tsf_to_dataframe(data_path) + loaded_data.head() + + train_part = [] + val_part = [] + for _, x in loaded_data['series_value'].items(): + num_segments = x.shape[0] // stride + val_num_segments = max(num_segments // 10, 1) + all_segments = [] + for i in range(0, x.shape[0] - seq_len + 1, stride): # we do some slicing here, but not now. + _x = x[None, i:i + seq_len] + all_segments.append(_x) + train_part += all_segments[:-val_num_segments] + val_part += all_segments[-val_num_segments:] + print(num_segments, val_num_segments) + + pedestrian_train = np.vstack(train_part) + pedestrian_val = np.vstack(val_part) + print(pedestrian_train.shape, pedestrian_val.shape) + + np.save(PREFIX + f'pedestrian_{seq_len}_train.npy', pedestrian_train[:,:,None]) + np.save(PREFIX + f'pedestrian_{seq_len}_val.npy', pedestrian_val[:,:,None]) + + zero_shot_schedule = [3, 10, 100] + for data_name in ['web', 'stock']: # 'web', 'stock' + for seq_len in [24, 96, 168, 336]: + if data_name == 'stock': + ori_data = more_data_loading(data_name, seq_len, 1, univar=False) + uni_ori_data = ori_data[:,:,0,None] + uni_ori_data /= uni_ori_data[:, :1, :] + else: + ori_data = more_data_loading(data_name, seq_len, seq_len, univar=False) + uni_ori_data = rearrange(ori_data, 'b t c -> (b c) t 1') # uniori_data[:,:,0,None]# + + zero_shot_data_path = Path(f'{PREFIX}/ts_data/new_zero_shot_data') + zero_shot_data_path.mkdir(exist_ok=True, parents=True) + + print(len(uni_ori_data)) + np.random.seed(0) + k_idx = np.random.choice(len(uni_ori_data), 2000+max(zero_shot_schedule)) + zero_shot_test_data = uni_ori_data[k_idx[-2000:]] + np.save(zero_shot_data_path/f'{data_name}_{seq_len}_test_sample.npy', zero_shot_test_data) + for k in zero_shot_schedule: + zero_shot_prompt = uni_ori_data[k_idx[:k]] + np.save(zero_shot_data_path/f'{data_name}_{seq_len}_k_{k}_sample.npy', zero_shot_prompt) + + pd.DataFrame(zero_shot_prompt[:,:,0].T).to_csv(zero_shot_data_path/f'{data_name}_dim0_{seq_len}_k_{k}_sample.csv', index=False) diff --git a/OATS/models/gen_model/utils/test_utils.py b/OATS/models/gen_model/utils/test_utils.py new file mode 100644 index 0000000..55f7f53 --- /dev/null +++ b/OATS/models/gen_model/utils/test_utils.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import numpy as np +import torch +from pathlib import Path +from utils.pkl_utils import save_pkl +from metrics.metrics_sets import run_metrics, calculate_one +from ldm.data.tsg_dataset import TSGDataset +import os + +data_root = os.environ['DATA_ROOT'] + +def test_model_with_dp(model, data, trainer, opt, logdir): + if trainer.callbacks[-1].best_model_path: + best_ckpt_path = trainer.callbacks[-1].best_model_path + print(f"Loading best model from {best_ckpt_path} for sampling") + model.init_from_ckpt(best_ckpt_path) + model = model.cuda() + model.eval() + save_path = Path(logdir) / 'generated_samples' + save_path.mkdir(exist_ok=True, parents=True) + seq_len = data.window + num_dp = 100 # number of samples for constructingdomain prompts + all_metrics = {} + for dataset in data.norm_train_dict: + dataset_data = TSGDataset({dataset: data.norm_train_dict[dataset]}) + dataset_samples = [] + for idx in np.random.randint(dataset_data.__len__(),size=num_dp): # randomly sample num_dp samples from the dataset + dataset_samples.append(dataset_data.__getitem__(idx)['context']) + dataset_samples = np.vstack(dataset_samples) + + x = torch.tensor(dataset_samples).to('cuda').float().unsqueeze(1)[:num_dp] + c, mask = model.get_learned_conditioning(x, return_mask=True) + repeats = int(1000 / num_dp) if not opt.debug else 1 + + if c is None: + mask_repeat = None + cond = None + elif mask is None: + cond = torch.repeat_interleave(c, repeats, dim=0) + mask_repeat = None + else: + cond = torch.repeat_interleave(c, repeats, dim=0) + mask_repeat = torch.repeat_interleave(mask, repeats, dim=0) + + all_gen = [] + for _ in range(5 if not opt.debug else 1): # iterate to reduce maximum memory usage + samples, _ = model.sample_log(cond=cond, batch_size=1000 if not opt.debug else 100, ddim=False, cfg_scale=1, mask=mask_repeat) + norm_samples = model.decode_first_stage(samples).detach().cpu().numpy() + inv_samples = data.inverse_transform(norm_samples, data_name=dataset) + all_gen.append(inv_samples) + generated_data = np.vstack(all_gen).transpose(0, 2, 1) + # save data in original scale, for fairness in evaluation + tmp_name = f'{dataset}_{seq_len}_generation' + np.save(save_path / f'{tmp_name}.npy', generated_data) + all_metrics = run_metrics(data_name=dataset, seq_len=seq_len, model_name=tmp_name, gen_data=generated_data, scale='zscore', exist_dict=all_metrics) + print(all_metrics) + save_pkl(all_metrics, Path(logdir) / 'metric_dict.pkl') + + +def test_model_uncond(model, data, trainer, opt, logdir): + if trainer.callbacks[-1].best_model_path: + best_ckpt_path = trainer.callbacks[-1].best_model_path + print(f"Loading best model from {best_ckpt_path} for sampling") + model.init_from_ckpt(best_ckpt_path) + model = model.cuda() + model.eval() + save_path = Path(logdir) / 'generated_samples' + save_path.mkdir(exist_ok=True, parents=True) + seq_len = data.window + all_metrics = {} + for dataset in data.norm_train_dict: + + all_gen = [] + for _ in range(5 if not opt.debug else 1): + samples, _ = model.sample_log(cond=None, batch_size=1000 if not opt.debug else 100, ddim=False, cfg_scale=1) + norm_samples = model.decode_first_stage(samples).detach().cpu().numpy() + inv_samples = data.inverse_transform(norm_samples, data_name=dataset) + all_gen.append(inv_samples) + generated_data = np.vstack(all_gen).transpose(0, 2, 1) + # save data in original scale. for fair use in evaluation + tmp_name = f'{dataset}_{seq_len}_uncond_generation' + np.save(save_path / f'{tmp_name}.npy', generated_data) + all_metrics = run_metrics(data_name=dataset, seq_len=seq_len, model_name=tmp_name, gen_data=generated_data, scale='zscore', exist_dict=all_metrics) + print(all_metrics) + save_pkl(all_metrics, Path(logdir) / 'metric_dict.pkl') + +def zero_shot_k_repeat(samples, model, train_data_module, num_gen_samples=1000): + data = train_data_module + k_samples = samples.transpose(0,2,1) + k = k_samples.shape[0] + normalizer = data.fit_normalizer(k_samples) + + norm_k_samples = data.transform(k_samples, normalizer=normalizer) + + x = torch.tensor(norm_k_samples).float().to('cuda') + c, mask = model.get_learned_conditioning(x, return_mask=True) + + repeats = int(num_gen_samples / k) + extra = num_gen_samples - repeats * k + + cond = torch.repeat_interleave(c, repeats, dim=0) + cond = torch.cat([cond, c[:extra]], dim=0) + mask_repeat = torch.repeat_interleave(mask, repeats, dim=0) + mask_repeat = torch.cat([mask_repeat, mask[:extra]], dim=0) + + samples, z_denoise_row = model.sample_log(cond=cond, batch_size=cond.shape[0], ddim=False, cfg_scale=1, mask=mask_repeat) + norm_samples = model.decode_first_stage(samples).detach().cpu().numpy() + inv_samples = data.inverse_transform(norm_samples, normalizer=normalizer) + gen_data = inv_samples.transpose(0,2,1) + + return gen_data, k_samples.transpose(0,2,1) + +def merge_dicts(dicts): + result = {} + for d in dicts: + for k, v in d.items(): + result[k] = v + return result + +def test_model_unseen(model, data, trainer, opt, logdir): + all_metrics = {} + seq_len = opt.seq_len + for data_name in ['stock', 'web']: + data_result_dicts = [] + uni_ori_data = np.load(f'{data_root}/ts_data/new_zero_shot_data/{data_name}_{seq_len}_test_sample.npy') + if data_name == 'web': + uni_ori_data = uni_ori_data[uni_ori_data dict: + """ + Generate conditional samples and wrap them in the expected batch dictionary format. + + Args: + prompt: Input prompt as numpy array, torch tensor, or list + model: The loaded TimeDP model + data_module: The data module for normalization + dataset_name: Name of the dataset for normalization + dataset_idx: Dataset index for metadata + num_samples: Number of samples to generate (batch size) + ddim: Whether to use DDIM sampling + ddim_steps: Number of DDIM steps + device: Device to run on + context_length: Number of observed timesteps (183 in example) + prediction_length: Number of prediction timesteps (64 in example) + total_length: Total sequence length (512) + patch_size: Patch size (32) + sample_id: Sample ID for metadata + + Returns: + Dictionary with keys: target, observed_mask, time_id, variate_id, prediction_mask, + patch_size, label, label_observed_mask, _dataset_idx, sample_id, dataset_index + """ + + # Generate samples using the existing function + generated_samples = generate_conditional_samples_with_multiple_prompts( + prompts=prompt, + subset_id=subset_id, + model=model, + data_module=data_module, + dataset_name=dataset_name, + num_samples=num_samples, + ddim=ddim, + ddim_steps=ddim_steps, + device=device + ) + + # Convert to torch tensors if needed + if isinstance(generated_samples, np.ndarray): + generated_samples = torch.tensor(generated_samples, dtype=torch.float32) + + # Ensure we have the right shape: (num_samples, sequence_length) + if generated_samples.dim() == 1: + generated_samples = generated_samples.unsqueeze(0) + + batch_size = generated_samples.shape[0] + sequence_length = generated_samples.shape[1] + + # Reshape generated samples to (batch_size, total_length, patch_size) + # We need to pad or truncate to fit the expected format + if sequence_length < total_length * patch_size: + # Pad with zeros + padding_needed = total_length * patch_size - sequence_length + padding = torch.zeros(batch_size, padding_needed, dtype=torch.float32) + padded_samples = torch.cat([generated_samples, padding], dim=1) + else: + # Truncate to fit + padded_samples = generated_samples[:, :total_length * patch_size] + + # Reshape to (batch_size, total_length, patch_size) + target = padded_samples.view(batch_size, total_length, patch_size) + + # Label is the same as target for this case + label = target.clone() + # move 0-8 patch as 1-9 patch (move forward), delete the 10th patch + label[:, :-1, :] = label[:, 1:, :].clone() + + # make the last non-padding patch to be 0 + target[:, sequence_length//patch_size-1, :] = 0 + label[:, sequence_length//patch_size-1, :] = 0 + + # Create observed_mask: True for observed timesteps, False for padding + observed_mask = torch.zeros(batch_size, total_length, patch_size, dtype=torch.bool) + observed_timesteps = min(context_length, total_length) + observed_mask[:, :observed_timesteps+prediction_length-1, :] = True + + # Create time_id: sequential timestep IDs + time_id = torch.zeros(batch_size, total_length, dtype=torch.long) + for i in range(batch_size): + time_id[i, :sequence_length//patch_size-1] = torch.arange(sequence_length//patch_size-1) + # Padding timesteps get 0 + + # Create variate_id: all zeros (univariate) + variate_id = torch.zeros(batch_size, total_length, dtype=torch.long) + + # Create prediction_mask: True for prediction timesteps + prediction_mask = torch.zeros(batch_size, total_length, dtype=torch.bool) + pred_start = context_length + pred_end = min(context_length + prediction_length-1, total_length) + prediction_mask[:, pred_start-1:pred_end] = True + + # Create patch_size tensor + patch_size_tensor = torch.full((batch_size, total_length), patch_size, dtype=torch.long) + # Set padding areas to 0 + patch_size_tensor[:, sequence_length//patch_size-1:] = 0 + + + # Label observed mask is the same as observed mask + label_observed_mask = observed_mask.clone() + label_observed_mask[:, :context_length+prediction_length-1 :] = True + + # Dataset index tensor + assert isinstance(dataset_idx, list) + _dataset_idx = torch.tensor(dataset_idx, dtype=torch.long) + + # Sample ID tensor + sample_id_tensor = torch.full((batch_size, total_length), sample_id, dtype=torch.long) + # Set padding areas to 0 + sample_id_tensor[:, sequence_length//patch_size-1:] = 0 + + # Dataset index for each timestep + # dataset_idx is now a list of integers, one per batch item + dataset_index = torch.zeros((batch_size, total_length), dtype=torch.long) + for i in range(batch_size): + dataset_index[i, :] = dataset_idx[i] + # Set padding areas to -1 + dataset_index[:, sequence_length//patch_size-1:] = -1 + + # Create the batch dictionary + batch = { + 'target': target, + 'observed_mask': observed_mask, + 'time_id': time_id, + 'variate_id': variate_id, + 'prediction_mask': prediction_mask, + 'patch_size': patch_size_tensor, + 'label': label, + 'label_observed_mask': label_observed_mask, + '_dataset_idx': _dataset_idx, + 'sample_id': sample_id_tensor, + 'dataset_index': dataset_index + } + + return batch + + +def generate_conditional_batch( + prompt: Union[np.ndarray, torch.Tensor, List[float]], + subset_id, + model, + data_module, + dataset_name: str, + dataset_idx: int, + num_samples: int = 1, + ddim: bool = True, + ddim_steps: int = 40, + device: str = 'cuda', + context_length: int = 7, # Number of observed timesteps + prediction_length: int = 3, # Number of prediction timesteps + total_length: int = 512, # Total sequence length + patch_size: int = 32, # Patch size + sample_id: int = 1 +) -> dict: + """ + Generate conditional samples and wrap them in the expected batch dictionary format. + + Args: + prompt: Input prompt as numpy array, torch tensor, or list + model: The loaded TimeDP model + data_module: The data module for normalization + dataset_name: Name of the dataset for normalization + dataset_idx: Dataset index for metadata + num_samples: Number of samples to generate (batch size) + ddim: Whether to use DDIM sampling + ddim_steps: Number of DDIM steps + device: Device to run on + context_length: Number of observed timesteps (183 in example) + prediction_length: Number of prediction timesteps (64 in example) + total_length: Total sequence length (512) + patch_size: Patch size (32) + sample_id: Sample ID for metadata + + Returns: + Dictionary with keys: target, observed_mask, time_id, variate_id, prediction_mask, + patch_size, label, label_observed_mask, _dataset_idx, sample_id, dataset_index + """ + + # Generate samples using the existing function + generated_samples = generate_conditional_samples_with_multiple_prompts( + prompts=prompt, + subset_id=subset_id, + model=model, + data_module=data_module, + dataset_name=dataset_name, + num_samples=num_samples, + ddim=ddim, + ddim_steps=ddim_steps, + device=device + ) + + # Convert to torch tensors if needed + if isinstance(generated_samples, np.ndarray): + generated_samples = torch.tensor(generated_samples, dtype=torch.float32) + + # Ensure we have the right shape: (num_samples, sequence_length) + if generated_samples.dim() == 1: + generated_samples = generated_samples.unsqueeze(0) + + batch_size = generated_samples.shape[0] + sequence_length = generated_samples.shape[1] + + # Reshape generated samples to (batch_size, total_length, patch_size) + # We need to pad or truncate to fit the expected format + if sequence_length < total_length * patch_size: + # Pad with zeros + padding_needed = total_length * patch_size - sequence_length + padding = torch.zeros(batch_size, padding_needed, dtype=torch.float32) + padded_samples = torch.cat([generated_samples, padding], dim=1) + else: + # Truncate to fit + padded_samples = generated_samples[:, :total_length * patch_size] + + # Reshape to (batch_size, total_length, patch_size) + target = padded_samples.view(batch_size, total_length, patch_size) + + # Create observed_mask: True for observed timesteps, False for padding + observed_mask = torch.zeros(batch_size, total_length, patch_size, dtype=torch.bool) + observed_timesteps = min(context_length, total_length) + observed_mask[:, :observed_timesteps, :] = True + + # Create time_id: sequential timestep IDs + time_id = torch.zeros(batch_size, total_length, dtype=torch.long) + for i in range(batch_size): + time_id[i, :sequence_length//patch_size] = torch.arange(sequence_length//patch_size) + # Padding timesteps get 0 + + # Create variate_id: all zeros (univariate) + variate_id = torch.zeros(batch_size, total_length, dtype=torch.long) + + # Create prediction_mask: True for prediction timesteps + prediction_mask = torch.zeros(batch_size, total_length, dtype=torch.bool) + pred_start = context_length + pred_end = min(context_length + prediction_length, total_length) + prediction_mask[:, pred_start:pred_end] = True + + # Create patch_size tensor + patch_size_tensor = torch.full((batch_size, total_length), patch_size, dtype=torch.long) + # Set padding areas to 0 + patch_size_tensor[:, sequence_length//patch_size:] = 0 + + # Label is the same as target for this case + label = target.clone() + + # Label observed mask is the same as observed mask + label_observed_mask = observed_mask.clone() + label_observed_mask[:, :context_length+prediction_length :] = True + + # Dataset index tensor + assert isinstance(dataset_idx, list) + _dataset_idx = torch.tensor(dataset_idx, dtype=torch.long) + + # Sample ID tensor + sample_id_tensor = torch.full((batch_size, total_length), sample_id, dtype=torch.long) + # Set padding areas to 0 + sample_id_tensor[:, sequence_length//patch_size:] = 0 + + # Dataset index for each timestep + # dataset_idx is now a list of integers, one per batch item + dataset_index = torch.zeros((batch_size, total_length), dtype=torch.long) + for i in range(batch_size): + dataset_index[i, :] = dataset_idx[i] + # Set padding areas to -1 + dataset_index[:, sequence_length//patch_size:] = -1 + + # Create the batch dictionary + batch = { + 'target': target, + 'observed_mask': observed_mask, + 'time_id': time_id, + 'variate_id': variate_id, + 'prediction_mask': prediction_mask, + 'patch_size': patch_size_tensor, + 'label': label, + 'label_observed_mask': label_observed_mask, + '_dataset_idx': _dataset_idx, + 'sample_id': sample_id_tensor, + 'dataset_index': dataset_index + } + + return batch + + +def generate_conditional_samples_with_multiple_prompts( + prompts: Union[np.ndarray, torch.Tensor, List[float]], + subset_id, + model, + data_module, + dataset_name, + num_samples: int = 1, + ddim: bool = True, + ddim_steps: int = 20, + device: str = 'cuda' +) -> np.ndarray: + """ + Generate conditional samples using a trained TimeDP model. + + Args: + prompts: Input prompt as numpy array, torch tensor, or list. Shape should be (sequence_length,) + model: The loaded TimeDP model + data_module: The data module for normalization + dataset_name: Name of the dataset for normalization (if None, uses first available dataset) + num_samples: Number of samples to generate + ddim_steps: Number of DDIM steps + device: Device to run on ('cuda' or 'cpu') + + Returns: + Generated samples as numpy array of shape (num_samples, sequence_length) + """ + # Convert prompt to numpy array if needed + # if isinstance(prompt, torch.Tensor): + # prompt = prompt.detach().cpu().numpy() + # elif isinstance(prompt, list): + # prompt = np.array(prompt) + + assert isinstance(prompts, list) + assert isinstance(dataset_name, list) + assert isinstance(subset_id, list) + num_samples *= len(prompts) # Total samples to generate across all prompts + + # Ensure prompt is the right shape + # if prompt.ndim == 2: + # prompt = prompt.reshape(prompt.shape[0], 1, -1) # (batch, 1, sequence_length) + + # normalzie the data + normalized_prompts = [] + for i, prompt in enumerate(prompts): + normalizer = data_module.normalizer_dict[dataset_name[i]] + normalized_prompt = data_module.transform(prompt.cpu(), normalizer) + normalized_prompts.append(normalized_prompt) + + # Convert to tensor and move to device + normalized_prompt = np.array(normalized_prompts).reshape(len(prompts), 1, -1) + x = torch.tensor(normalized_prompt).to(device).float() + subset_id = torch.tensor(subset_id, dtype=torch.long).to(device) # Dummy subset_id for now + + # Get conditioning + c, mask = model.get_learned_conditioning(x, return_mask=True) + + print("c.shape", c.shape if c is not None else "None") + print("mask.shape", mask.shape if mask is not None else "None") + + # Repeat conditioning for the number of samples we want to generate + if c is None: + mask_repeat = None + cond = None + elif mask is None: + cond = c.repeat(1, 1, 1) + mask_repeat = None + else: + cond = c.repeat(1, 1, 1) + mask_repeat = mask.repeat(1, 1) + + # Generate samples + with torch.no_grad(): + samples, _ = model.sample_log( + cond=cond, + batch_size=num_samples, + ddim=ddim, + ddim_steps=ddim_steps, + cfg_scale=1.0, # 5.0 + mask=mask_repeat, + data_key=subset_id, + ) + norm_samples = model.decode_first_stage(samples).detach().cpu().numpy() + + inv_samples = [] + for i in range(len(dataset_name)): + inv_sample = data_module.inverse_transform(norm_samples[i], data_name=dataset_name[i]) + inv_samples.append(inv_sample) + + inv_samples = np.array(inv_samples).reshape(num_samples, -1) + + return inv_samples.squeeze() + + +def load_timedp_model( + config_path: str, + ckpt_path: str, + seq_len: int = 320, + num_latents: int = 16, + batch_size: int = 16, + use_pam: bool = True, + uncond: bool = False, + seed: int = 0, + debug: bool = False, + overwrite_learning_rate: float = None, + device: str = 'cuda' +): + """ + Load TimeDP model and data module without using parser. + + Args: + config_path: Path to the config YAML file + ckpt_path: Path to the model checkpoint + seq_len: Sequence length (default: 320) + num_latents: Number of latents for PAM (default: 16) + batch_size: Batch size (default: 16) + use_pam: Whether to use PAM (default: True) + uncond: Whether to use unconditional model (default: False) + seed: Random seed (default: 0) + debug: Debug mode (default: False) + overwrite_learning_rate: Override learning rate (optional) + device: Device to load on (default: 'cuda') + + Returns: + Tuple of (model, data, config_name, logdir) + """ + from omegaconf import OmegaConf + from pytorch_lightning import seed_everything + from ldm.util import instantiate_from_config + + # Get data root from environment + data_root = os.environ['DATA_ROOT'] + + # Set seed + seed_everything(seed) + + # Load and merge configs + config = OmegaConf.load(config_path) + + # Get config name from path + cfg_fname = os.path.split(config_path)[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + + # Customize config from parameters + config.model['params']['seq_len'] = seq_len + config.model['params']['unet_config']['params']['seq_len'] = seq_len + config.data['params']['window'] = seq_len + config.data['params']['batch_size'] = batch_size + + # Set max steps + config.lightning['trainer']['max_steps'] = 50000 + if debug: + config.lightning['trainer']['max_steps'] = 10 + config.lightning['callbacks']['image_logger']['params']['batch_frequency'] = 5 + + # Handle learning rate + if overwrite_learning_rate is not None: + config.model['base_learning_rate'] = overwrite_learning_rate + print(f"Setting learning rate (overwriting config file) to {overwrite_learning_rate:.2e}") + base_lr = overwrite_learning_rate + else: + base_lr = config.model['base_learning_rate'] + + # Create experiment name + nowname = f"{cfg_name.split('-')[-1]}_{seq_len}_nl_{num_latents}_lr{base_lr:.1e}_bs{batch_size}" + + # Configure conditional/unconditional setup + if uncond: + config.model['params']['cond_stage_config'] = "__is_unconditional__" + config.model['params']['cond_stage_trainable'] = False + config.model['params']['unet_config']['params']['context_dim'] = None + nowname += f"_uncond" + else: + config.model['params']['cond_stage_config']['params']['window'] = seq_len + + if use_pam: + config.model['params']['cond_stage_config']['target'] = "ldm.modules.encoders.modules.DomainUnifiedPrototyper" + config.model['params']['cond_stage_config']['params']['num_latents'] = num_latents + config.model['params']['unet_config']['params']['latent_unit'] = num_latents + config.model['params']['unet_config']['params']['use_pam'] = True + nowname += f"_pam" + else: + config.model['params']['cond_stage_config']['target'] = "ldm.modules.encoders.modules.DomainUnifiedEncoder" + config.model['params']['unet_config']['params']['use_pam'] = False + + nowname += f"_seed{seed}" + logdir = os.path.join('./logs', cfg_name, nowname) + + # Set checkpoint path in config + config.model['params']['ckpt_path'] = ckpt_path + + # Instantiate model + model = instantiate_from_config(config.model) + + # Instantiate data module + # Replace placeholders in data paths + for k, v in config.data.params.data_path_dict.items(): + config.data.params.data_path_dict[k] = v.replace('{DATA_ROOT}', data_root).replace('{SEQ_LEN}', str(seq_len)) + + data = instantiate_from_config(config.data) + + # Prepare data + data.prepare_data() + data.setup("predict") + print("#### Data Preparation Finished #####") + + return model, data, cfg_name, logdir + +def load_model(ckpt_path: str, config_path: str = None, device: str = 'cuda', seed: int = 0, **kwargs): + """ + Load the TimeDP model and data module. + + Args: + ckpt_path: Path to the model checkpoint + config_path: Path to the config file (optional) + device: Device to load on + **kwargs: Additional arguments passed to load_timedp_model + + Returns: + Tuple of (model, data) + """ + if config_path is None: + config_path = 'models/gen_model/configs/multi_domain_timedp_local.yaml' + + model, data, cfg_name, logdir = load_timedp_model( + config_path=config_path, + ckpt_path=ckpt_path, + device=device, + seed=seed, + **kwargs + ) + + model.init_from_ckpt(ckpt_path) + model = model.to(device) + model.eval() + + return model, data + +def main(model, data, index=0): + """ + Test function for conditional generation. + """ + + # print(data.key_list) + # for key, _ in data.data_dict.items(): + # print(key) + + test_prompts = [] + subset_ids = [] + dataset_names = [] + for dataset in data.key_list: + print(f"Testing dataset: {dataset}") + dataset_name = dataset # "australian_electricity_demand" / "largest_2021" + subset_id = data.key_list.index(dataset_name) + # subset_id = 2 + print(dataset_name, subset_id, index) + test_prompt = np.load(f'extracted_label_patches_{dataset_name}.npy')[index].reshape(320,) + + print("Generating conditional samples...") + # print(f"Prompt shape: {test_prompt.shape}") + test_prompts.append(torch.Tensor(test_prompt)) + subset_ids.append(subset_id) + dataset_names.append(dataset_name) + + # Generate samples + generated_samples = generate_conditional_batch_decoder( + prompt=test_prompts, + subset_id=subset_ids, + model=model, + dataset_name=dataset_names, + data_module=data, + num_samples=1, + ddim=True, + ddim_steps=40, + dataset_idx=[9494]*len(test_prompts), # Example dataset index, adjust as needed + ) + + # print other information + # print(generated_samples) + # save the first samples of each key to a file + for key, value in generated_samples.items(): + torch.save(value[0], f"batch_{key}_generated.pt") + + print(f"Generated {len(generated_samples['label'])} samples") + generated_samples = generated_samples['label'][:, :10, :].reshape(-1, 320) + + print(f"Generated samples shape: {generated_samples.shape}") + + # # plot all the generated samples in sub figures as well as the test prompt + import matplotlib.pyplot as plt + + # plot 20 test_prompts and 20 generated samples in a 5*8 grid + plt.figure(figsize=(20, 10)) + + # Plot test prompts + for i, prompt in enumerate(test_prompts): + plt.subplot(5, 8, 2*i + 1) + plt.plot(prompt, label=f'Test Prompt {i+1}', color="orange") + # plt.legend() + plt.title(f'Test Prompt {i+1}') + + # Plot generated samples + for i, sample in enumerate(generated_samples): + plt.subplot(5, 8, 2*i + 2) + plt.plot(sample, label=f'Generated Sample {i+1}') + # plt.legend() + plt.title(f'Generated Sample {i+1}') + + # plt.subplots(len(generated_samples), 1, figsize=(10, 4 * len(generated_samples))) + # for i, sample in enumerate(generated_samples): + # plt.subplot(len(generated_samples), 1, i + 1) + # if i != 0: + # plt.plot(sample, label=f'Generated Sample {i+1}') + # if i == 0: # Only plot the prompt on the first subplot + # plt.plot(test_prompt, label='Test Prompt', color="orange") + # plt.legend() + # plt.title(f'Generated Sample {i+1}') + + plt.tight_layout() + # plt.savefig(f'new_generated_samples_{dataset_name}_conditionclass_promotidx_{index}_cfg10.png') + plt.savefig(f'play_cfg1_index{index}_ddim.png') + +if __name__ == '__main__': + ckpt_path = "000060-0.0853.ckpt" + + # Load model once + print("Loading model...") + model, data = load_model(ckpt_path, seed=0) + + main(model, data, index=1) diff --git a/OATS/models/tsfm/common/__init__.py b/OATS/models/tsfm/common/__init__.py new file mode 100644 index 0000000..a442066 --- /dev/null +++ b/OATS/models/tsfm/common/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. \ No newline at end of file diff --git a/OATS/models/tsfm/common/core.py b/OATS/models/tsfm/common/core.py new file mode 100644 index 0000000..8f926d7 --- /dev/null +++ b/OATS/models/tsfm/common/core.py @@ -0,0 +1,54 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +from collections.abc import Callable +from typing import TypeVar + +T = TypeVar("T") + + +def abstract_class_property(*names: str) -> Callable[[type[T], ...], type[T]]: + def _func(cls: type[T]) -> type[T]: + original_init_subclass = cls.__init_subclass__ + + def _init_subclass(_cls, **kwargs): + # The default implementation of __init_subclass__ takes no + # positional arguments, but a custom implementation does. + # If the user has not reimplemented __init_subclass__ then + # the first signature will fail and we try the second. + try: + original_init_subclass(_cls, **kwargs) + except TypeError: + original_init_subclass(**kwargs) + + # Check that each attribute is defined. + for name in names: + if not hasattr(_cls, name): + raise NotImplementedError( + f"{name} has not been defined for {_cls.__name__}" + ) + if getattr(_cls, name, NotImplemented) is NotImplemented: + raise NotImplementedError( + f"dataset_list has not been defined for {_cls.__name__}" + ) + + cls.__init_subclass__ = classmethod(_init_subclass) + return cls + + return _func diff --git a/OATS/models/tsfm/common/env.py b/OATS/models/tsfm/common/env.py new file mode 100644 index 0000000..22aa1e5 --- /dev/null +++ b/OATS/models/tsfm/common/env.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import warnings +from pathlib import Path +from typing import Optional + +from dotenv import load_dotenv + + +def get_path_var(var: Optional[str]) -> Optional[Path]: + if (path := os.getenv(var)) is not None: + return Path(path) + return None + + +class Env: + _instance: Optional["Env"] = None + path_vars: list[str] = [ + "LOTSA_10M_PATH", + "LOTSA_100M_PATH", + "LOTSA_1B_PATH", + "LOTSA_16B_PATH", + "LOTSA_V1_PATH", + "CUSTOM_DATA_PATH" + ] + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + if not load_dotenv(): + warnings.warn("Failed to load .env file.") + cls.monkey_patch_path_vars() + return cls._instance + + @classmethod + def monkey_patch_path_vars(cls): + for var in cls.path_vars: + setattr(cls, var, get_path_var(var)) + + +env = Env() \ No newline at end of file diff --git a/OATS/models/tsfm/common/hydra_util.py b/OATS/models/tsfm/common/hydra_util.py new file mode 100644 index 0000000..9865ee4 --- /dev/null +++ b/OATS/models/tsfm/common/hydra_util.py @@ -0,0 +1,59 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections.abc import Callable +from typing import Any + +from hydra.utils import get_class +from omegaconf import OmegaConf +from tsfm.common.env import env + + +def register_resolver(name: str) -> Callable[[Callable], Callable]: + def decorator(resolver: Callable) -> Callable: + OmegaConf.register_new_resolver(name, resolver) + return resolver + + return decorator + + +@register_resolver("as_tuple") +def resolve_as_tuple(ls: list) -> tuple: + return tuple(ls) + + +@register_resolver("cls_getattr") +def resolve_cls_getattr(cls_name: str, attribute_name: str) -> Any: + if cls_name.endswith(".load_from_checkpoint"): + cls_name = cls_name[: -len(".load_from_checkpoint")] + cls = get_class(cls_name) + return getattr(cls, attribute_name) + + +@register_resolver("floordiv") +def resolve_floordiv(a: int, b: int) -> int: + return a // b + + +@register_resolver("mul") +def resolve_mul(a: float, b: float) -> float: + return a * b + +@register_resolver("env_get") +def resolve_env_attr(attr_name: str) -> Any: + return getattr(env, attr_name, None) \ No newline at end of file diff --git a/OATS/models/tsfm/common/sampler.py b/OATS/models/tsfm/common/sampler.py new file mode 100644 index 0000000..c05d310 --- /dev/null +++ b/OATS/models/tsfm/common/sampler.py @@ -0,0 +1,58 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections.abc import Callable +from functools import partial +from typing import cast + +import numpy as np + +Sampler = Callable[[int | np.ndarray], int | np.ndarray] + + +def uniform_sampler(n: int | np.ndarray) -> int | np.ndarray: + return np.random.randint(1, n + 1) + + +def binomial_sampler(n: int | np.ndarray, p: float = 0.5) -> int | np.ndarray: + return np.random.binomial(n - 1, p) + 1 + + +def beta_binomial_sampler( + n: int | np.ndarray, a: float = 1, b: float = 1 +) -> int | np.ndarray: + # equivalent to uniform_sampler when a = b = 1 + if isinstance(n, np.ndarray): + p = np.random.beta(a, b, size=n.shape) + else: + p = np.random.beta(a, b) + return np.random.binomial(n - 1, p) + 1 + + +def get_sampler(distribution: str, **kwargs) -> Sampler: + if distribution == "uniform": + return uniform_sampler + elif distribution == "binomial": + p = kwargs.get("p", 0.5) + return cast(Sampler, partial(binomial_sampler, p=p)) + elif distribution == "beta_binomial": + a = kwargs.get("a", 1) + b = kwargs.get("b", 1) + return cast(Sampler, partial(beta_binomial_sampler, a=a, b=b)) + else: + raise NotImplementedError(f"distribution {distribution} not implemented") diff --git a/OATS/models/tsfm/common/torch_util.py b/OATS/models/tsfm/common/torch_util.py new file mode 100644 index 0000000..7322378 --- /dev/null +++ b/OATS/models/tsfm/common/torch_util.py @@ -0,0 +1,124 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Optional + +import numpy as np +import torch +from jaxtyping import Bool, Float, Int + +numpy_to_torch_dtype_dict = { + bool: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} + + +def packed_attention_mask( + sample_id: Int[torch.Tensor, "*batch seq_len"] +) -> Bool[torch.Tensor, "*batch seq_len seq_len"]: + sample_id = sample_id.unsqueeze(-1) + attention_mask = sample_id.eq(sample_id.mT) + return attention_mask + +def packed_attention_mask_andop( + id: Bool[torch.Tensor, "*batch seq_len"], +) -> Bool[torch.Tensor, "*batch seq_len seq_len"]: + id = id.unsqueeze(-1) + attention_mask = id & id.mT + return attention_mask + +def mask_fill( + tensor: Float[torch.Tensor, "*batch dim"], + mask: Bool[torch.Tensor, "*batch"], + value: Float[torch.Tensor, "dim"], +) -> Float[torch.Tensor, "*batch dim"]: + mask = mask.unsqueeze(-1) + return tensor * ~mask + value * mask + + +def safe_div( + numer: torch.Tensor, + denom: torch.Tensor, +) -> torch.Tensor: + return numer / torch.where( + denom == 0, + 1.0, + denom, + ) + + +def size_to_mask( + max_size: int, + sizes: Int[torch.Tensor, "*batch"], +) -> Bool[torch.Tensor, "*batch max_size"]: + mask = torch.arange(max_size, device=sizes.device) + return torch.lt(mask, sizes.unsqueeze(-1)) + + +def fixed_size( + value: Float[torch.Tensor, "*batch max_size"] +) -> Int[torch.Tensor, "*batch"]: + sizes = torch.ones_like(value[..., 0], dtype=torch.long) * value.shape[-1] + return sizes + + +def sized_mean( + value: Float[torch.Tensor, "*batch max_size"], + sizes: Optional[Int[torch.Tensor, "*batch"]], + dim: Optional[int | tuple[int, ...]] = None, + keepdim: bool = False, + size_keepdim: bool = False, + correction: int = 0, +) -> Float[torch.Tensor, "..."]: + value = value * size_to_mask(value.shape[-1], sizes) + div_val = safe_div( + value.sum(dim=-1).sum(dim, keepdim=keepdim), + torch.clamp(sizes.sum(dim, keepdim=keepdim) - correction, min=0), + ) + if size_keepdim: + div_val = div_val.unsqueeze(-1) + return div_val + + +def masked_mean( + value: Float[torch.Tensor, "..."], + mask: Bool[torch.Tensor, "..."], + dim: Optional[int | tuple[int, ...]] = None, + keepdim: bool = False, + correction: int = 0, +) -> Float[torch.Tensor, "..."]: + return safe_div( + (value * mask).sum(dim=dim, keepdim=keepdim), + torch.clamp(mask.float().sum(dim, keepdim=keepdim) - correction, min=0), + ) + + +def unsqueeze_trailing_dims(x: torch.Tensor, shape: torch.Size) -> torch.Tensor: + if x.ndim > len(shape) or x.shape != shape[: x.ndim]: + raise ValueError + dim = (...,) + (None,) * (len(shape) - x.ndim) + return x[dim] diff --git a/OATS/models/tsfm/common/typing.py b/OATS/models/tsfm/common/typing.py new file mode 100644 index 0000000..d584777 --- /dev/null +++ b/OATS/models/tsfm/common/typing.py @@ -0,0 +1,56 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections.abc import Callable, Iterable +from typing import Any + +import numpy as np +import torch +from jaxtyping import AbstractDtype, Num + + +class DateTime64(AbstractDtype): + dtypes = ["datetime64"] + + +class Character(AbstractDtype): + dtypes = ["str_"] + + +# Data preparation +GenFunc = Callable[[], Iterable[dict[str, Any]]] +SliceableGenFunc = Callable[..., Iterable[dict[str, Any]]] + + +# Indexer +DateTime = DateTime64[np.ndarray, ""] +BatchedDateTime = DateTime64[np.ndarray, "batch"] +String = np.character +BatchedString = Character[np.ndarray, "batch"] +UnivarTimeSeries = Num[np.ndarray, "time"] +MultivarTimeSeries = Num[np.ndarray, "var time"] +Data = DateTime | String | UnivarTimeSeries | MultivarTimeSeries +BatchedData = ( + BatchedDateTime | BatchedString | list[UnivarTimeSeries] | list[MultivarTimeSeries] +) +FlattenedData = DateTime | String | list[UnivarTimeSeries] + + +# Loader +Sample = dict[str, Num[torch.Tensor, "*sample"]] +BatchedSample = dict[str, Num[torch.Tensor, "batch *sample"]] diff --git a/OATS/models/tsfm/data/__init__.py b/OATS/models/tsfm/data/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/OATS/models/tsfm/data/__init__.py @@ -0,0 +1 @@ + diff --git a/OATS/models/tsfm/data/builder/__init__.py b/OATS/models/tsfm/data/builder/__init__.py new file mode 100644 index 0000000..b1bb663 --- /dev/null +++ b/OATS/models/tsfm/data/builder/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from ._base import ConcatDatasetBuilder, DatasetBuilder, ConcatDatasetBuilderWithGlobalIndex + +__all__ = [ + "DatasetBuilder", + "ConcatDatasetBuilder", + "ConcatDatasetBuilderWithGlobalIndex", +] diff --git a/OATS/models/tsfm/data/builder/_base.py b/OATS/models/tsfm/data/builder/_base.py new file mode 100644 index 0000000..ee729a0 --- /dev/null +++ b/OATS/models/tsfm/data/builder/_base.py @@ -0,0 +1,241 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import abc +from typing import Any, Callable + +from torch.utils.data import ConcatDataset, Dataset + +from tsfm.transform import Transformation + +import logging +log = logging.getLogger(__name__) + +# TODO: Add __repr__ +class DatasetBuilder(abc.ABC): + @abc.abstractmethod + def build_dataset(self, *args, **kwargs): ... + + @abc.abstractmethod + def load_dataset( + self, transform_map: dict[Any, Callable[..., Transformation]] + ) -> Dataset: ... + + +class ConcatDatasetBuilder(DatasetBuilder): + def __init__(self, *builders: DatasetBuilder): + super().__init__() + assert len(builders) > 0, "Must provide at least one builder to ConcatBuilder" + assert all( + isinstance(builder, DatasetBuilder) for builder in builders + ), "All builders must be instances of DatasetBuilder" + self.builders: tuple[DatasetBuilder, ...] = builders + + def build_dataset(self): + raise ValueError( + "Do not use ConcatBuilder to build datasets, build sub datasets individually instead." + ) + + def load_dataset( + self, transform_map: dict[Any, Callable[..., Transformation]] + ) -> ConcatDataset: + return ConcatDataset( + [builder.load_dataset(transform_map) for builder in self.builders] + ) + + +class ConcatDatasetBuilderWithGlobalIndex(DatasetBuilder): + """ConcatDatasetBuilder that calculates and passes global offsets to sub-datasets.""" + + def __init__(self, *builders: DatasetBuilder): + super().__init__() + assert len(builders) > 0, "Must provide at least one builder to ConcatBuilder" + assert all( + isinstance(builder, DatasetBuilder) for builder in builders + ), "All builders must be instances of DatasetBuilder" + self.builders: tuple[DatasetBuilder, ...] = builders + self.dataset_ranges = [] # Store list of (start, end, name, builder_type) tuples + + def build_dataset(self): + raise ValueError( + "Do not use ConcatBuilder to build datasets, build sub datasets individually instead." + ) + + def _get_dataset_name(self, builder, dataset): + """Extract a meaningful dataset name from builder and dataset.""" + # Method 1: Try to get the dataset name from dataset.info if it exists + dataset_info_name = None + if hasattr(dataset, 'info') and hasattr(dataset.info, 'dataset_name'): + dataset_info_name = dataset.info.dataset_name + + # Method 2: Try to get from dataset indexer + indexer_name = None + if hasattr(dataset, 'indexer') and hasattr(dataset.indexer, 'dataset'): + if hasattr(dataset.indexer.dataset, 'info') and hasattr(dataset.indexer.dataset.info, 'dataset_name'): + indexer_name = dataset.indexer.dataset.info.dataset_name + + # Method 3: Look for dataset names in builder's dataset_list + builder_dataset_name = None + if hasattr(builder, 'dataset_list') and hasattr(builder, 'datasets'): + # For LOTSA builders, try to match the dataset path or find in datasets list + if hasattr(dataset, 'indexer') and hasattr(dataset.indexer, 'dataset_path'): + dataset_path = str(dataset.indexer.dataset_path) + for dataset_name in builder.dataset_list: + if dataset_name in dataset_path: + builder_dataset_name = dataset_name + break + elif len(builder.dataset_list) == 1: + # If builder only has one dataset, use that name + builder_dataset_name = builder.dataset_list[0] + + # Method 4: Check if we can extract from file paths or other attributes + path_name = None + if hasattr(dataset, 'indexer') and hasattr(dataset.indexer, 'dataset'): + if hasattr(dataset.indexer.dataset, '_data_dir_cache') and dataset.indexer.dataset._data_dir_cache: + # Extract name from data directory path + import os + path_name = os.path.basename(dataset.indexer.dataset._data_dir_cache.rstrip('/')) + + # Prioritize the most specific and meaningful name + best_name = None + + # First preference: builder dataset name (most specific) + if builder_dataset_name and builder_dataset_name != 'generator': + best_name = builder_dataset_name + # Second preference: path name (usually specific) + elif path_name and path_name != 'generator': + best_name = path_name + # Third preference: indexer name (if not generic) + elif indexer_name and indexer_name != 'generator': + best_name = indexer_name + # Fourth preference: dataset info name (even if generic) + elif dataset_info_name: + best_name = dataset_info_name + + # If we still have a generic name, try to make it more specific using builder info + if best_name == 'generator' or best_name is None: + builder_name = builder.__class__.__name__.replace('DatasetBuilder', '') + + # For builders with multiple datasets, try to get more specific info + if hasattr(builder, 'datasets') and builder.datasets: + # If the builder is loading specific datasets, try to match + if len(builder.datasets) == 1: + best_name = builder.datasets[0] + else: + # Multiple datasets - we need to figure out which one this is + # This is tricky without more context, but we can use the builder name + best_name = f"{builder_name}_dataset" + else: + best_name = builder_name + + return best_name or "Unknown" + + def _flatten_datasets(self, dataset, global_offset, builder=None, dataset_name=None): + """Recursively flatten nested ConcatDatasets and set global offsets.""" + flattened = [] + current_offset = global_offset + + if isinstance(dataset, ConcatDataset): + for i, sub_dataset in enumerate(dataset.datasets): + sub_flattened, current_offset = self._flatten_datasets( + sub_dataset, current_offset, builder, f"{dataset_name}_sub{i}" if dataset_name else None + ) + flattened.extend(sub_flattened) + else: + # Single dataset - set its global offset and record metadata + if hasattr(dataset, 'set_global_offset'): + dataset.set_global_offset(current_offset) + + # Extract meaningful dataset name + if dataset_name is None: + dataset_name = self._get_dataset_name(builder, dataset) + + # Store range metadata for this dataset + # Use num_ts (actual time series count) instead of len(dataset) (weighted count) + # because the dataset classes use modulo num_ts for global_idx calculation + if hasattr(dataset, 'num_ts'): + actual_size = dataset.num_ts + else: + actual_size = len(dataset) # fallback for datasets without num_ts + + self.dataset_ranges.append({ + 'global_start': current_offset, + 'global_end': current_offset + actual_size - 1, + 'size': actual_size, + 'dataset_name': dataset_name, + 'builder_type': builder.__class__.__name__ if builder else 'Unknown' + }) + + log.info(f"Set global offset {current_offset} for dataset '{dataset_name}' with {actual_size} samples (range: {current_offset}-{current_offset + actual_size - 1})") + log.info(f"Dataset ts_num: {getattr(dataset, 'num_ts', 'N/A')}.") + flattened.append(dataset) + current_offset += actual_size + + return flattened, current_offset + + def load_dataset( + self, transform_map: dict[Any, Callable[..., Transformation]] + ) -> ConcatDataset: + all_datasets = [] + global_offset = 0 + self.dataset_ranges = [] # Reset ranges + + for builder in self.builders: + # Load the dataset + dataset = builder.load_dataset(transform_map) + + # Flatten and set offsets + flattened, global_offset = self._flatten_datasets(dataset, global_offset, builder) + all_datasets.extend(flattened) + + log.info(f"Total concatenated dataset size: {global_offset}") + log.info(f"Dataset ranges: {len(self.dataset_ranges)} datasets") + return ConcatDataset(all_datasets) + + def get_dataset_name_for_global_index(self, global_idx: int) -> str: + """Get the sub-dataset name for a given global index.""" + # Bounds check: ensure global_idx is not negative + if global_idx < 0: + return f"Invalid_idx_{global_idx}" + + # Find the dataset that contains this global index + for range_info in self.dataset_ranges: + if range_info['global_start'] <= global_idx <= range_info['global_end']: + return range_info['dataset_name'] + + # If we reach here, the index is out of bounds + # Check if it's beyond the maximum known range + if self.dataset_ranges: + max_range = max(r['global_end'] for r in self.dataset_ranges) + total_size = max_range + 1 + + if global_idx > max_range: + log.warning(f"Global index {global_idx} exceeds dataset bounds (max: {max_range}, total: {total_size}). " + f"This might indicate stale influence scores from a previous run with larger dataset.") + return f"OutOfBounds_idx_{global_idx}_max_{max_range}" + + return f"Unknown_idx_{global_idx}" + + def get_all_dataset_ranges(self) -> list: + """Get all dataset range metadata for debugging/analysis.""" + return self.dataset_ranges.copy() + + +class SafeConcatDatasetBuilder(ConcatDatasetBuilder): + def __init__(self, *builders: DatasetBuilder): + super().__init__(*builders) + self.builders = builders + + def load_dataset( + self, transform_map: dict[Any, Callable[..., Transformation]] + ) -> ConcatDataset: + datasets = [] + for builder in self.builders: + try: + datasets.append(builder.safeload_dataset(transform_map)) + except Exception as e: + log.error(f"Error loading dataset from {builder}: {e}") + continue + + return ConcatDataset(datasets) + diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/README.md b/OATS/models/tsfm/data/builder/lotsa_v1/README.md new file mode 100644 index 0000000..e69de29 diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/__init__.py b/OATS/models/tsfm/data/builder/lotsa_v1/__init__.py new file mode 100644 index 0000000..3a8c4bb --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from ._base import LOTSADatasetBuilder +from .buildings_bench import Buildings900KDatasetBuilder, BuildingsBenchDatasetBuilder +from .cloudops_tsf import CloudOpsTSFDatasetBuilder +from .cmip6 import CMIP6DatasetBuilder +from .era5 import ERA5DatasetBuilder +from .gluonts import GluonTSDatasetBuilder +from .largest import LargeSTDatasetBuilder +from .lib_city import LibCityDatasetBuilder +from .others import OthersLOTSADatasetBuilder +from .proenfo import ProEnFoDatasetBuilder +from .subseasonal import SubseasonalDatasetBuilder + +__all__ = [ + "LOTSADatasetBuilder", + "Buildings900KDatasetBuilder", + "BuildingsBenchDatasetBuilder", + "CloudOpsTSFDatasetBuilder", + "CMIP6DatasetBuilder", + "ERA5DatasetBuilder", + "GluonTSDatasetBuilder", + "LargeSTDatasetBuilder", + "LibCityDatasetBuilder", + "OthersLOTSADatasetBuilder", + "ProEnFoDatasetBuilder", + "SubseasonalDatasetBuilder", +] diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/__main__.py b/OATS/models/tsfm/data/builder/lotsa_v1/__main__.py new file mode 100644 index 0000000..55b1607 --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/__main__.py @@ -0,0 +1,120 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import argparse +import traceback +from pathlib import Path + +from tsfm.common.env import env + +from . import ( + Buildings900KDatasetBuilder, + BuildingsBenchDatasetBuilder, + CloudOpsTSFDatasetBuilder, + CMIP6DatasetBuilder, + ERA5DatasetBuilder, + GluonTSDatasetBuilder, + LargeSTDatasetBuilder, + LibCityDatasetBuilder, + OthersLOTSADatasetBuilder, + ProEnFoDatasetBuilder, + SubseasonalDatasetBuilder, +) + +parser = argparse.ArgumentParser() +parser.add_argument( + "builder", + type=str, + choices=[ + "buildings_900k", + "buildings_bench", + "cloudops_tsf", + "cmip6", + "era5", + "gluonts", + "largest", + "lib_city", + "others", + "proenfo", + "subseasonal", + ], +) +parser.add_argument( + "--datasets", + type=str, + nargs="+", + default=None, + help="The datasets to generate", +) +parser.add_argument( + "--storage_path", + type=Path, + default=env.LOTSA_V1_PATH, + help="Path of directory to save the datasets", +) +parser.add_argument( + "--overwrite", + action="store_true", +) +args = parser.parse_args() + +Builder = { + "buildings_900k": Buildings900KDatasetBuilder, + "buildings_bench": BuildingsBenchDatasetBuilder, + "cloudops_tsf": CloudOpsTSFDatasetBuilder, + "cmip6": CMIP6DatasetBuilder, + "era5": ERA5DatasetBuilder, + "gluonts": GluonTSDatasetBuilder, + "largest": LargeSTDatasetBuilder, + "lib_city": LibCityDatasetBuilder, + "others": OthersLOTSADatasetBuilder, + "proenfo": ProEnFoDatasetBuilder, + "subseasonal": SubseasonalDatasetBuilder, +}[args.builder] + +datasets = set(args.datasets or Builder.dataset_list) +found = {directory.stem for directory in args.storage_path.iterdir()} +overlap = datasets & found + +if len(overlap) > 0: + print(f"Found datasets already present in storage path: {overlap}") + +if not args.overwrite: + datasets = datasets - found + if len(overlap) > 0: + print(f"Skipping processed datasets, building: {list(datasets)}") + print("To overwrite existing datasets, use the `--overwrite` flag") +else: + print(f"Overwriting existing datasets, building: {datasets}") + +failed = {} +for dataset in datasets: + try: + print(f"Building: {dataset}") + Builder( + datasets=list(datasets), + storage_path=args.storage_path, + ).build_dataset(dataset=dataset) + print(f"Successfully built {dataset}") + except Exception as e: + print(f"Failed to build {dataset}") + failed[dataset] = traceback.format_exc() + +if len(failed) > 0: + print(f"Failed: {list(failed.keys())}") + for k, v in failed.items(): + print(f"{k}: {v}") diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/_base.py b/OATS/models/tsfm/data/builder/lotsa_v1/_base.py new file mode 100644 index 0000000..8435d57 --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/_base.py @@ -0,0 +1,129 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import abc +import datasets +from collections.abc import Callable +from pathlib import Path +from typing import Optional, Any + +from datasets import load_from_disk +from torch.utils.data import ConcatDataset, Dataset + +from tsfm.common.core import abstract_class_property +from tsfm.common.env import env +from tsfm.data.builder._base import DatasetBuilder +from tsfm.data.dataset import SampleTimeSeriesType, TimeSeriesDataset, TimeSeriesDatasetWithIndex +from tsfm.data.indexer import HuggingFaceDatasetIndexer +from tsfm.transform import Identity, Transformation + +import logging +log = logging.getLogger(__name__) + +@abstract_class_property("dataset_list", "dataset_type_map", "dataset_load_func_map") +class LOTSADatasetBuilder(DatasetBuilder, abc.ABC): + dataset_list: list[str] = NotImplemented + dataset_type_map: dict[str, type[TimeSeriesDataset]] = NotImplemented + dataset_load_func_map: dict[str, Callable[..., TimeSeriesDataset]] = NotImplemented + uniform: bool = False + + # Use the enhanced dataset with index tracking by default + default_dataset_class: type[TimeSeriesDataset] = TimeSeriesDatasetWithIndex + + def __init__( + self, + datasets: list[str], + weight_map: Optional[dict[str, float]] = None, + sample_time_series: SampleTimeSeriesType = SampleTimeSeriesType.NONE, + storage_path: Path = env.LOTSA_V1_PATH, + ): + assert all( + dataset in self.dataset_list for dataset in datasets + ), f"Invalid datasets {set(datasets).difference(self.dataset_list)}, must be one of {self.dataset_list}" + weight_map = weight_map or dict() + self.datasets = datasets + self.weights = [weight_map.get(dataset, 1.0) for dataset in datasets] + self.sample_time_series = sample_time_series + self.storage_path = storage_path + + def load_dataset( + self, transform_map: dict[str | type, Callable[..., Transformation]] + ) -> Dataset: + datasets = [ + self.dataset_load_func_map[dataset]( + HuggingFaceDatasetIndexer( + add_column( + load_from_disk(str(self.storage_path / dataset)), + 'dataset_name', + dataset), + uniform=self.uniform, + ), + self._get_transform(transform_map, dataset), + sample_time_series=self.sample_time_series, + dataset_weight=weight, + ) + for dataset, weight in zip(self.datasets, self.weights) + ] + + return datasets[0] if len(datasets) == 1 else ConcatDataset(datasets) + + def safeload_dataset( + self, transform_map: dict[str | type, Callable[..., Transformation]] + ) -> Dataset: + datasets = [] + for dataset, weight in zip(self.datasets, self.weights): + try: + indexer = HuggingFaceDatasetIndexer( + load_from_disk(str(self.storage_path / dataset)), + uniform=self.uniform, + ) + datasets.append( + self.dataset_load_func_map[dataset]( + indexer, + self._get_transform(transform_map, dataset), + sample_time_series=self.sample_time_series, + dataset_weight=weight, + ) + ) + except Exception as e: + log.error(f'Error loading dataset from {dataset}: {e}') + continue + + return datasets[0] if len(datasets) == 1 else ConcatDataset(datasets) + + def _get_transform( + self, + transform_map: dict[str | type, Callable[..., Transformation]], + dataset: str, + ) -> Transformation: + + if dataset in transform_map: + transform = transform_map[dataset] + elif (dataset_type := self.dataset_type_map[dataset]) in transform_map: + transform = transform_map[dataset_type] + else: + try: # defaultdict + transform = transform_map[dataset] + except KeyError: + transform = transform_map.get("default", Identity) + return transform() + +def add_column(hfdataset: datasets.arrow_dataset.Dataset, column_name: str, clolumn_value: Any): + new_clolumn = [clolumn_value] * len(hfdataset) + hfdataset = hfdataset.add_column(column_name, new_clolumn) + return hfdataset \ No newline at end of file diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/buildings_bench.py b/OATS/models/tsfm/data/builder/lotsa_v1/buildings_bench.py new file mode 100644 index 0000000..b85b4cf --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/buildings_bench.py @@ -0,0 +1,185 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from collections import defaultdict +from functools import partial +from itertools import product +from pathlib import Path +from typing import Any, Generator + +import datasets +import pyarrow.parquet as pq +from datasets import Features, Sequence, Value + +try: + from buildings_bench.data import load_pandas_dataset +except ImportError: + import traceback + + e = traceback.format_exc() + + def load_pandas_dataset(*args, **kwargs): + raise ImportError(e) + + +from tsfm.common.env import env +from tsfm.data.dataset import MultiSampleTimeSeriesDataset, MultiSampleTimeSeriesDatasetWithIndex, TimeSeriesDataset, TimeSeriesDatasetWithIndex + +from ._base import LOTSADatasetBuilder + +MULTI_SAMPLE_DATASETS = [ + "bdg-2_panther", +] + + +class BuildingsBenchDatasetBuilder(LOTSADatasetBuilder): + dataset_list = [ + "sceaux", + "borealis", + "ideal", + "bdg-2_panther", + "bdg-2_fox", + "bdg-2_rat", + "bdg-2_bear", + "smart", + "lcl", + ] + dataset_type_map = defaultdict(lambda: TimeSeriesDatasetWithIndex) | { + dataset: MultiSampleTimeSeriesDatasetWithIndex for dataset in MULTI_SAMPLE_DATASETS + } + dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDatasetWithIndex)) | { + dataset: partial( + MultiSampleTimeSeriesDatasetWithIndex, + max_ts=128, + combine_fields=("target", "past_feat_dynamic_real"), + ) + for dataset in MULTI_SAMPLE_DATASETS + } + + def build_dataset(self, dataset: str): + def gen_func() -> Generator[dict[str, Any], None, None]: + if dataset.startswith("bdg"): + pd_dataset = load_pandas_dataset(dataset.replace("_", ":")) + else: + pd_dataset = load_pandas_dataset(dataset) + + for item_id, df in pd_dataset: + freq = df.index.freqstr + if freq is None: + from pandas.tseries.frequencies import to_offset + + freq = to_offset(df.index[1] - df.index[0]).freqstr + df = df.asfreq(freq) + if df.power.isnull().sum() / len(df) > 0.5: + continue + yield dict( + item_id=item_id, + start=df.index[0], + target=df.power, + freq=freq, + ) + + hf_dataset = datasets.Dataset.from_generator( + generator=gen_func, + features=Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence(Value("float32")), + ) + ), + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk(dataset_path=env.LOTSA_V1_PATH / dataset) + + +class Buildings900KDatasetBuilder(LOTSADatasetBuilder): + dataset_list: list[str] = ["buildings_900k"] + dataset_type_map = dict(buildings_900k=TimeSeriesDatasetWithIndex) + dataset_load_func_map = dict( + buildings_900k=partial(TimeSeriesDatasetWithIndex), + ) + + def build_dataset(self, dataset: str, num_proc: int = os.cpu_count()): + all_jobs = [] + building_type_and_years = [ + "comstock_amy2018", + "comstock_tmy3", + "resstock_amy2018", + "resstock_tmy3", + ] + regions = ["midwest", "northeast", "south", "west"] + for building_type_and_year, region in product(building_type_and_years, regions): + for building_dir in [ + "Buildings-900K/end-use-load-profiles-for-us-building-stock", + "Buildings-900K-test", + ]: + pumas = ( + Path(os.getenv("BUILDINGS_BENCH")) + / f"{building_dir}/2021/{building_type_and_year}_release_1/" + f"timeseries_individual_buildings/by_puma_{region}/upgrade=0/" + ).glob("puma=*") + + pumas = [p.stem[len("puma=") :] for p in pumas] + + for puma in pumas: + all_jobs.append( + (building_dir, building_type_and_year, region, puma) + ) + + def gen_func(job_ids: list[int]) -> Generator[dict[str, Any], None, None]: + for idx in job_ids: + building_dir, building_type_and_year, region, puma = all_jobs[idx] + tab = pq.read_table( + Path(os.getenv("BUILDINGS_BENCH")) + / f"{building_dir}/2021/{building_type_and_year}_release_1/" + f"timeseries_individual_buildings/by_puma_{region}/upgrade=0/puma={puma}" + ) + tab = tab.sort_by("timestamp") + for building_num, col_num in zip( + tab.column_names[1:], range(1, tab.num_columns) + ): + yield dict( + item_id=f"{building_type_and_year}_{region}_{puma}_{building_num}", + start=tab.column(0) + .slice(0, 1) + .to_numpy()[0] + .astype("datetime64"), + target=tab.column(col_num).to_numpy(), + freq="H", + ) + + hf_dataset = datasets.Dataset.from_generator( + gen_func, + features=Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence(Value("float32")), + ) + ), + gen_kwargs={"job_ids": [i for i in range(len(all_jobs))]}, + num_proc=num_proc, + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk(self.storage_path / dataset, num_proc=num_proc) diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/cloudops_tsf.py b/OATS/models/tsfm/data/builder/lotsa_v1/cloudops_tsf.py new file mode 100644 index 0000000..0ebbf63 --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/cloudops_tsf.py @@ -0,0 +1,99 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from collections import defaultdict +from functools import partial +from typing import Any, Generator + +import datasets +from datasets import Features, Sequence, Value, load_dataset, load_dataset_builder +from gluonts.dataset.common import ProcessDataEntry +from gluonts.dataset.split import DateSplitter + +from tsfm.common.env import env +from tsfm.data.dataset import TimeSeriesDataset, TimeSeriesDatasetWithIndex + +from ._base import LOTSADatasetBuilder + + +class CloudOpsTSFDatasetBuilder(LOTSADatasetBuilder): + dataset_list = [ + "azure_vm_traces_2017", + "borg_cluster_data_2011", + "alibaba_cluster_trace_2018", + ] + dataset_type_map = defaultdict(lambda: TimeSeriesDatasetWithIndex) + dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDatasetWithIndex)) + + def build_dataset(self, dataset: str, num_proc: int = os.cpu_count()): + cloudops_dataset = load_dataset( + path="Salesforce/cloudops_tsf", name=dataset, split="pretrain" + ) + cfg = load_dataset_builder( + path="Salesforce/cloudops_tsf", + name=dataset, + ).config + pde = ProcessDataEntry( + freq=cfg.freq, one_dim_target=cfg.univariate, use_timestamp=False + ) + splitter = DateSplitter(cfg.test_split_date) + + def process(entry): + return next(iter(splitter.split([pde(entry)])[0])) + + def gen_func(ids: list[int]) -> Generator[dict[str, Any], None, None]: + for item in cloudops_dataset.select(ids): + item = process(item) + yield dict( + item_id=item["item_id"], + start=item["start"].to_timestamp(), + freq=cfg.freq, + target=item["target"], + past_feat_dynamic_real=item["past_feat_dynamic_real"], + ) + + target_feature = ( + Sequence(Value("float32")) + if cfg.target_dim == 1 + else Sequence(Sequence(Value("float32")), length=cfg.target_dim) + ) + past_feat_dynamic_real_feature = Sequence( + Sequence(Value("float32")), length=cfg.past_feat_dynamic_real_dim + ) + + hf_dataset = datasets.Dataset.from_generator( + gen_func, + features=Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=target_feature, + past_feat_dynamic_real=past_feat_dynamic_real_feature, + ) + ), + gen_kwargs={"ids": [i for i in range(len(cloudops_dataset))]}, + num_proc=num_proc, + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk( + self.storage_path / dataset, + num_proc=10, + ) diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/cmip6.py b/OATS/models/tsfm/data/builder/lotsa_v1/cmip6.py new file mode 100644 index 0000000..d333b0f --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/cmip6.py @@ -0,0 +1,130 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import itertools +import os +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Generator + +import datasets +import numpy as np +import pandas as pd +from datasets import Features, Sequence, Value + +from tsfm.common.env import env +from tsfm.data.dataset import TimeSeriesDataset, TimeSeriesDatasetWithIndex + +from ._base import LOTSADatasetBuilder + +DEFAULT_PRESSURE_LEVELS = [ + 50, + 100, + 150, + 200, + 250, + 300, + 400, + 500, + 600, + 700, + 850, + 925, + 1000, +] + + +CMIP6_VARIABLES = [ + f"{var}_{level}" + for var, level in itertools.product( + [ + "geopotential", + "specific_humidity", + "temperature", + ], + DEFAULT_PRESSURE_LEVELS, + ) +] + [ + f"{var}_{level}" + for var, level in itertools.product( + [ + "u_component_of_wind", + "v_component_of_wind", + ], + [50, 250, 500, 600, 700, 850, 925], + ) +] + + +class CMIP6DatasetBuilder(LOTSADatasetBuilder): + dataset_list = [f"cmip6_{year}" for year in range(1850, 2015, 5)] + dataset_type_map = defaultdict(lambda: TimeSeriesDatasetWithIndex) + dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDatasetWithIndex)) + uniform = True + + def build_dataset(self, dataset: str, num_proc: int = os.cpu_count()): + cmip6_path = Path(os.getenv("CMIP6_PATH")) + + year = int(dataset.split("_")[-1]) + all_jobs = [(x, y) for x, y in itertools.product(range(64), range(128))] + + all_vars = {var: [] for var in CMIP6_VARIABLES} + for shard in range(10): + np_file = np.load( + str(cmip6_path / f"train/{year}01010600-{year + 5}01010000_{shard}.npz") + ) + for var in CMIP6_VARIABLES: + all_vars[var].append(np_file[var][:, 0, :, :]) + + targets = np.stack( + [np.concatenate(all_vars[var]) for var in CMIP6_VARIABLES], axis=0 + ) + + def gen_func( + jobs: list[tuple[int, int]] + ) -> Generator[dict[str, Any], None, None]: + for x, y in jobs: + yield dict( + item_id=f"{year}_{x}_{y}", + start=pd.Timestamp(f"{year}-01-01 06:00"), + target=targets[:, :, x, y], + freq="6H", + ) + + hf_dataset = datasets.Dataset.from_generator( + gen_func, + features=Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence( + Sequence(Value("float32")), length=len(CMIP6_VARIABLES) + ), + ) + ), + gen_kwargs=dict(jobs=all_jobs), + num_proc=num_proc, + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk( + self.storage_path / dataset, + num_proc=num_proc, + ) diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/era5.py b/OATS/models/tsfm/data/builder/lotsa_v1/era5.py new file mode 100644 index 0000000..513931a --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/era5.py @@ -0,0 +1,109 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import itertools +import os +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Generator + +import datasets +import numpy as np +import pandas as pd +from datasets import Features, Sequence, Value + +from tsfm.common.env import env +from tsfm.data.dataset import TimeSeriesDataset, TimeSeriesDatasetWithIndex + +from ._base import LOTSADatasetBuilder + +ERA5_VARIABLES = [ + "2m_temperature", + "10m_u_component_of_wind", + "10m_v_component_of_wind", +] + [ + f"{var}_{level}" + for var, level in itertools.product( + [ + "geopotential", + "relative_humidity", + "specific_humidity", + "temperature", + "u_component_of_wind", + "v_component_of_wind", + ], + [50, 250, 500, 600, 700, 850, 925], + ) +] + + +class ERA5DatasetBuilder(LOTSADatasetBuilder): + dataset_list = [f"era5_{year}" for year in range(1989, 2018 + 1)] + dataset_type_map = defaultdict(lambda: TimeSeriesDatasetWithIndex) + dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDatasetWithIndex)) + uniform = True + + def build_dataset(self, dataset: str, num_proc: int = os.cpu_count()): + era5_path = Path(os.getenv("ERA5_PATH")) + + year = dataset.split("_")[-1] + all_jobs = [(x, y) for x, y in itertools.product(range(64), range(128))] + + all_vars = {var: [] for var in ERA5_VARIABLES} + for shard in range(16): + np_file = np.load(str(era5_path / f"train/{year}_{shard}.npz")) + for var in ERA5_VARIABLES: + all_vars[var].append(np_file[var][:, 0, :, :]) + + targets = np.stack( + [np.concatenate(all_vars[var]) for var in ERA5_VARIABLES], axis=0 + ) + + def gen_func( + jobs: list[tuple[int, int]] + ) -> Generator[dict[str, Any], None, None]: + for x, y in jobs: + yield dict( + item_id=f"{year}_{x}_{y}", + start=pd.Timestamp(f"{year}-01-01"), + target=targets[:, :, x, y], + freq="H", + ) + + hf_dataset = datasets.Dataset.from_generator( + gen_func, + features=Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence( + Sequence(Value("float32")), length=len(ERA5_VARIABLES) + ), + ) + ), + gen_kwargs=dict(jobs=all_jobs), + num_proc=num_proc, + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk( + self.storage_path / dataset, + num_proc=num_proc, + ) diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/gluonts.py b/OATS/models/tsfm/data/builder/lotsa_v1/gluonts.py new file mode 100644 index 0000000..8bc6031 --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/gluonts.py @@ -0,0 +1,452 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import defaultdict +from functools import partial +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any, Generator, Optional +from zipfile import ZipFile + +import datasets +import gluonts +import numpy as np +from datasets import Features, Sequence, Value +from gluonts.dataset import DatasetWriter +from gluonts.dataset.common import MetaData, TrainDatasets +from gluonts.dataset.field_names import FieldName +from gluonts.dataset.repository._tsf_datasets import Dataset as MonashDataset +from gluonts.dataset.repository._tsf_datasets import TSFReader, convert_data +from gluonts.dataset.repository._tsf_reader import frequency_converter +from gluonts.dataset.repository._util import metadata +from gluonts.dataset.repository.datasets import get_dataset +from pandas.tseries.frequencies import to_offset + +from tsfm.common.env import env +from tsfm.data.dataset import MultiSampleTimeSeriesDataset, MultiSampleTimeSeriesDatasetWithIndex, TimeSeriesDataset, TimeSeriesDatasetWithIndex + +from ._base import LOTSADatasetBuilder + + +def default_prediction_length_from_frequency(freq: str) -> int: + prediction_length_map = { + "T": 60 * 24 * 7, + "H": 24 * 7, + "D": 30, + "W-SUN": 8, + "M": 12, + "Y": 4, + "S": 60 * 60 * 24 * 7, + } + try: + freq = to_offset(freq).name + return prediction_length_map[freq] + except KeyError as err: + raise ValueError( + f"Cannot obtain default prediction length from frequency `{freq}`." + ) from err + + +gluonts.dataset.repository._tsf_datasets.default_prediction_length_from_frequency = ( + default_prediction_length_from_frequency +) + + +def generate_forecasting_dataset( + dataset_path: Path, + dataset_name: str, + dataset_writer: DatasetWriter, + prediction_length: Optional[int] = None, +): + dataset = gluonts.dataset.repository._tsf_datasets.datasets[dataset_name] + dataset_path.mkdir(exist_ok=True) + + with TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + with ZipFile(dataset.download(temp_path)) as archive: + archive.extractall(path=temp_path) + + # only one file is exptected + reader = TSFReader(temp_path / archive.namelist()[0]) + meta, data = reader.read() + + if dataset_name.startswith("cif_2016") and len(dataset_name) > len("cif_2016"): + horizon = int(dataset_name[len("cif_2016_") :]) + data = list(filter(lambda x: x if x["horizon"] == horizon else False, data)) + meta.forecast_horizon = horizon + + if dataset_name.startswith("monash_m3_other"): + meta.frequency = "quarterly" + + freq = frequency_converter(meta.frequency) + if prediction_length is None: + if hasattr(meta, "forecast_horizon"): + prediction_length = int(meta.forecast_horizon) + else: + prediction_length = default_prediction_length_from_frequency(freq) + + # Impute missing start dates with unix epoch and remove time series whose + # length is less than or equal to the prediction length + data = [ + {**d, "start_timestamp": d.get("start_timestamp", "1970-01-01")} + for d in data + if len(d[FieldName.TARGET]) > prediction_length + ] + train_data, test_data = convert_data(data, prediction_length) + + meta = MetaData( + **metadata( + cardinality=len(data), + freq=freq, + prediction_length=prediction_length, + ) + ) + + dataset = TrainDatasets(metadata=meta, train=train_data, test=test_data) + dataset.save(path_str=str(dataset_path), writer=dataset_writer, overwrite=True) + + +gluonts.dataset.repository._tsf_datasets.generate_forecasting_dataset = ( + generate_forecasting_dataset +) + + +additional_datasets = { + "bitcoin": MonashDataset( + file_name="bitcoin_dataset_without_missing_values.zip", + record="5122101", + ROOT="https://zenodo.org/record", + ), + "wind_power": MonashDataset( + file_name="wind_4_seconds_dataset.zip", + record="4656032", + ROOT="https://zenodo.org/record", + ), + "us_births": MonashDataset( + file_name="us_births_dataset.zip", + record="4656049", + ROOT="https://zenodo.org/record", + ), + "traffic_hourly": MonashDataset( + file_name="traffic_hourly_dataset.zip", + record="4656132", + ROOT="https://zenodo.org/record", + ), + "traffic_weekly": MonashDataset( + file_name="traffic_weekly_dataset.zip", + record="4656135", + ROOT="https://zenodo.org/record", + ), + "solar_power": MonashDataset( + file_name="solar_4_seconds_dataset.zip", + record="4656027", + ROOT="https://zenodo.org/record", + ), + "oikolab_weather": MonashDataset( + file_name="oikolab_weather_dataset.zip", + record="5184708", + ROOT="https://zenodo.org/record", + ), + "elecdemand": MonashDataset( + file_name="elecdemand_dataset.zip", + record="4656069", + ROOT="https://zenodo.org/record", + ), + "covid_mobility": MonashDataset( + file_name="covid_mobility_dataset_with_missing_values.zip", + record="4663762", + ROOT="https://zenodo.org/record", + ), + "extended_web_traffic_with_missing": MonashDataset( + file_name="web_traffic_extended_dataset_with_missing_values.zip", + record="7370977", + ROOT="https://zenodo.org/record", + ), + "monash_m3_monthly": MonashDataset( + file_name="m3_monthly_dataset.zip", + record="4656298", + ROOT="https://zenodo.org/record", + ), + "monash_m3_quarterly": MonashDataset( + file_name="m3_quarterly_dataset.zip", + record="4656262", + ROOT="https://zenodo.org/record", + ), + "monash_m3_yearly": MonashDataset( + file_name="m3_yearly_dataset.zip", + record="4656222", + ROOT="https://zenodo.org/record", + ), + "monash_m3_other": MonashDataset( + file_name="m3_other_dataset.zip", + record="4656335", + ROOT="https://zenodo.org/record", + ), + "cif_2016_12": MonashDataset( + file_name="cif_2016_dataset.zip", + record="4656042", + ROOT="https://zenodo.org/record", + ), + "cif_2016_6": MonashDataset( + file_name="cif_2016_dataset.zip", + record="4656042", + ROOT="https://zenodo.org/record", + ), + "sunspot_with_missing": MonashDataset( + file_name="sunspot_dataset_with_missing_values.zip", + record="4654773", + ROOT="https://zenodo.org/record", + ), + "temperature_rain_with_missing": MonashDataset( + file_name="temperature_rain_dataset_with_missing_values.zip", + record="5129073", + ROOT="https://zenodo.org/record", + ), + "rideshare_with_missing": MonashDataset( + file_name="rideshare_dataset_with_missing_values.zip", + record="5122114", + ROOT="https://zenodo.org/record", + ), + "car_parts_with_missing": MonashDataset( + file_name="car_parts_dataset_with_missing_values.zip", + record="4656022", + ROOT="https://zenodo.org/record", + ), + "kdd_cup_2018_with_missing": MonashDataset( + file_name="kdd_cup_2018_dataset_with_missing_values.zip", + record="4656719", + ROOT="https://zenodo.org/record", + ), + "vehicle_trips_with_missing": MonashDataset( + file_name="vehicle_trips_dataset_with_missing_values.zip", + record="5122535", + ROOT="https://zenodo.org/record", + ), + "bitcoin_with_missing": MonashDataset( + file_name="bitcoin_dataset_with_missing_values.zip", + record="5121965", + ROOT="https://zenodo.org/record", + ), + "london_smart_meters_with_missing": MonashDataset( + file_name="london_smart_meters_dataset_with_missing_values.zip", + record="4656072", + ROOT="https://zenodo.org/record", + ), + "wind_farms_with_missing": MonashDataset( + file_name="wind_farms_minutely_dataset_with_missing_values.zip", + record="4654909", + ROOT="https://zenodo.org/record", + ), + "nn5_daily_with_missing": MonashDataset( + file_name="nn5_daily_dataset_with_missing_values.zip", + record="4656110", + ROOT="https://zenodo.org/record", + ), +} + +gluonts.dataset.repository._tsf_datasets.datasets |= additional_datasets +gluonts.dataset.repository.datasets.dataset_recipes |= { + k: partial( + generate_forecasting_dataset, + dataset_name=k, + ) + for k in additional_datasets.keys() +} + +PRETRAIN_GROUP = [ + "taxi_30min", + "uber_tlc_daily", + "uber_tlc_hourly", + "wiki-rolling_nips", + "london_smart_meters_with_missing", + "wind_farms_with_missing", + "wind_power", + "solar_power", + "oikolab_weather", + "elecdemand", + "covid_mobility", + "kaggle_web_traffic_weekly", + "extended_web_traffic_with_missing", + "m5", + "m4_yearly", + "m1_yearly", + "m1_quarterly", + "monash_m3_yearly", + "monash_m3_quarterly", + "tourism_yearly", +] + +TRAIN_TEST_GROUP = { + "m4_hourly": None, + "m4_daily": None, + "m4_weekly": None, + "m4_monthly": None, + "m4_quarterly": None, + "m1_monthly": None, + "monash_m3_monthly": None, + "monash_m3_other": None, + "nn5_daily_with_missing": None, + "nn5_weekly": 8, + "tourism_monthly": None, + "tourism_quarterly": None, + "cif_2016_6": None, + "cif_2016_12": None, + "traffic_hourly": 168, + "traffic_weekly": 8, + "australian_electricity_demand": 336, + "rideshare_with_missing": 168, + "saugeenday": 30, + "sunspot_with_missing": 30, + "temperature_rain_with_missing": 30, + "vehicle_trips_with_missing": 30, + "weather": 30, + "car_parts_with_missing": 12, + "fred_md": 12, + "pedestrian_counts": 24, + "hospital": 12, + "covid_deaths": 30, + "kdd_cup_2018_with_missing": 168, + "bitcoin_with_missing": 30, + "us_births": 30, +} + + +MULTI_SAMPLE_DATASETS = [ + "oikolab_weather", + "kaggle_web_traffic_weekly", + "extended_web_traffic_with_missing", + "m5", + "nn5_daily_with_missing", + "nn5_weekly", + "traffic_hourly", + "traffic_weekly", + "rideshare_with_missing", + "temperature_rain_with_missing", + "car_parts_with_missing", + "fred_md", + "hospital", + "covid_deaths", +] + + +class GluonTSDatasetBuilder(LOTSADatasetBuilder): + dataset_list = PRETRAIN_GROUP + list(TRAIN_TEST_GROUP) + dataset_type_map = defaultdict(lambda: TimeSeriesDatasetWithIndex) | { + dataset: MultiSampleTimeSeriesDatasetWithIndex for dataset in MULTI_SAMPLE_DATASETS + } + dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDatasetWithIndex)) | { + dataset: partial( + MultiSampleTimeSeriesDatasetWithIndex, + max_ts=128, + combine_fields=("target", "past_feat_dynamic_real"), + ) + for dataset in MULTI_SAMPLE_DATASETS + } + + def build_dataset(self, dataset): + if dataset in TRAIN_TEST_GROUP: + gluonts_dataset = get_dataset( + dataset, + prediction_length=TRAIN_TEST_GROUP[dataset], + regenerate=True, + ).train + elif dataset in PRETRAIN_GROUP: + gluonts_dataset = get_dataset(dataset, regenerate=True).test + else: + raise ValueError(f"Unrecognized dataset: {dataset}") + + def gen_func() -> Generator[dict[str, Any], None, None]: + for item in gluonts_dataset: + if dataset == "covid_mobility": + if ( + len(item["target"]) < 100 + or np.isnan(item["target"]).sum() / len(item["target"]) > 0.25 + ): + continue + if len(item["target"]) < 20: + continue + + freq = item["start"].freqstr + if freq is None or freq == "": + raise ValueError + item["freq"] = freq + item["start"] = item["start"].to_timestamp() + del item["feat_static_cat"] + yield item + + hf_dataset = datasets.Dataset.from_generator( + generator=gen_func, + features=Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence(Value("float32")), + ) + ), + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk(dataset_path=self.storage_path / dataset) + + def build_val_dataset(self, dataset): + if dataset in TRAIN_TEST_GROUP: + gluonts_dataset = get_dataset( + dataset, + prediction_length=TRAIN_TEST_GROUP[dataset], + regenerate=True, + ).test + elif dataset in PRETRAIN_GROUP: + gluonts_dataset = get_dataset(dataset, regenerate=True).test + else: + raise ValueError(f"Unrecognized dataset: {dataset}") + + def gen_func() -> Generator[dict[str, Any], None, None]: + for item in gluonts_dataset: + if dataset == "covid_mobility": + if ( + len(item["target"]) < 100 + or np.isnan(item["target"]).sum() / len(item["target"]) > 0.25 + ): + continue + if len(item["target"]) < 20: + continue + + freq = item["start"].freqstr + if freq is None or freq == "": + raise ValueError + item["freq"] = freq + item["start"] = item["start"].to_timestamp() + del item["feat_static_cat"] + yield item + + hf_dataset = datasets.Dataset.from_generator( + generator=gen_func, + features=Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence(Value("float32")), + ) + ), + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk(dataset_path=f"{self.storage_path}/{dataset}") \ No newline at end of file diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/largest.py b/OATS/models/tsfm/data/builder/lotsa_v1/largest.py new file mode 100644 index 0000000..3db67dd --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/largest.py @@ -0,0 +1,76 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Generator + +import datasets +import pandas as pd +from datasets import Features, Sequence, Value + +from tsfm.common.env import env +from tsfm.data.dataset import TimeSeriesDataset, TimeSeriesDatasetWithIndex + +from ._base import LOTSADatasetBuilder + + +class LargeSTDatasetBuilder(LOTSADatasetBuilder): + dataset_list = [ + "largest_2017", + "largest_2018", + "largest_2019", + "largest_2020", + "largest_2021", + ] + dataset_type_map = defaultdict(lambda: TimeSeriesDatasetWithIndex) + dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDatasetWithIndex)) + + def build_dataset(self, dataset: str, num_proc: int = os.cpu_count()): + year = dataset.split("_")[-1] + df = pd.read_hdf(Path(os.getenv("LARGEST_PATH")) / f"ca_his_raw_{year}.h5") + + def gen_func(cols: list[int]) -> Generator[dict[str, Any], None, None]: + for col in cols: + if df[col].isnull().all(): + continue + yield dict( + item_id=f"{col}", + start=df.index[0], + target=df[col], + freq="5T", + ) + + hf_dataset = datasets.Dataset.from_generator( + gen_func, + features=Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence(Value("float32")), + ) + ), + num_proc=num_proc, + gen_kwargs={"cols": list(df.columns)}, + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk(self.storage_path / dataset) diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/lib_city.py b/OATS/models/tsfm/data/builder/lotsa_v1/lib_city.py new file mode 100644 index 0000000..cec8926 --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/lib_city.py @@ -0,0 +1,171 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import os +from collections import defaultdict +from functools import partial +from pathlib import Path + +import datasets +import numpy as np +import pandas as pd +from datasets import Features, Sequence, Value +from pandas.tseries.frequencies import to_offset + +from tsfm.common.env import env +from tsfm.data.dataset import MultiSampleTimeSeriesDataset, MultiSampleTimeSeriesDatasetWithIndex + +from ._base import LOTSADatasetBuilder + + +class LibCityDatasetBuilder(LOTSADatasetBuilder): + dataset_list = [ + "BEIJING_SUBWAY_30MIN", + "HZMETRO", + "LOOP_SEATTLE", + "LOS_LOOP", + "M_DENSE", + "PEMS03", + "PEMS04", + "PEMS07", + "PEMS08", + "PEMS_BAY", + "Q-TRAFFIC", + "SHMETRO", + "SZ_TAXI", + ] + dataset_type_map = defaultdict(lambda: MultiSampleTimeSeriesDatasetWithIndex) + dataset_load_func_map = defaultdict( + lambda: partial( + MultiSampleTimeSeriesDatasetWithIndex, + max_ts=128, + combine_fields=("target", "past_feat_dynamic_real"), + ) + ) + + def build_dataset(self, dataset: str, num_proc: int = os.cpu_count()): + lib_city_path = Path(os.getenv("LIB_CITY_PATH")) + engine = "pyarrow" if dataset == "Q-TRAFFIC" else None + + if dataset in ("PEMS03", "PEMS04", "PEMS07", "PEMS08"): + raw_dataset_name = dataset.replace("0", "D") + else: + raw_dataset_name = dataset + + with open(lib_city_path / f"{raw_dataset_name}/config.json") as f: + config = json.load(f) + data_col = ( + data_col + if len(data_col := config["info"]["data_col"]) == 1 + else config["info"]["data_col"] + ) + freq = to_offset( + pd.to_timedelta(f'{config["info"]["time_intervals"]}S') + ).freqstr + + try: + df = pd.read_csv( + lib_city_path / f"{raw_dataset_name}/{raw_dataset_name}_new.dyna", + engine=engine, + ) + except FileNotFoundError: + df = pd.read_csv( + lib_city_path / f"{raw_dataset_name}/{raw_dataset_name}.dyna", + engine=engine, + ) + + df["time"] = pd.to_datetime(df["time"]) + df = df.set_index("time") + + past_feat_dynamic_real_dict = {} + if (lib_city_path / f"{dataset}/{dataset}.ext").exists(): + ext_df = pd.read_csv(lib_city_path / f"{dataset}/{dataset}.ext") + ext_df["time"] = pd.to_datetime(ext_df["time"]) + ext_df = ext_df.set_index("time") + ext_df = ext_df[config["info"]["ext_col"]] + if pd.infer_freq(ext_df.index) is None: + ext_df = ext_df.reindex( + pd.date_range(ext_df.index[0], ext_df.index[-1], freq=freq) + ) + cov = ext_df.to_numpy() + if cov.shape[-1] == 1: + cov = cov.squeeze(-1) + else: + cov = cov.T + past_feat_dynamic_real_dict["past_feat_dynamic_real"] = cov + else: + cov = None + + entity_df = df[df.entity_id == next(iter(df.entity_id.unique()))] + if entity_df[data_col].to_numpy().ndim == 2: + target_dim = entity_df[data_col].to_numpy().astype(np.float32).T.shape[0] + target_feature = ( + Sequence(Value("float32")) + if target_dim == 1 + else Sequence(Sequence(Value("float32")), length=target_dim) + ) + else: + target_feature = Sequence(Value("float32")) + + past_feat_dynamic_real_feature_dict = {} + if cov is not None: + past_feat_dynamic_real_feature_dict["past_feat_dynamic_real"] = ( + Sequence(Value("float")) + if cov.ndim == 1 + else Sequence(Sequence(Value("float32")), length=cov.shape[0]) + ) + + def gen_func(): + for idx in df.entity_id.unique(): + entity_df = df.query(f"entity_id == {idx}") + inferred_freq = pd.infer_freq(entity_df.index) + if inferred_freq is None: + entity_df = entity_df.reindex( + pd.date_range( + entity_df.index[0], entity_df.index[-1], freq=freq + ) + ) + target = entity_df[data_col].to_numpy().astype(np.float32) + if target.ndim == 2: + if target.shape[-1] == 1: + target = target.squeeze(-1) + else: + target = target.T + yield dict( + item_id=f"{idx}", + start=entity_df.index[0], + target=target, + freq=freq, + ) | past_feat_dynamic_real_dict + + hf_datasets = datasets.Dataset.from_generator( + gen_func, + features=Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=target_feature, + ) + | past_feat_dynamic_real_feature_dict + ), + cache_dir=env.HF_CACHE_PATH, + ) + hf_datasets.info.dataset_name = dataset + hf_datasets.save_to_disk(self.storage_path / dataset) diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/others.py b/OATS/models/tsfm/data/builder/lotsa_v1/others.py new file mode 100644 index 0000000..88df14b --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/others.py @@ -0,0 +1,834 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Generator + +import datasets +import numpy as np +import pandas as pd +from datasets import Features, Sequence, Value + +try: + from pyreadr import read_r +except ImportError as e: + + def read_r(*args, **kwargs): + raise e + + +from tsfm.common.env import env +from tsfm.common.typing import GenFunc +from tsfm.data.dataset import MultiSampleTimeSeriesDataset, MultiSampleTimeSeriesDatasetWithIndex, TimeSeriesDataset, TimeSeriesDatasetWithIndex + +from ._base import LOTSADatasetBuilder + +MULTI_SAMPLE_DATASETS = [ + "godaddy", + "hierarchical_sales", + "beijing_air_quality", +] + + +def _get_kdd_2022_gen_func(data_path: Path) -> tuple[GenFunc, Features]: + df = pd.read_csv( + data_path / "wtbdata_245days.csv", header=0, skiprows=lambda x: x == 1 + ) + df["timestamp"] = pd.to_datetime("2020-01-01 " + df["Tmstamp"]) + ( + df["Day"] - 1 + ) * pd.Timedelta("1D") + + def gen_func() -> Generator[dict[str, Any], None, None]: + for idx in df["TurbID"].unique(): + id_df = df[df["TurbID"] == idx].set_index("timestamp") + yield dict( + item_id=f"{idx}", + start=id_df.index[0], + target=id_df["Patv"], + past_feat_dynamic_real=id_df[ + [ + "Wspd", + "Wdir", + "Etmp", + "Itmp", + "Ndir", + "Pab1", + "Pab2", + "Pab3", + "Prtv", + ] + ] + .to_numpy() + .T, + freq="10T", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + past_feat_dynamic_real=Sequence(Sequence(Value("float32")), length=9), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_godaddy_gen_func(data_path: Path) -> tuple[GenFunc, Features]: + train = pd.read_csv(data_path / "train.csv") + test = pd.read_csv(data_path / "revealed_test.csv") + df = pd.concat([train, test]) + df["first_day_of_month"] = pd.to_datetime(df["first_day_of_month"]) + + def gen_func() -> Generator[dict[str, Any], None, None]: + for idx in df.cfips.unique(): + id_df = df.query(f"cfips == {idx}") + id_df = id_df.set_index("first_day_of_month").sort_index() + yield dict( + item_id=f"{idx}", + start=id_df.index[0], + target=id_df[["microbusiness_density", "active"]].to_numpy().T, + freq="MS", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Sequence(Value("float32")), length=2), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_favorita_sales_gen_func( + data_path: Path, length_threshold: int = 250, missing_threshold: float = 0.5 +) -> tuple[GenFunc, Features]: + train = pd.read_csv( + data_path / "train.csv", + dtype=dict( + id=int, store_nbr=int, item_nbr=int, unit_sales=int, onpromotion=bool + ), + parse_dates=["date"], + engine="pyarrow", + ) + + def gen_func() -> Generator[dict[str, Any], None, None]: + for store_nbr in train.store_nbr.unique(): + store = train.query(f"store_nbr == {store_nbr}") + for item_nbr in store.item_nbr.unique(): + item = ( + store.query(f"item_nbr == {item_nbr}") + .set_index("date") + .sort_index() + ) + item = item.reindex( + pd.date_range(start=item.index[0], end=item.index[-1], freq="1D") + ) + + missing_pct = item.unit_sales.isnull().sum() / len(item) + + if len(item) < length_threshold or missing_pct > missing_threshold: + continue + + yield dict( + item_id=f"{store_nbr}_{item_nbr}", + start=item.index[0], + target=item.unit_sales, + freq="D", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_favorita_transactions_gen_func(data_path: Path) -> tuple[GenFunc, Features]: + transactions = pd.read_csv(data_path / "transactions.csv") + transactions["date"] = pd.to_datetime(transactions["date"]) + + def gen_func() -> Generator[dict[str, Any], None, None]: + for store_nbr in transactions.store_nbr.unique(): + store = ( + transactions.query(f"store_nbr == {store_nbr}") + .set_index("date") + .sort_index() + ) + store = store.reindex( + pd.date_range(start=store.index[0], end=store.index[-1], freq="1D") + ) + yield dict( + item_id=f"{store_nbr}", + start=store.index[0], + target=store.transactions, + freq="D", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_restaurant_gen_func( + data_path: Path, missing_threshold: float = 0.5 +) -> tuple[GenFunc, Features]: + air_visit_data = pd.read_csv(data_path / "air_visit_data.csv") + air_visit_data["visit_date"] = pd.to_datetime(air_visit_data["visit_date"]) + + def gen_func() -> Generator[dict[str, Any], None, None]: + for air_store_id in air_visit_data.air_store_id.unique(): + air_store = ( + air_visit_data.query(f'air_store_id == "{air_store_id}"') + .set_index("visit_date") + .sort_index() + ) + air_store = air_store.reindex( + pd.date_range( + start=air_store.index[0], end=air_store.index[-1], freq="1D" + ) + ) + missing_pct = air_store.visitors.isnull().sum() / len(air_store) + if missing_pct > missing_threshold: + continue + yield dict( + item_id=f"{air_store_id}", + start=air_store.index[0], + target=air_store.visitors, + freq="D", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_hierarchical_sales_gen_func(data_path: Path) -> tuple[GenFunc, Features]: + sales = pd.read_csv(data_path / "hierarchical_sales_data.csv") + sales["DATE"] = pd.to_datetime(sales["DATE"]) + sales = sales.set_index("DATE").sort_index() + + def gen_func() -> Generator[dict[str, Any], None, None]: + for id in sales.columns: + if "QTY" in id: + yield dict( + item_id=id, + start=sales.index[0], + target=sales[id].astype(np.float32), + freq="D", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_china_air_quality_gen_func(data_path: Path) -> tuple[GenFunc, Features]: + airquality = pd.read_csv(data_path / "airquality.csv") + airquality["time"] = pd.to_datetime(airquality["time"]) + + def gen_func() -> Generator[dict[str, Any], None, None]: + for station_id in airquality.station_id.unique(): + station = ( + airquality.query(f"station_id == {station_id}") + .set_index("time") + .sort_index() + ) + station = station.reindex( + pd.date_range( + start=station.index[0], + end=station.index[-1], + freq="H", + ) + ) + yield dict( + item_id=f"{station_id}", + start=station.index[0], + target=station[ + [ + "PM25_Concentration", + "PM10_Concentration", + "NO2_Concentration", + "CO_Concentration", + "O3_Concentration", + "SO2_Concentration", + ] + ] + .to_numpy() + .T, + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Sequence(Value("float32")), length=6), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_beijing_air_quality_gen_func(data_path: Path) -> tuple[GenFunc, Features]: + files = data_path.glob("*.csv") + dfs = [(file, pd.read_csv(file)) for file in files] + + def gen_func() -> Generator[dict[str, Any], None, None]: + for file, data in dfs: + data["date"] = pd.to_datetime( + data.year.astype(str) + + "-" + + data.month.astype(str) + + "-" + + data.day.astype(str) + + " " + + data.hour.astype(str) + + ":00" + ) + data = data.set_index("date").sort_index() + data = data.reindex( + pd.date_range( + start=data.index[0], + end=data.index[-1], + freq="H", + ) + ) + yield dict( + item_id=file.stem, + start=data.index[0], + target=data[ + [ + "PM2.5", + "PM10", + "SO2", + "NO2", + "CO", + "O3", + "TEMP", + "PRES", + "DEWP", + "RAIN", + "WSPM", + ] + ] + .to_numpy() + .T, + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Sequence(Value("float32")), length=11), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_residential_load_power_gen_func(data_path: Path) -> tuple[GenFunc, Features]: + units = [ + file.stem + for file in (data_path / "anonymous_public_load_power_data_per_unit").glob( + "*.rds" + ) + ] + + def gen_func() -> Generator[dict[str, Any], None, None]: + for unit in units: + load_power = read_r( + data_path / f"anonymous_public_load_power_data_per_unit/{unit}.rds" + )[None] + load_power = ( + load_power.drop_duplicates(subset="utc", keep="last") + .set_index("utc") + .sort_index() + ) + load_power = load_power.reindex( + pd.date_range( + start=load_power.index[0], + end=load_power.index[-1], + freq="T", + ) + ) + target = load_power[["sum", "min", "max"]].to_numpy() + if target.shape[0] < 16: + continue + + yield dict( + item_id=f"{unit}", + start=load_power.index[0], + target=target.T, + freq="T", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Sequence(Value("float32")), length=3), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_residential_pv_power_gen_func(data_path: Path) -> tuple[GenFunc, Features]: + units = [ + file.stem + for file in (data_path / "anonymous_public_pv_power_data_per_unit").glob( + "*.rds" + ) + ] + + def gen_func() -> Generator[dict[str, Any], None, None]: + for unit in units: + pv_power = read_r( + data_path / f"anonymous_public_pv_power_data_per_unit/{unit}.rds" + )[None] + pv_power = ( + pv_power.drop_duplicates(subset="utc", keep="last") + .set_index("utc") + .sort_index() + ) + pv_power = pv_power.reindex( + pd.date_range( + start=pv_power.index[0], + end=pv_power.index[-1], + freq="T", + ) + ) + yield dict( + item_id=f"{unit}", + start=pv_power.index[0], + target=pv_power[["sum", "min", "max"]].to_numpy().T, + freq="T", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Sequence(Value("float32")), length=3), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_cdc_fluview_ilinet_gen_func(data_path: Path) -> tuple[GenFunc, Features]: + national = pd.read_csv(data_path / "National/ILINet.csv", skiprows=1) + hhs = pd.read_csv(data_path / "HHS/ILINet.csv", skiprows=1) + census = pd.read_csv(data_path / "Census/ILINet.csv", skiprows=1) + state = pd.read_csv(data_path / "State/ILINet.csv", skiprows=1) + + def gen_func() -> Generator[dict[str, Any], None, None]: + for dataset in (national, hhs, census, state): + dataset["date"] = pd.to_datetime( + (dataset.YEAR * 100 + dataset.WEEK).astype(str) + "0", format="%Y%W%w" + ) + for region in dataset.REGION.unique(): + region_ds = ( + dataset.query(f'REGION == "{region}"') + .set_index("date") + .sort_index() + ) + + if region_ds["REGION TYPE"].iloc[0] == "National": + item_id = "national" + elif region_ds["REGION TYPE"].iloc[0] == "HHS Regions": + item_id = f"hhs_{region}" + elif region_ds["REGION TYPE"].iloc[0] == "Census Regions": + item_id = f"census_{region}" + elif region_ds["REGION TYPE"].iloc[0] == "States": + item_id = f"states_{region}" + else: + raise ValueError + + yield dict( + item_id=item_id, + start=region_ds.index[0], + target=region_ds[ + [ + "% WEIGHTED ILI", + "%UNWEIGHTED ILI", + "ILITOTAL", + "NUM. OF PROVIDERS", + "TOTAL PATIENTS", + ] + ] + .replace("X", np.nan) + .to_numpy() + .astype(float) + .T, + freq="W", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Sequence(Value("float32")), length=5), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_cdc_fluview_who_nrevss_gen_func(data_path: Path) -> tuple[GenFunc, Features]: + national_prior = pd.read_csv( + data_path / "National/WHO_NREVSS_Combined_prior_to_2015_16.csv", skiprows=1 + ) + national_public_health = pd.read_csv( + data_path / "National/WHO_NREVSS_Public_Health_Labs.csv", skiprows=1 + ) + national_clinical_labs = pd.read_csv( + data_path / "National/WHO_NREVSS_Clinical_Labs.csv", skiprows=1 + ) + hhs_prior = pd.read_csv( + data_path / "HHS/WHO_NREVSS_Combined_prior_to_2015_16.csv", skiprows=1 + ) + hhs_public_health = pd.read_csv( + data_path / "HHS/WHO_NREVSS_Public_Health_Labs.csv", skiprows=1 + ) + hhs_clinical_labs = pd.read_csv( + data_path / "HHS/WHO_NREVSS_Clinical_Labs.csv", skiprows=1 + ) + census_prior = pd.read_csv( + data_path / "Census/WHO_NREVSS_Combined_prior_to_2015_16.csv", skiprows=1 + ) + census_public_health = pd.read_csv( + data_path / "Census/WHO_NREVSS_Public_Health_Labs.csv", skiprows=1 + ) + census_clinical_labs = pd.read_csv( + data_path / "Census/WHO_NREVSS_Clinical_Labs.csv", skiprows=1 + ) + state_prior = pd.read_csv( + data_path / "State/WHO_NREVSS_Combined_prior_to_2015_16.csv", skiprows=1 + ) + state_public_health = pd.read_csv( + data_path / "State/WHO_NREVSS_Public_Health_Labs.csv", skiprows=1 + ) + state_clinical_labs = pd.read_csv( + data_path / "State/WHO_NREVSS_Clinical_Labs.csv", skiprows=1 + ) + + state_public_health["YEAR"] = ( + state_public_health["SEASON_DESCRIPTION"] + .apply(lambda x: x[len("Season ") : len("Season 2015")]) + .astype(int) + ) + state_public_health["WEEK"] = ( + state_public_health["SEASON_DESCRIPTION"] + .apply(lambda x: x[len("Season 2015-") :]) + .astype(int) + ) + + def gen_func() -> Generator[dict[str, Any], None, None]: + for prior, public_health, clinical_labs in [ + (national_prior, national_public_health, national_clinical_labs), + (hhs_prior, hhs_public_health, hhs_clinical_labs), + (census_prior, census_public_health, census_clinical_labs), + (state_prior, state_public_health, state_clinical_labs), + ]: + for col in [ + "TOTAL SPECIMENS", + "A (2009 H1N1)", + "A (H1)", + "A (H3)", + "A (Subtyping not Performed)", + "A (Unable to Subtype)", + "B", + "H3N2v", + ]: + prior[col] = prior[col].replace("X", 0).astype(int) + + for col in [ + "TOTAL SPECIMENS", + "A (2009 H1N1)", + "A (H3)", + "A (Subtyping not Performed)", + "B", + "BVic", + "BYam", + "H3N2v", + ]: + public_health[col] = public_health[col].replace("X", 0).astype(int) + + for col in ["TOTAL SPECIMENS", "TOTAL A", "TOTAL B"]: + clinical_labs[col] = clinical_labs[col].replace("X", 0).astype(int) + + prior.loc[:, "A"] = ( + prior["A (2009 H1N1)"] + + prior["A (H1)"] + + prior["A (H3)"] + + prior["A (Subtyping not Performed)"] + + prior["A (Unable to Subtype)"] + ) + public_health.loc[:, "A"] = ( + public_health["A (2009 H1N1)"] + + public_health["A (H3)"] + + public_health["A (Subtyping not Performed)"] + ) + public_health.loc[:, "B"] = ( + public_health["B"] + public_health["BVic"] + public_health["BYam"] + ) + + prior = prior[ + [ + "TOTAL SPECIMENS", + "A", + "B", + "H3N2v", + "YEAR", + "WEEK", + "REGION", + "REGION TYPE", + ] + ] + post = public_health[ + [ + "TOTAL SPECIMENS", + "A", + "B", + "H3N2v", + "YEAR", + "WEEK", + "REGION", + "REGION TYPE", + ] + ] + post.loc[:, "TOTAL SPECIMENS"] = ( + post["TOTAL SPECIMENS"] + clinical_labs["TOTAL SPECIMENS"] + ) + post.loc[:, "A"] = post["A"] + clinical_labs["TOTAL A"] + post.loc[:, "B"] = post["B"] + clinical_labs["TOTAL B"] + + combined = pd.concat([prior, post]) + combined["date"] = pd.to_datetime( + (combined.YEAR * 100 + combined.WEEK).astype(str) + "0", format="%Y%W%w" + ) + + for region in combined.REGION.unique(): + region_ds = ( + combined.query(f'REGION == "{region}"') + .set_index("date") + .sort_index() + ) + + if region_ds["REGION TYPE"].iloc[0] == "National": + item_id = "national" + elif region_ds["REGION TYPE"].iloc[0] == "HHS Regions": + item_id = f"hhs_{region}" + elif region_ds["REGION TYPE"].iloc[0] == "Census Regions": + item_id = f"census_{region}" + elif region_ds["REGION TYPE"].iloc[0] == "States": + item_id = f"states_{region}" + else: + raise ValueError + + target = ( + region_ds[["TOTAL SPECIMENS", "A", "B", "H3N2v"]] + .to_numpy() + .astype(np.float32) + ) + + if target.shape[0] < 16: + continue + + yield dict( + item_id=item_id, + start=region_ds.index[0], + target=target.T, + freq="W", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Sequence(Value("float32")), length=4), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_project_tycho_gen_func( + data_path: Path, + length_threshold: int = 100, + missing_threshold: float = 0.25, +) -> tuple[GenFunc, Features]: + data = pd.read_csv(data_path / "ProjectTycho_Level2_v1.1.0.csv", engine="pyarrow") + data = data.rename(columns={" event": "event"}) + ts = data[["state", "loc", "loc_type", "disease", "event"]].value_counts() + ts = ts[ts > 0].index + + def gen_func() -> Generator[dict[str, Any], None, None]: + for state, loc, loc_type, disease, event in ts: + item = data.query( + f'state == "{state}" and loc == "{loc}" and loc_type == "{loc_type}" and ' + f'disease == "{disease}" and event == "{event}"' + ) + item.loc[:, "from_date"] = pd.to_datetime(item["from_date"]) + item = item.drop_duplicates("from_date").set_index("from_date").sort_index() + + item = item.reindex( + pd.date_range( + start=item.index[0], + end=item.index[-1], + freq="W", + ) + ) + + missing_pct = item.number.isnull().sum() / len(item) + if len(item) < length_threshold or missing_pct > missing_threshold: + continue + + yield dict( + item_id=f"{state}_{loc}_{loc_type}_{disease}_{event}", + start=item.index[0], + target=item.number, + freq="W", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +class OthersLOTSADatasetBuilder(LOTSADatasetBuilder): + dataset_list = [ + "kdd2022", + "godaddy", + "favorita_sales", + "favorita_transactions", + "restaurant", + "hierarchical_sales", + "china_air_quality", + "beijing_air_quality", + "residential_load_power", + "residential_pv_power", + "cdc_fluview_ilinet", + "cdc_fluview_who_nrevss", + "project_tycho", + "AtrialFibrillation", "BIDMC32HR", "IEEEPPG", "MotorImagery", "PigArtPressure", "PigCVP", "SelfRegulationSCP1", "SelfRegulationSCP2", "TDBrain" + ] + dataset_type_map = defaultdict(lambda: TimeSeriesDatasetWithIndex) | { + dataset: MultiSampleTimeSeriesDatasetWithIndex for dataset in MULTI_SAMPLE_DATASETS + } + dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDatasetWithIndex)) | { + dataset: partial( + MultiSampleTimeSeriesDatasetWithIndex, + max_ts=128, + combine_fields=("target", "past_feat_dynamic_real"), + ) + for dataset in MULTI_SAMPLE_DATASETS + } + + def build_dataset(self, dataset: str): + data_path = ( + Path(os.getenv("OTHERS_PATH")) + / { + "kdd2022": "kdd2022", + "godaddy": "godaddy", + "favorita_sales": "favorita", + "favorita_transactions": "favorita", + "restaurant": "restaurant", + "hierarchical_sales": "hierarchical_sales", + "china_air_quality": "china_air_quality", + "beijing_air_quality": "beijing_air_quality", + "residential_load_power": "residential_power", + "residential_pv_power": "residential_power", + "cdc_fluview_ilinet": "CDCFluView", + "cdc_fluview_who_nrevss": "CDCFluView", + "project_tycho": "ProjectTycho", + }[dataset] + ) + gen_func, features = { + "kdd2022": _get_kdd_2022_gen_func, + "godaddy": _get_godaddy_gen_func, + "favorita_sales": _get_favorita_sales_gen_func, + "favorita_transactions": _get_favorita_transactions_gen_func, + "restaurant": _get_restaurant_gen_func, + "hierarchical_sales": _get_hierarchical_sales_gen_func, + "china_air_quality": _get_china_air_quality_gen_func, + "beijing_air_quality": _get_beijing_air_quality_gen_func, + "residential_load_power": _get_residential_load_power_gen_func, + "residential_pv_power": _get_residential_pv_power_gen_func, + "cdc_fluview_ilinet": _get_cdc_fluview_ilinet_gen_func, + "cdc_fluview_who_nrevss": _get_cdc_fluview_who_nrevss_gen_func, + "project_tycho": _get_project_tycho_gen_func, + }[dataset](data_path) + + hf_dataset = datasets.Dataset.from_generator( + gen_func, + features=features, + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk(self.storage_path / dataset) diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/proenfo.py b/OATS/models/tsfm/data/builder/lotsa_v1/proenfo.py new file mode 100644 index 0000000..1fac96b --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/proenfo.py @@ -0,0 +1,371 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from collections import defaultdict +from functools import partial +from pathlib import Path + +import datasets +import numpy as np +import pandas as pd +from datasets import Features, Sequence, Value + +from tsfm.common.env import env +from tsfm.common.typing import GenFunc +from tsfm.data.dataset import MultiSampleTimeSeriesDataset, MultiSampleTimeSeriesDatasetWithIndex, TimeSeriesDataset, TimeSeriesDatasetWithIndex + +from ._base import LOTSADatasetBuilder + +MULTI_SAMPLE_DATASETS = [ + "gfc12_load", + "gfc17_load", + "bull", + "hog", +] + + +def _get_gfc12_load_gen_func(proenfo_path: Path) -> tuple[GenFunc, Features]: + df = pd.read_pickle(proenfo_path / "GFC12_load/new_load_with_weather.pkl") + + def gen_func(): + for col in range(1, 21): + yield dict( + item_id=f"GFC12_load_{col}", + start=df.index[0], + target=df[str(col)], + past_feat_dynamic_real=df.airTemperature, + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + past_feat_dynamic_real=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_gfc14_load_gen_func(proenfo_path: Path) -> tuple[GenFunc, Features]: + df = pd.read_pickle(proenfo_path / "GFC14_load/load_with_weather.pkl") + + def gen_func(): + yield dict( + item_id="GFC14_load", + start=df.index[0], + target=df["load"], + past_feat_dynamic_real=df.airTemperature, + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_gfc17_load_gen_func(proenfo_path: Path) -> tuple[GenFunc, Features]: + df = pd.read_pickle(proenfo_path / "GFC17_load/load_with_weather.pkl") + + def gen_func(): + for col in ["CT", "ME", "NH", "RI", "VT", "NEMASSBOST", "SEMASS", "WCMASS"]: + yield dict( + item_id=f"GFC17_load_{col}", + start=df[col].index[0], + target=df[col]["load"].astype(np.float32), + past_feat_dynamic_real=df[col]["airTemperature"].astype(np.float32), + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + past_feat_dynamic_real=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_spain_gen_func(proenfo_path: Path) -> tuple[GenFunc, Features]: + df = pd.read_pickle(proenfo_path / "Spain/new_load_with_weather.pkl") + + def gen_func(): + yield dict( + item_id="Spain", + start=df.index[0], + target=df["load"], + past_feat_dynamic_real=df.airTemperature, + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + past_feat_dynamic_real=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_pdb_gen_func(proenfo_path: Path) -> tuple[GenFunc, Features]: + df = pd.read_pickle(proenfo_path / "PDB/load_with_weather.pkl") + + def gen_func(): + yield dict( + item_id="PDB", + start=df.index[0], + target=df["load"].astype(np.float32), + past_feat_dynamic_real=df.airTemperature.astype(np.float32), + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + past_feat_dynamic_real=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_elf_gen_func(proenfo_path: Path): + df = pd.read_pickle(proenfo_path / "ELF/load_with_weather.pkl") + + def gen_func(): + yield dict( + item_id="ELF", + start=df.index[0], + target=df["load"], + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_bull_gen_func(proenfo_path: Path) -> tuple[GenFunc, Features]: + df = pd.read_pickle(proenfo_path / "Bull/load_with_weather.pkl") + + def gen_func(): + for col in [col for col in df if "Bull" in col]: + yield dict( + item_id=f"Bull_{col}", + start=df.index[0], + target=df[col], + past_feat_dynamic_real=df[ + ["airTemperature", "dewTemperature", "seaLvlPressure"] + ] + .to_numpy() + .T, + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + past_feat_dynamic_real=Sequence(Sequence(Value("float32")), length=3), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_cockatoo_gen_func(proenfo_path: Path) -> tuple[GenFunc, Features]: + df = pd.read_pickle(proenfo_path / "Cockatoo/load_with_weather.pkl") + + def gen_func(): + for col in [col for col in df if "Cockatoo" in col]: + yield dict( + item_id=f"Cockatoo_{col}", + start=df.index[0], + target=df[col], + past_feat_dynamic_real=df[ + [ + "airTemperature", + "dewTemperature", + "seaLvlPressure", + "windDirection", + "windSpeed", + ] + ] + .to_numpy() + .T, + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + past_feat_dynamic_real=Sequence(Sequence(Value("float32")), length=5), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_hog_gen_func(proenfo_path: Path) -> tuple[GenFunc, Features]: + df = pd.read_pickle(proenfo_path / "Hog/load_with_weather.pkl") + + def gen_func(): + for col in [col for col in df if "Hog" in col]: + yield dict( + item_id=f"Hog_{col}", + start=df.index[0], + target=df[col], + past_feat_dynamic_real=df[ + [ + "airTemperature", + "dewTemperature", + "seaLvlPressure", + "windDirection", + "windSpeed", + ] + ] + .to_numpy() + .T, + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + past_feat_dynamic_real=Sequence(Sequence(Value("float32")), length=5), + freq=Value("string"), + ) + ) + + return gen_func, features + + +def _get_covid19_energy_gen_func(proenfo_path: Path) -> tuple[GenFunc, Features]: + df = pd.read_pickle(proenfo_path / "Covid19/load_with_weather.pkl") + cols = list(df.columns) + cols.remove("load") + target = df["load"] + past_feat_dynamic_real = df[cols].to_numpy().T + + def gen_func(): + yield dict( + item_id="covid19_energy", + start=df.index[0], + target=target, + past_feat_dynamic_real=past_feat_dynamic_real, + freq="H", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + target=Sequence(Value("float32")), + past_feat_dynamic_real=Sequence( + Sequence(Value("float32")), length=len(past_feat_dynamic_real) + ), + freq=Value("string"), + ) + ) + + return gen_func, features + + +class ProEnFoDatasetBuilder(LOTSADatasetBuilder): + dataset_list = [ + "gfc12_load", + "gfc14_load", + "gfc17_load", + "spain", + "pdb", + "elf", + "bull", + "cockatoo", + "hog", + "covid19_energy", + ] + dataset_type_map = defaultdict(lambda: TimeSeriesDatasetWithIndex) | { + dataset: MultiSampleTimeSeriesDatasetWithIndex for dataset in MULTI_SAMPLE_DATASETS + } + dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDatasetWithIndex)) | { + dataset: partial( + MultiSampleTimeSeriesDatasetWithIndex, + max_ts=128, + combine_fields=("target", "past_feat_dynamic_real"), + ) + for dataset in MULTI_SAMPLE_DATASETS + } + + def build_dataset(self, dataset: str): + proenfo_path = Path(os.getenv("PROENFO_PATH")) + gen_func, features = { + "gfc12_load": _get_gfc12_load_gen_func, + "gfc14_load": _get_gfc14_load_gen_func, + "gfc17_load": _get_gfc17_load_gen_func, + "spain": _get_spain_gen_func, + "pdb": _get_pdb_gen_func, + "elf": _get_elf_gen_func, + "bull": _get_bull_gen_func, + "cockatoo": _get_cockatoo_gen_func, + "hog": _get_hog_gen_func, + "covid19_energy": _get_covid19_energy_gen_func, + }[dataset](proenfo_path) + + hf_dataset = datasets.Dataset.from_generator( + gen_func, + features=features, + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk(self.storage_path / dataset) diff --git a/OATS/models/tsfm/data/builder/lotsa_v1/subseasonal.py b/OATS/models/tsfm/data/builder/lotsa_v1/subseasonal.py new file mode 100644 index 0000000..74e99b5 --- /dev/null +++ b/OATS/models/tsfm/data/builder/lotsa_v1/subseasonal.py @@ -0,0 +1,154 @@ +# Copyright (c), Salesforce 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import defaultdict +from functools import partial +from typing import Any, Generator + +import datasets +import numpy as np +from datasets import Features, Sequence, Value + +try: + from subseasonal_data import data_loaders +except ImportError as e: + + def data_loaders(*args, **kwargs): + raise e + + +from tsfm.common.env import env +from tsfm.common.typing import GenFunc +from tsfm.data.dataset import MultiSampleTimeSeriesDataset, MultiSampleTimeSeriesDatasetWithIndex + +from ._base import LOTSADatasetBuilder + + +def _get_subseasonal_precip_gen_func() -> tuple[GenFunc, Features]: + precip = data_loaders.get_ground_truth("us_precip") + lat_lon = precip[["lat", "lon"]].value_counts().index + + def gen_func() -> Generator[dict[str, Any], None, None]: + for lat, lon in lat_lon: + lat_lon_precip = ( + precip.query(f"lat == {lat} and lon == {lon}") + .set_index("start_date") + .sort_index() + ) + yield dict( + item_id=f"{lat}_{lon}", + start=lat_lon_precip.index[0], + target=lat_lon_precip["precip"][:11323].to_numpy(), + freq="D", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence(Value("float32")), + ) + ) + + return gen_func, features + + +def _get_subseasonal_gen_func() -> tuple[GenFunc, Features]: + precip = data_loaders.get_ground_truth("us_precip") + tmp2m = data_loaders.get_ground_truth("us_tmp2m") + tmin = data_loaders.get_ground_truth("us_tmin") + tmax = data_loaders.get_ground_truth("us_tmax") + + lat_lon = precip[["lat", "lon"]].value_counts().index + + def gen_func() -> Generator[dict[str, Any], None, None]: + for lat, lon in lat_lon: + lat_lon_precip = ( + precip.query(f"lat == {lat} and lon == {lon}") + .set_index("start_date") + .sort_index() + ) + lat_lon_tmp2m = ( + tmp2m.query(f"lat == {lat} and lon == {lon}") + .set_index("start_date") + .sort_index() + ) + lat_lon_tmin = ( + tmin.query(f"lat == {lat} and lon == {lon}") + .set_index("start_date") + .sort_index() + ) + lat_lon_tmax = ( + tmax.query(f"lat == {lat} and lon == {lon}") + .set_index("start_date") + .sort_index() + ) + + yield dict( + item_id=f"{lat}_{lon}", + start=lat_lon_precip[11323:].index[0], + target=np.stack( + [ + lat_lon_precip[11323:]["precip"].to_numpy(), + lat_lon_tmp2m["tmp2m"].to_numpy(), + lat_lon_tmin["tmin"].to_numpy(), + lat_lon_tmax["tmax"].to_numpy(), + ], + axis=0, + ), + freq="D", + ) + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence(Sequence(Value("float32")), length=4), + ) + ) + + return gen_func, features + + +class SubseasonalDatasetBuilder(LOTSADatasetBuilder): + dataset_list = [ + "subseasonal", + "subseasonal_precip", + ] + dataset_type_map = defaultdict(lambda: MultiSampleTimeSeriesDatasetWithIndex) + dataset_load_func_map = defaultdict( + lambda: partial( + MultiSampleTimeSeriesDatasetWithIndex, max_ts=128, combine_fields=("target",) + ) + ) + + def build_dataset(self, dataset: str): + gen_func, features = { + "subseasonal": _get_subseasonal_gen_func, + "subseasonal_precip": _get_subseasonal_precip_gen_func, + }[dataset]() + + hf_dataset = datasets.Dataset.from_generator( + gen_func, + features=features, + cache_dir=env.HF_CACHE_PATH, + ) + hf_dataset.info.dataset_name = dataset + hf_dataset.save_to_disk(self.storage_path / dataset) diff --git a/OATS/models/tsfm/data/builder/simple.py b/OATS/models/tsfm/data/builder/simple.py new file mode 100644 index 0000000..1dfd19a --- /dev/null +++ b/OATS/models/tsfm/data/builder/simple.py @@ -0,0 +1,307 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import argparse +from dataclasses import dataclass +from itertools import product +from pathlib import Path +from typing import Any, Callable, Generator, Optional + +import datasets +import pandas as pd +from datasets import Features, Sequence, Value +from torch.utils.data import Dataset + +from tsfm.common.env import env +from tsfm.common.typing import GenFunc +from tsfm.data.dataset import EvalDataset, EvalDatasetWithIndex, SampleTimeSeriesType, TimeSeriesDataset, TimeSeriesDatasetWithIndex +from tsfm.data.indexer import HuggingFaceDatasetIndexer +from tsfm.transform import Transformation + +from ._base import DatasetBuilder + + +def _from_long_dataframe( + df: pd.DataFrame, + offset: Optional[int] = None, + date_offset: Optional[pd.Timestamp] = None, +) -> tuple[GenFunc, Features]: + items = df.item_id.unique() + + def example_gen_func() -> Generator[dict[str, Any], None, None]: + for item_id in items: + item_df = df.query(f'item_id == "{item_id}"').drop("item_id", axis=1) + if offset is not None: + item_df = item_df.iloc[:offset] + elif date_offset is not None: + item_df = item_df[item_df.index <= date_offset] + yield { + "target": item_df.to_numpy(), + "start": item_df.index[0], + "freq": pd.infer_freq(item_df.index), + "item_id": item_id, + } + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence(Value("float32")), + ) + ) + + return example_gen_func, features + + +def _from_wide_dataframe( + df: pd.DataFrame, + offset: Optional[int] = None, + date_offset: Optional[pd.Timestamp] = None, +) -> tuple[GenFunc, Features]: + if offset is not None: + df = df.iloc[:offset] + elif date_offset is not None: + df = df[df.index <= date_offset] + + print(df) + + def example_gen_func() -> Generator[dict[str, Any], None, None]: + for i in range(len(df.columns)): + yield { + "target": df.iloc[:, i].to_numpy(), + "start": df.index[0], + "freq": pd.infer_freq(df.index), + "item_id": f"item_{i}", + } + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence(Value("float32")), + ) + ) + + return example_gen_func, features + + +def _from_wide_dataframe_multivariate( + df: pd.DataFrame, + offset: Optional[int] = None, + date_offset: Optional[pd.Timestamp] = None, +) -> tuple[GenFunc, Features]: + if offset is not None: + df = df.iloc[:offset] + elif date_offset is not None: + df = df[df.index <= date_offset] + + def example_gen_func() -> Generator[dict[str, Any], None, None]: + yield { + "target": df.to_numpy().T, + "start": df.index[0], + "freq": pd.infer_freq(df.index), + "item_id": "item_0", + } + + features = Features( + dict( + item_id=Value("string"), + start=Value("timestamp[s]"), + freq=Value("string"), + target=Sequence(Sequence(Value("float32")), length=len(df.columns)), + ) + ) + + return example_gen_func, features + + +@dataclass +class SimpleDatasetBuilder(DatasetBuilder): + dataset: str + weight: float = 1.0 + sample_time_series: Optional[SampleTimeSeriesType] = SampleTimeSeriesType.NONE + storage_path: Path = env.CUSTOM_DATA_PATH + + def __post_init__(self): + self.storage_path = Path(self.storage_path) + + def build_dataset( + self, + file: Path, + dataset_type: str, + offset: Optional[int] = None, + date_offset: Optional[pd.Timestamp] = None, + ): + assert offset is None or date_offset is None, ( + "One or neither offset and date_offset must be specified, but not both. " + f"Got offset: {offset}, date_offset: {date_offset}" + ) + + df = pd.read_csv(file, index_col=0, parse_dates=True) + + if dataset_type == "long": + _from_dataframe = _from_long_dataframe + elif dataset_type == "wide": + _from_dataframe = _from_wide_dataframe + elif dataset_type == "wide_multivariate": + _from_dataframe = _from_wide_dataframe_multivariate + else: + raise ValueError( + f"Unrecognized dataset_type, {dataset_type}." + " Valid options are 'long', 'wide', and 'wide_multivariate'." + ) + + example_gen_func, features = _from_dataframe( + df, offset=offset, date_offset=date_offset + ) + hf_dataset = datasets.Dataset.from_generator( + example_gen_func, features=features + ) + hf_dataset.info.dataset_name = self.dataset + hf_dataset.save_to_disk(str(self.storage_path / self.dataset)) + + def load_dataset( + self, transform_map: dict[str, Callable[..., Transformation]] + ) -> Dataset: + return TimeSeriesDatasetWithIndex( + HuggingFaceDatasetIndexer( + datasets.load_from_disk( + str(self.storage_path / self.dataset), + ) + ), + transform=transform_map[self.dataset](), + dataset_weight=self.weight, + sample_time_series=self.sample_time_series, + ) + + +@dataclass +class SimpleEvalDatasetBuilder(DatasetBuilder): + dataset: str + offset: Optional[int] + windows: Optional[int] + distance: Optional[int] + prediction_length: Optional[int] + context_length: Optional[int] + patch_size: Optional[int] + storage_path: Path = env.CUSTOM_DATA_PATH + sample_time_series: Optional[SampleTimeSeriesType] = SampleTimeSeriesType.NONE + + def __post_init__(self): + self.storage_path = Path(self.storage_path) + + def build_dataset(self, file: Path, dataset_type: str): + df = pd.read_csv(file, index_col=0, parse_dates=True) + + if dataset_type == "long": + _from_dataframe = _from_long_dataframe + elif dataset_type == "wide": + _from_dataframe = _from_wide_dataframe + elif dataset_type == "wide_multivariate": + _from_dataframe = _from_wide_dataframe_multivariate + else: + raise ValueError( + f"Unrecognized dataset_type, {dataset_type}." + " Valid options are 'long', 'wide', and 'wide_multivariate'." + ) + + example_gen_func, features = _from_dataframe(df) + hf_dataset = datasets.Dataset.from_generator( + example_gen_func, features=features + ) + hf_dataset.info.dataset_name = self.dataset + hf_dataset.save_to_disk(self.storage_path / self.dataset) + + def load_dataset( + self, transform_map: dict[str, Callable[..., Transformation]] + ) -> Dataset: + return EvalDatasetWithIndex( + self.windows, + HuggingFaceDatasetIndexer( + datasets.load_from_disk( + str(self.storage_path / self.dataset), + ) + ), + transform=transform_map[self.dataset]( + offset=self.offset, + distance=self.distance, + prediction_length=self.prediction_length, + context_length=self.context_length, + patch_size=self.patch_size, + ), + sample_time_series=self.sample_time_series, + ) + + +def generate_eval_builders( + dataset: str, + offset: int, + eval_length: int, + prediction_lengths: list[int], + context_lengths: list[int], + patch_sizes: list[int], + storage_path: Path = env.CUSTOM_DATA_PATH, +) -> list[SimpleEvalDatasetBuilder]: + return [ + SimpleEvalDatasetBuilder( + dataset=dataset, + offset=offset, + windows=eval_length // pred, + distance=pred, + prediction_length=pred, + context_length=ctx, + patch_size=psz, + storage_path=storage_path, + ) + for pred, ctx, psz in product(prediction_lengths, context_lengths, patch_sizes) + ] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("dataset_name", type=str) + parser.add_argument("file_path", type=str) + parser.add_argument( + "--dataset_type", + type=str, + choices=["wide", "long", "wide_multivariate"], + default="wide", + ) + parser.add_argument( + "--offset", + type=int, + default=None, + ) + parser.add_argument( + "--date_offset", + type=str, + default=None, + ) + parser.add_argument( + "--storage_path", + type=str, + default=env.CUSTOM_DATA_PATH, + ) + + args = parser.parse_args() + + SimpleDatasetBuilder(dataset=args.dataset_name, storage_path=args.storage_path).build_dataset( + file=Path(args.file_path), + dataset_type=args.dataset_type, + offset=args.offset, + date_offset=pd.Timestamp(args.date_offset) if args.date_offset else None, + ) + + if args.offset is not None or args.date_offset is not None: + SimpleEvalDatasetBuilder( + f"{args.dataset_name}_eval", + offset=None, + windows=None, + distance=None, + prediction_length=None, + context_length=None, + patch_size=None, + storage_path=args.storage_path + ).build_dataset(file=Path(args.file_path), dataset_type=args.dataset_type) diff --git a/OATS/models/tsfm/data/dataset.py b/OATS/models/tsfm/data/dataset.py new file mode 100644 index 0000000..3e6299a --- /dev/null +++ b/OATS/models/tsfm/data/dataset.py @@ -0,0 +1,302 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from enum import Enum +from typing import Any + +import numpy as np +from torch.utils.data import Dataset + +from tsfm.common.sampler import Sampler, get_sampler +from tsfm.common.typing import ( + BatchedData, + BatchedDateTime, + BatchedString, + Data, + FlattenedData, + MultivarTimeSeries, + UnivarTimeSeries, +) +from tsfm.data.indexer import Indexer +from tsfm.transform import Transformation + + +class SampleTimeSeriesType(Enum): + NONE = "none" + UNIFORM = "uniform" + PROPORTIONAL = "proportional" + + +class TimeSeriesDataset(Dataset): + def __init__( + self, + indexer: Indexer[dict[str, Any]], + transform: Transformation, + sample_time_series: SampleTimeSeriesType = SampleTimeSeriesType.NONE, + dataset_weight: float = 1.0, + ): + self.indexer = indexer + self.transform = transform + self.sample_time_series = sample_time_series + self.dataset_weight = dataset_weight + + if sample_time_series == SampleTimeSeriesType.NONE: + self.probabilities = None + elif sample_time_series == SampleTimeSeriesType.UNIFORM: + self.probabilities = indexer.get_uniform_probabilities() + elif sample_time_series == SampleTimeSeriesType.PROPORTIONAL: + self.probabilities = indexer.get_proportional_probabilities() + else: + raise ValueError(f"Unknown sample type {sample_time_series}") + + def __getitem__(self, idx: int) -> dict[str, FlattenedData]: + if idx < 0 or idx >= len(self): + raise IndexError( + f"Index {idx} out of range for dataset of length {len(self)}" + ) + + if self.sample_time_series != SampleTimeSeriesType.NONE: + idx = np.random.choice(len(self.probabilities), p=self.probabilities) + + return self.transform(self._flatten_data(self._get_data(idx))) + + @property + def num_ts(self) -> int: + return len(self.indexer) + + def __len__(self) -> int: + return int(np.ceil(self.num_ts * self.dataset_weight)) + + def _get_data(self, idx: int) -> dict[str, Data | BatchedData]: + return self.indexer[idx % self.num_ts] + + @staticmethod + def _flatten_data(data: dict[str, Data]) -> dict[str, FlattenedData]: + return { + k: ( + [v] + if isinstance(v, UnivarTimeSeries) + else list(v) if isinstance(v, MultivarTimeSeries) else v + ) + for k, v in data.items() + } + + +class TimeSeriesDatasetWithIndex(TimeSeriesDataset): + """Enhanced TimeSeriesDataset that preserves original dataset indices for influence function computation.""" + + def __init__( + self, + indexer, + transform, + sample_time_series=SampleTimeSeriesType.NONE, + dataset_weight=1.0, + global_offset=0, + ): + super().__init__(indexer, transform, sample_time_series, dataset_weight) + self.global_offset = global_offset + + def set_global_offset(self, offset: int): + """Set the global offset for this dataset within a concatenated dataset.""" + self.global_offset = offset + + def __getitem__(self, idx: int) -> dict[str, FlattenedData]: + if idx < 0 or idx >= len(self): + raise IndexError( + f"Index {idx} out of range for dataset of length {len(self)}" + ) + + original_idx = idx # Store the original index before any transformations + + if self.sample_time_series != SampleTimeSeriesType.NONE: + # If sampling is used, we still want to track the original sampled index + sampled_idx = np.random.choice(len(self.probabilities), p=self.probabilities) + original_idx = sampled_idx # Track the actual sampled index + idx = sampled_idx + + # Get the data and add the GLOBAL dataset index metadata + data = self._get_data(idx) + # For weighted datasets, use the actual data index (bounded by num_ts) not the virtual index + # This ensures global_idx never exceeds the allocated range based on num_ts + actual_data_idx = idx % self.num_ts # This matches what _get_data uses + global_idx = self.global_offset + actual_data_idx + if global_idx >= self.global_offset + self.num_ts: + print(f"Global idx {global_idx} exceeds allocated range [{self.global_offset} - {self.global_offset + self.num_ts}]") + data['_dataset_idx'] = global_idx + # print(f"Original idx: {original_idx}, Actual data idx: {actual_data_idx}, Global offset: {self.global_offset}, Global idx: {global_idx}") + return self.transform(self._flatten_data(data)) + + +class MultiSampleTimeSeriesDataset(TimeSeriesDataset): + def __init__( + self, + indexer: Indexer[dict[str, Any]], + transform: Transformation, + max_ts: int, + combine_fields: tuple[str, ...], + sample_time_series: SampleTimeSeriesType = SampleTimeSeriesType.NONE, + dataset_weight: float = 1.0, + sampler: Sampler = get_sampler("beta_binomial", a=2, b=5), + ): + super().__init__(indexer, transform, sample_time_series, dataset_weight) + self.max_ts = max_ts + self.combine_fields = combine_fields + self.sampler = sampler + + def _get_data(self, idx: int) -> dict[str, BatchedData]: + n_series = self.sampler(min(self.num_ts, self.max_ts)) + choices = np.concatenate([np.arange(idx), np.arange(idx + 1, self.num_ts)]) + others = np.random.choice(choices, n_series - 1, replace=False) + samples = self.indexer[np.concatenate([[idx], others])] + return samples + + def _flatten_data( + self, samples: dict[str, BatchedData] + ) -> dict[str, FlattenedData]: + for field in samples.keys(): + if field in self.combine_fields: + item = samples[field] + if isinstance(item, list) and isinstance(item[0], MultivarTimeSeries): + samples[field] = [ + univar for sample in samples[field] for univar in sample + ] + elif isinstance(samples[field], BatchedDateTime): + samples[field] = np.asarray(samples[field][0]) + elif isinstance(samples[field], BatchedString): + samples[field] = samples[field][0] + else: + raise AssertionError( + f"Field {field} not accounted for in {self.indexer} MultiSampleTimeSeriesDataset" + ) + return samples + + +class MultiSampleTimeSeriesDatasetWithIndex(MultiSampleTimeSeriesDataset): + """Enhanced MultiSampleTimeSeriesDataset that preserves original dataset indices for influence function computation.""" + + def __init__( + self, + indexer: Indexer[dict[str, Any]], + transform: Transformation, + max_ts: int, + combine_fields: tuple[str, ...], + sample_time_series: SampleTimeSeriesType = SampleTimeSeriesType.NONE, + dataset_weight: float = 1.0, + sampler: Sampler = get_sampler("beta_binomial", a=2, b=5), + global_offset: int = 0, + ): + super().__init__(indexer, transform, max_ts, combine_fields, sample_time_series, dataset_weight, sampler) + self.global_offset = global_offset + + def set_global_offset(self, offset: int): + """Set the global offset for this dataset within a concatenated dataset.""" + self.global_offset = offset + + def __getitem__(self, idx: int) -> dict[str, FlattenedData]: + if idx < 0 or idx >= len(self): + raise IndexError( + f"Index {idx} out of range for dataset of length {len(self)}" + ) + + original_idx = idx # Store the original index before any transformations + + if self.sample_time_series != SampleTimeSeriesType.NONE: + # If sampling is used, we still want to track the original sampled index + sampled_idx = np.random.choice(len(self.probabilities), p=self.probabilities) + original_idx = sampled_idx # Track the actual sampled index + idx = sampled_idx + + # Get the data and add the GLOBAL dataset index metadata + data = self._get_data(idx) + # For weighted datasets, use the actual data index (bounded by num_ts) not the virtual index + # This ensures global_idx never exceeds the allocated range based on num_ts + actual_data_idx = idx % self.num_ts # This matches what _get_data uses + global_idx = self.global_offset + actual_data_idx + if global_idx >= self.global_offset + self.num_ts: + print(f"Global idx {global_idx} exceeds allocated range [{self.global_offset} - {self.global_offset + self.num_ts}]") + data['_dataset_idx'] = global_idx + # print(f"Original idx: {original_idx}, Actual data idx: {actual_data_idx}, Global offset: {self.global_offset}, Global idx: {global_idx}") + return self.transform(self._flatten_data(data)) + + def _flatten_data( + self, samples: dict[str, BatchedData] + ) -> dict[str, FlattenedData]: + """Override to handle the _dataset_idx field for influence function tracking.""" + # Handle the special _dataset_idx field + dataset_idx = None + if '_dataset_idx' in samples: + dataset_idx = samples.pop('_dataset_idx') + + # Process the normal fields using the parent method + flattened = super()._flatten_data(samples) + + # Add back the dataset index field + if dataset_idx is not None: + flattened['_dataset_idx'] = dataset_idx + + return flattened + + +class EvalDataset(TimeSeriesDataset): + def __init__( + self, + windows: int, + indexer: Indexer[dict[str, Any]], + transform: Transformation, + sample_time_series: SampleTimeSeriesType = SampleTimeSeriesType.NONE, + ): + super().__init__( + indexer, + transform, + sample_time_series, + dataset_weight=windows, + ) + + def _get_data(self, idx: int) -> dict[str, Data]: + window, idx = divmod(idx, self.num_ts) + item = self.indexer[idx] + item["window"] = window + return item + + +class EvalDatasetWithIndex(EvalDataset): + """Enhanced EvalDataset that preserves original dataset indices for influence function computation.""" + + def __init__( + self, + windows: int, + indexer: Indexer[dict[str, Any]], + transform: Transformation, + sample_time_series: SampleTimeSeriesType = SampleTimeSeriesType.NONE, + global_offset: int = 0, + ): + super().__init__(windows, indexer, transform, sample_time_series) + self.global_offset = global_offset + + def set_global_offset(self, offset: int): + """Set the global offset for this dataset within a concatenated dataset.""" + self.global_offset = offset + + def __getitem__(self, idx: int) -> dict[str, FlattenedData]: + if idx < 0 or idx >= len(self): + raise IndexError( + f"Index {idx} out of range for dataset of length {len(self)}" + ) + + original_idx = idx # Store the original index before any transformations + + if self.sample_time_series != SampleTimeSeriesType.NONE: + # If sampling is used, we still want to track the original sampled index + sampled_idx = np.random.choice(len(self.probabilities), p=self.probabilities) + original_idx = sampled_idx # Track the actual sampled index + idx = sampled_idx + + # Get the data and add the GLOBAL dataset index metadata + data = self._get_data(idx) + # For EvalDataset, use the actual time series index (from divmod) not the virtual window index + # This ensures global_idx never exceeds the allocated range based on num_ts + window, actual_ts_idx = divmod(idx, self.num_ts) + global_idx = self.global_offset + actual_ts_idx + data['_dataset_idx'] = global_idx + # print(f"Eval - Original idx: {original_idx}, Window: {window}, Actual TS idx: {actual_ts_idx}, Global offset: {self.global_offset}, Global idx: {global_idx}") + return self.transform(self._flatten_data(data)) diff --git a/OATS/models/tsfm/data/indexer/__init__.py b/OATS/models/tsfm/data/indexer/__init__.py new file mode 100644 index 0000000..ab1c2ed --- /dev/null +++ b/OATS/models/tsfm/data/indexer/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from ._base import Indexer +from .hf_dataset_indexer import HuggingFaceDatasetIndexer + +__all__ = ["Indexer", "HuggingFaceDatasetIndexer"] diff --git a/OATS/models/tsfm/data/indexer/_base.py b/OATS/models/tsfm/data/indexer/_base.py new file mode 100644 index 0000000..34396da --- /dev/null +++ b/OATS/models/tsfm/data/indexer/_base.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import abc +from collections.abc import Iterable, Sequence + +import numpy as np + +from tsfm.common.typing import BatchedData, Data + + +class Indexer(abc.ABC, Sequence): + def __init__(self, uniform: bool = False): + self.uniform = uniform + + def check_index(self, idx: int | slice | Iterable[int]): + if isinstance(idx, int): + if idx < 0 or idx >= len(self): + raise IndexError(f"Index {idx} out of bounds for length {len(self)}") + elif isinstance(idx, slice): + if idx.start is not None and idx.start < 0: + raise IndexError( + f"Index {idx.start} out of bounds for length {len(self)}" + ) + if idx.stop is not None and idx.stop >= len(self): + raise IndexError( + f"Index {idx.stop} out of bounds for length {len(self)}" + ) + elif isinstance(idx, Iterable): + idx = np.fromiter(idx, np.int64) + if np.logical_or(idx < 0, idx >= len(self)).any(): + raise IndexError(f"Index out of bounds for length {len(self)}") + else: + raise NotImplementedError(f"Unable to index on type: {type(idx)}") + + def __getitem__( + self, idx: int | slice | Iterable[int] + ) -> dict[str, Data | BatchedData]: + self.check_index(idx) + + if isinstance(idx, int): + item = self._getitem_int(idx) + elif isinstance(idx, slice): + item = self._getitem_slice(idx) + elif isinstance(idx, Iterable): + item = self._getitem_iterable(idx) + else: + raise NotImplementedError(f"Unable to index on type: {type(idx)}") + + return {k: v for k, v in item.items()} + + def _getitem_slice(self, idx: slice) -> dict[str, BatchedData]: + indices = list(range(len(self))[idx]) + return self._getitem_iterable(indices) + + @abc.abstractmethod + def _getitem_int(self, idx: int) -> dict[str, Data]: ... + + @abc.abstractmethod + def _getitem_iterable(self, idx: Iterable[int]) -> dict[str, BatchedData]: ... + + def get_uniform_probabilities(self) -> np.ndarray: + return np.ones(len(self)) / len(self) + + def get_proportional_probabilities(self, field: str = "target") -> np.ndarray: + if self.uniform: + return self.get_uniform_probabilities() + + lengths = np.asarray([sample[field].shape[-1] for sample in self]) + probs = lengths / lengths.sum() + return probs diff --git a/OATS/models/tsfm/data/indexer/hf_dataset_indexer.py b/OATS/models/tsfm/data/indexer/hf_dataset_indexer.py new file mode 100644 index 0000000..a15aa13 --- /dev/null +++ b/OATS/models/tsfm/data/indexer/hf_dataset_indexer.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections.abc import Iterable + +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc +from datasets import Dataset +from datasets.features import Sequence +from datasets.formatting import query_table + +from tsfm.common.typing import BatchedData, Data, MultivarTimeSeries, UnivarTimeSeries + +from ._base import Indexer + + +class HuggingFaceDatasetIndexer(Indexer): + def __init__(self, dataset: Dataset, uniform: bool = False): + super().__init__(uniform=uniform) + self.dataset = dataset + self.features = dict(self.dataset.features) + self.non_seq_cols = [ + name + for name, feat in self.features.items() + if not isinstance(feat, Sequence) + ] + self.seq_cols = [ + name for name, feat in self.features.items() if isinstance(feat, Sequence) + ] + self.dataset.set_format("numpy", columns=self.non_seq_cols) + + def __len__(self) -> int: + return len(self.dataset) + + def _getitem_int(self, idx: int) -> dict[str, Data]: + non_seqs = self.dataset[idx] + pa_subtable = query_table(self.dataset.data, idx, indices=self.dataset._indices) + seqs = { + col: self._pa_column_to_numpy(pa_subtable, col)[0] for col in self.seq_cols + } + return non_seqs | seqs + + def _getitem_iterable(self, idx: Iterable[int]) -> dict[str, BatchedData]: + non_seqs = self.dataset[idx] + pa_subtable = query_table(self.dataset.data, idx, indices=self.dataset._indices) + seqs = { + col: self._pa_column_to_numpy(pa_subtable, col) for col in self.seq_cols + } + return non_seqs | seqs + + def _getitem_slice(self, idx: slice) -> dict[str, BatchedData]: + non_seqs = self.dataset[idx] + pa_subtable = query_table(self.dataset.data, idx, indices=self.dataset._indices) + seqs = { + col: self._pa_column_to_numpy(pa_subtable, col) for col in self.seq_cols + } + return non_seqs | seqs + + def _pa_column_to_numpy( + self, pa_table: pa.Table, column_name: str + ) -> list[UnivarTimeSeries] | list[MultivarTimeSeries]: + pa_array: pa.Array = pa_table.column(column_name) + feature = self.features[column_name] + + if isinstance(pa_array, pa.ChunkedArray): + if isinstance(feature.feature, Sequence): + array = [ + flat_slice.flatten().to_numpy(False).reshape(feat_length, -1) + for chunk in pa_array.chunks + for i in range(len(chunk)) + if (flat_slice := chunk.slice(i, 1).flatten()) + and ( + feat_length := ( + feature.length if feature.length != -1 else len(flat_slice) + ) + ) + ] + else: + array = [ + chunk.slice(i, 1).flatten().to_numpy(False) + for chunk in pa_array.chunks + for i in range(len(chunk)) + ] + elif isinstance(pa_array, pa.ListArray): + if isinstance(feature.feature, Sequence): + flat_slice = pa_array.flatten() + feat_length = ( + feature.length if feature.length != -1 else len(flat_slice) + ) + array = [flat_slice.flatten().to_numpy(False).reshape(feat_length, -1)] + else: + array = [pa_array.flatten().to_numpy(False)] + else: + raise NotImplementedError + + return array + + def get_proportional_probabilities(self, field: str = "target") -> np.ndarray: + if self.uniform: + return self.get_uniform_probabilities() + + if self[0]["target"].ndim > 1: + lengths = pc.list_value_length( + pc.list_flatten(pc.list_slice(self.dataset.data.column(field), 0, 1)) + ) + else: + lengths = pc.list_value_length(self.dataset.data.column(field)) + lengths = lengths.to_numpy() + probs = lengths / lengths.sum() + return probs diff --git a/OATS/models/tsfm/data/loader.py b/OATS/models/tsfm/data/loader.py new file mode 100644 index 0000000..7ffacee --- /dev/null +++ b/OATS/models/tsfm/data/loader.py @@ -0,0 +1,507 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import itertools +from collections import defaultdict, deque +from collections.abc import Callable, Iterator, Sequence +from dataclasses import dataclass, field +from typing import NamedTuple, Optional + +import numpy as np +import torch +from jaxtyping import Bool, Int +from torch.utils.data import DataLoader as TorchDataLoader +from torch.utils.data import Dataset, Sampler, default_collate, default_convert + +from tsfm.common.typing import BatchedSample, Sample + + +@dataclass +class Collate: + max_length: Optional[int] + seq_fields: tuple[str, ...] + pad_func_map: dict[str, Callable[[Sequence[int], np.dtype], np.ndarray]] = field( + default_factory=dict + ) + target_field: str = "target" + + def __post_init__(self): + self.pad_func_map = defaultdict(self._default_pad_func) | self.pad_func_map + + @staticmethod + def _default_pad_func() -> Callable[[Sequence[int], np.dtype], np.ndarray]: + return np.zeros + + def __call__(self, batch: list[Sample]) -> BatchedSample: + raise NotImplementedError + + +class PadCollate(Collate): + def __call__(self, batch: list[Sample]) -> BatchedSample: + assert all( + [ + len(sample[self.target_field]) == len(sample[key]) + for sample in batch + for key in self.seq_fields + ] + ), "All fields must have the same length." + assert all( + [len(sample[self.target_field]) <= self.max_length for sample in batch] + ), f"Sample length must be less than or equal to max_length ({self.max_length})" + + sample_id = self.get_sample_id(batch) + padded_batch = self.pad_samples(batch) + merged_batch = padded_batch | dict(sample_id=sample_id) + return merged_batch + + def pad_samples(self, batch: list[Sample]) -> BatchedSample: + for sample in batch: + length = len(sample[self.target_field]) + for key in self.seq_fields: + sample[key] = torch.cat( + [ + default_convert(sample[key]), + default_convert( + self.pad_func_map[key]( + (self.max_length - length,) + sample[key].shape[1:], + sample[key].dtype, + ) + ), + ] + ) + return default_collate(batch) + + def get_sample_id(self, batch: list[Sample]) -> Int[torch.Tensor, "batch seq"]: + sample_id = torch.stack( + [ + torch.cat([torch.ones(length), torch.zeros(self.max_length - length)]) + for sample in batch + if (length := len(sample[self.target_field])) + ] + ).to(torch.long) + return sample_id + + +class PadCollateWithDatasetIndex(PadCollate): + """Enhanced PadCollate that preserves original dataset indices for influence function computation.""" + + def __call__(self, batch: list[Sample]) -> BatchedSample: + assert all( + [ + len(sample[self.target_field]) == len(sample[key]) + for sample in batch + for key in self.seq_fields + ] + ), "All fields must have the same length." + assert all( + [len(sample[self.target_field]) <= self.max_length for sample in batch] + ), f"Sample length must be less than or equal to max_length ({self.max_length})" + + sample_id = self.get_sample_id(batch) + dataset_index = self.get_dataset_index(batch) + padded_batch = self.pad_samples(batch) + merged_batch = padded_batch | dict(sample_id=sample_id, dataset_index=dataset_index) + return merged_batch + + def get_dataset_index(self, batch: list[Sample]) -> Int[torch.Tensor, "batch seq"]: + """Get original dataset indices for each sample.""" + dataset_indices = [] + + for i, sample in enumerate(batch): + length = len(sample[self.target_field]) + # Get the original dataset index from sample metadata + # This field should always be present now that all dataset types support it + original_idx = sample['_dataset_idx'] + + # Create tensor: original index for real data, -1 for padding + sample_indices = torch.cat([ + torch.full((length,), original_idx, dtype=torch.long), + torch.full((self.max_length - length,), -1, dtype=torch.long) # -1 for padding + ]) + dataset_indices.append(sample_indices) + + return torch.stack(dataset_indices) + + def get_sample_id(self, batch: list[Sample]) -> Int[torch.Tensor, "batch seq"]: + """Keep the original sample_id logic for backward compatibility.""" + sample_id = torch.stack( + [ + torch.cat([torch.ones(length), torch.zeros(self.max_length - length)]) + for sample in batch + if (length := len(sample[self.target_field])) + ] + ).to(torch.long) + return sample_id + + +class PackCollate(Collate): + def __call__(self, batch: list[Sample]) -> BatchedSample: + assert all( + [ + len(sample[self.target_field]) == len(sample[key]) + for sample in batch + for key in self.seq_fields + ] + ), "All fields must have the same length." + assert all( + [len(sample[self.target_field]) <= self.max_length for sample in batch] + ), f"Sample length must be less than or equal to max_length ({self.max_length})" + + packed_batch, bin_spaces = self.first_fit_decreasing_bin_packing(batch) + sample_id = self.get_sample_id(packed_batch, bin_spaces) + merged_batch = self.merge_batch(packed_batch, bin_spaces) | dict( + sample_id=sample_id + ) + return merged_batch + + def first_fit_decreasing_bin_packing( + self, + batch: list[Sample], + ) -> tuple[list[list[Sample]], Int[np.ndarray, "batch"]]: + batch = sorted( + batch, key=lambda sample: len(sample[self.target_field]), reverse=True + ) + bin_spaces: Int[np.ndarray, "batch"] = np.full(len(batch), self.max_length) + packed_batch: list[list[Sample]] = [[]] + + for sample in batch: + length = len(sample[self.target_field]) + criterion: Bool[np.ndarray, "batch"] = bin_spaces - length >= 0 + bin_id: int = criterion.argmax() + if len(packed_batch) <= bin_id: + if len(packed_batch) != bin_id: + raise ValueError + packed_batch.append([]) + + packed_batch[bin_id].append(sample) + bin_spaces[bin_id] -= length + + return packed_batch, bin_spaces[: len(packed_batch)] + + def get_sample_id( + self, batch: list[list[Sample]], bin_spaces: Int[np.ndarray, "batch"] + ) -> Int[torch.Tensor, "batch seq"]: + sample_id = torch.stack( + [ + torch.cat( + [ + torch.ones(len(sample[self.target_field])) * (idx + 1) + for idx, sample in enumerate(bin_) + ] + + [torch.zeros(space)], # padding + ) + for bin_, space in zip(batch, bin_spaces) + ] + ).to(torch.long) + return sample_id + + def merge_batch( + self, batch: list[list[Sample]], bin_spaces: Int[np.ndarray, "batch"] + ) -> BatchedSample: + batch = { + key: torch.stack( + [ + torch.cat( + [default_convert(sample[key]) for sample in bin_] + + [ + default_convert( + self.pad_func_map[key]( + (space,) + bin_[0][key].shape[1:], + bin_[0][key].dtype, + ) + ) + ] + ) + for bin_, space in zip(batch, bin_spaces) + ], + ) + for key in self.seq_fields + } + return batch + + +@dataclass +class SliceableBatchedSample: + data: BatchedSample + + def __post_init__(self): + assert all( + [ + len(self.data[key]) == len(self.data[next(iter(self.data))]) + for key in self.data.keys() + ] + ) + + def __len__(self) -> int: + return len(self.data[next(iter(self.data))]) + + def __getitem__(self, item: slice) -> "SliceableBatchedSample": + return SliceableBatchedSample( + {key: self.data[key][item] for key in self.data.keys()} + ) + + +class Metadata(NamedTuple): + shape: tuple[int, ...] + dtype: torch.dtype + + +@dataclass +class BatchedSampleQueue: + container: deque[SliceableBatchedSample] = field(default_factory=deque) + schema: Optional[dict[str, Metadata]] = None + + def _check_schema(self, batch: SliceableBatchedSample): + if self.schema is None: + self.schema = { + key: Metadata( + shape=tuple(batch.data[key].shape[1:]), dtype=batch.data[key].dtype + ) + for key in batch.data.keys() + } + else: + assert all( + [ + (key in batch.data) + and (metadata.shape == tuple(batch.data[key].shape[1:])) + and (metadata.dtype == batch.data[key].dtype) + for key, metadata in self.schema.items() + ] + ), "batch must have the same schema as the first batch" + + def append(self, batch: SliceableBatchedSample | BatchedSample): + if not isinstance(batch, SliceableBatchedSample): + batch = SliceableBatchedSample(batch) + self._check_schema(batch) + self.container.append(batch) + + def appendleft(self, batch: SliceableBatchedSample | BatchedSample): + if not isinstance(batch, SliceableBatchedSample): + batch = SliceableBatchedSample(batch) + self._check_schema(batch) + self.container.appendleft(batch) + + def popleft(self, size: int) -> BatchedSample: + if size > len(self): + raise ValueError( + f"pop size ({size}) must be less than or equal to queue size ({len(self)})" + ) + + out = BatchedSampleQueue() + while len(out) < size: + curr = self.container.popleft() + if len(out) + len(curr) > size: + self.appendleft(curr[size - len(out) :]) + curr = curr[: size - len(out)] + out.append(curr) + return out.as_batched_data() + + def as_batched_data(self) -> BatchedSample: + return { + key: torch.cat([batch.data[key] for batch in self.container], dim=0) + for key in self.schema.keys() + } + + def __len__(self) -> int: + return sum(len(batch) for batch in self.container) + + +@dataclass +class _BatchedSampleIterator: + dataloader_iter: Iterator[BatchedSample] + batch_size: int + drop_last: bool + fill_last: bool + pad_func_map: dict[str, Callable[[Sequence[int], np.dtype], np.ndarray]] + + def __post_init__(self): + self.queue = BatchedSampleQueue() + + def __iter__(self): + return self + + def __next__(self) -> BatchedSample: + while (data := self._next_batch()) is None: + continue + return data + + def _next_batch(self) -> Optional[BatchedSample]: + if len(self.queue) < self.batch_size: + try: + data = next(self.dataloader_iter) + self.queue.append(data) + return None + except StopIteration: + if self.drop_last or len(self.queue) == 0: + raise StopIteration + elif self.fill_last: + self._pad_queue(self.batch_size - len(self.queue)) + + batch = self.queue.popleft(min(self.batch_size, len(self.queue))) + return batch + + def _pad_queue(self, size: int): + if self.queue.schema is None: + raise ValueError("schema must be set before padding") + padding = { + key: default_convert( + self.pad_func_map[key]((size,) + metadata.shape, np.dtype(np.float32)) + ).to(metadata.dtype) + for key, metadata in self.queue.schema.items() + } + self.queue.append(padding) + + def has_next(self) -> bool: + if len(self.queue) < self.batch_size: + try: + next_batch = next(self) + self.queue.appendleft(next_batch) + except StopIteration: + return False + return True + + +class DataLoader: + def __init__( + self, + dataset: Dataset, + batch_size: int, + batch_size_factor: float = 1.0, + cycle: bool = False, + num_batches_per_epoch: Optional[int] = None, + shuffle: bool = False, + sampler: Optional[Sampler] = None, + num_workers: int = 0, + collate_fn: Optional[Collate] = None, + pin_memory: bool = False, + drop_last: bool = True, + fill_last: bool = False, + worker_init_fn: Optional[Callable[[int], None]] = None, + prefetch_factor: int = 2, + persistent_workers: bool = False, + ): + if num_batches_per_epoch is not None: + assert cycle, "can only set 'num_batches_per_epoch' when 'cycle=True'" + + self.dataloader = TorchDataLoader( + dataset=dataset, + batch_size=int(batch_size * batch_size_factor), + shuffle=shuffle, + sampler=sampler, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=False, + worker_init_fn=worker_init_fn, + prefetch_factor=prefetch_factor if num_workers > 0 else None, + persistent_workers=persistent_workers and num_workers > 0, + ) + self.batch_size = batch_size + self.cycle = cycle + self.num_batches_per_epoch = num_batches_per_epoch + self.collate_fn = collate_fn + self.drop_last = drop_last + self.fill_last = fill_last + self.iterator: Optional[_BatchedSampleIterator] = None + + def __iter__(self) -> Iterator: + if self.iterator is None or not self.iterator.has_next(): + dataloader_iter = ( + iter(self.dataloader) + if not self.cycle + else itertools.chain.from_iterable(itertools.repeat(self.dataloader)) + ) + self.iterator = _BatchedSampleIterator( + dataloader_iter=dataloader_iter, + batch_size=self.batch_size, + drop_last=self.drop_last, + fill_last=self.fill_last, + pad_func_map=self.collate_fn.pad_func_map, + ) + return itertools.islice(self.iterator, self.num_batches_per_epoch) + + @property + def worker_init_fn(self) -> Optional[Callable[[int], None]]: + return self.dataloader.worker_init_fn + + @worker_init_fn.setter + def worker_init_fn(self, worker_init_fn: Optional[Callable[[int], None]]): + self.dataloader.worker_init_fn = worker_init_fn + + +class EvalPadCollate(Collate): + max_context_length = 36 + max_prediction_length = 12 + + def __init__(self, max_context_length: int, max_prediction_length: int, **kwargs): + self.max_context_length = max_context_length + self.max_prediction_length = max_prediction_length + super().__init__(max_length=max_context_length+max_prediction_length, **kwargs) + + def __call__(self, batch: list[Sample]) -> BatchedSample: + assert all( + [ + len(sample[self.target_field]) == len(sample[key]) + for sample in batch + for key in self.seq_fields + ] + ), "All fields must have the same length." + assert all( + [len(sample[self.target_field]) <= self.max_length for sample in batch] + ), f"Sample length must be less than or equal to max_length ({self.max_length})" + + sample_id = self.get_sample_id(batch) + padded_batch = self.pad_samples(batch) + merged_batch = padded_batch | dict(sample_id=sample_id) + return merged_batch + + def pad_samples(self, batch: list[Sample]) -> BatchedSample: + for sample in batch: + length = len(sample[self.target_field]) + prediction_length = np.sum(sample['prediction_mask']) + context_length = length - prediction_length + left_pad = self.max_context_length - context_length + right_pad = self.max_prediction_length - prediction_length + + for key in self.seq_fields: + sample[key] = torch.cat( + [ + default_convert( + self.pad_func_map[key]( + (left_pad,) + sample[key].shape[1:], + sample[key].dtype, + ) + ), + default_convert(sample[key]), + default_convert( + self.pad_func_map[key]( + (right_pad,) + sample[key].shape[1:], + sample[key].dtype, + ) + ), + ] + ) + return default_collate(batch) + + def get_sample_id(self, batch: list[Sample]) -> Int[torch.Tensor, "batch seq"]: + sample_id = [] + + for sample in batch: + length = len(sample[self.target_field]) + prediction_length = np.sum(sample['prediction_mask']) + context_length = length - prediction_length + left_pad = self.max_context_length - context_length + right_pad = self.max_prediction_length - prediction_length + + sample_id.append( + torch.cat( + [ + torch.zeros(left_pad), + torch.ones(length), + torch.zeros(right_pad) + ] + ) + ) + + sample_id = torch.stack(sample_id).to(torch.long) + return sample_id \ No newline at end of file diff --git a/OATS/models/tsfm/distribution/__init__.py b/OATS/models/tsfm/distribution/__init__.py new file mode 100644 index 0000000..bb289e4 --- /dev/null +++ b/OATS/models/tsfm/distribution/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from ._base import DistributionOutput, DistrParamProj +from .laplace import LaplaceFixedScaleOutput, LaplaceOutput +from .log_normal import LogNormalOutput +from .mixture import MixtureOutput +from .negative_binomial import NegativeBinomialOutput +from .normal import NormalFixedScaleOutput, NormalOutput +from .pareto import ParetoFixedAlphaOutput, ParetoOutput +from .student_t import StudentTOutput + +DISTRIBUTION_OUTPUTS = [ + "LaplaceFixedScaleOutput", + "LaplaceOutput", + "LogNormalOutput", + "MixtureOutput", + "NegativeBinomialOutput", + "NormalFixedScaleOutput", + "NormalOutput", + "ParetoFixedAlphaOutput", + "StudentTOutput", +] + +__all__ = ["DistrParamProj", "DistributionOutput"] + DISTRIBUTION_OUTPUTS diff --git a/OATS/models/tsfm/distribution/_base.py b/OATS/models/tsfm/distribution/_base.py new file mode 100644 index 0000000..0039d70 --- /dev/null +++ b/OATS/models/tsfm/distribution/_base.py @@ -0,0 +1,165 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import abc +from collections.abc import Callable +from typing import Any, Optional + +import torch +from einops import rearrange +from jaxtyping import Float, PyTree +from torch import nn +from torch.distributions import AffineTransform, Distribution, TransformedDistribution +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +from tsfm.common.core import abstract_class_property +from tsfm.module.ts_embed import MultiOutSizeLinear + + +# TODO: Replace with tree_map when multiple trees supported +def tree_map_multi( + func: Callable, tree: PyTree[Any, "T"], *other: PyTree[Any, "T"] +) -> PyTree[Any, "T"]: + leaves, treespec = tree_flatten(tree) + other_leaves = [tree_flatten(o)[0] for o in other] + return_leaves = [func(*leaf) for leaf in zip(leaves, *other_leaves)] + return tree_unflatten(return_leaves, treespec) + + +def convert_to_module(tree: PyTree[nn.Module, "T"]) -> PyTree[nn.Module, "T"]: + if isinstance(tree, dict): + return nn.ModuleDict( + {key: convert_to_module(child) for key, child in tree.items()} + ) + if isinstance(tree, (list, tuple)): + return nn.ModuleList([convert_to_module(child) for child in tree]) + return tree + + +def convert_to_container(tree: PyTree[nn.Module, "T"]) -> PyTree[nn.Module, "T"]: + if isinstance(tree, nn.ModuleDict): + return {key: convert_to_container(child) for key, child in tree.items()} + if isinstance(tree, nn.ModuleList): + return [convert_to_container(child) for child in tree] + return tree + + +class DistrParamProj(nn.Module): + def __init__( + self, + in_features: int, + out_features: int | tuple[int, ...] | list[int], + args_dim: PyTree[int, "T"], + domain_map: PyTree[Callable[[torch.Tensor], torch.Tensor], "T"], + proj_layer: Callable[..., nn.Module] = nn.Linear, + **kwargs: Any, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.args_dim = args_dim + self.domain_map = domain_map + self.proj = convert_to_module( + tree_map( + lambda dim: ( + proj_layer(in_features, dim * out_features) + if isinstance(out_features, int) + else proj_layer( + in_features, + tuple(dim * of for of in out_features), + dim=dim, + **kwargs, + ) + ), + args_dim, + ) + ) + self.out_size = ( + out_features if isinstance(out_features, int) else max(out_features) + ) + + def forward(self, *args) -> PyTree[Float[torch.Tensor, "*batch out dim"], "T"]: + params_unbounded = tree_map( + lambda proj: rearrange( + proj(*args), + "... (dim out_size) -> ... out_size dim", + out_size=self.out_size, + ), + convert_to_container(self.proj), + ) + params = tree_map_multi( + lambda func, inp: func(inp), self.domain_map, params_unbounded + ) + return params + + +class AffineTransformed(TransformedDistribution): + def __init__( + self, + base_dist: Distribution, + loc: Optional[torch.Tensor | float] = None, + scale: Optional[torch.Tensor | float] = None, + validate_args: Optional[bool] = None, + ): + self.loc = loc if loc is not None else 0.0 + self.scale = scale if scale is not None else 1.0 + super().__init__( + base_dist, + [AffineTransform(loc=self.loc, scale=self.scale)], + validate_args=validate_args, + ) + + @property + def mean(self) -> torch.Tensor: + return self.base_dist.mean * self.scale + self.loc + + @property + def variance(self) -> torch.Tensor: + return self.base_dist.variance * self.scale**2 + + +@abstract_class_property("distr_cls") +class DistributionOutput: + distr_cls: type[Distribution] = NotImplemented + + def distribution( + self, + distr_params: PyTree[torch.Tensor, "T"], + loc: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + validate_args: Optional[bool] = None, + ) -> Distribution: + distr = self._distribution(distr_params, validate_args=validate_args) + if loc is not None or scale is not None: + distr = AffineTransformed(distr, loc=loc, scale=scale) + return distr + + def _distribution( + self, + distr_params: PyTree[torch.Tensor, "T"], + validate_args: Optional[bool] = None, + ) -> Distribution: + return self.distr_cls(**distr_params, validate_args=validate_args) + + @property + @abc.abstractmethod + def args_dim(self) -> PyTree[int, "T"]: ... + + @property + @abc.abstractmethod + def domain_map(self) -> PyTree[Callable[[torch.Tensor], torch.Tensor], "T"]: ... + + def get_param_proj( + self, + in_features: int, + out_features: int | tuple[int, ...] | list[int], + proj_layer: Callable[..., nn.Module] = nn.Linear, + **kwargs: Any, + ) -> nn.Module: + return DistrParamProj( + in_features=in_features, + out_features=out_features, + args_dim=self.args_dim, + domain_map=self.domain_map, + proj_layer=proj_layer, + **kwargs, + ) diff --git a/OATS/models/tsfm/distribution/laplace.py b/OATS/models/tsfm/distribution/laplace.py new file mode 100644 index 0000000..f10822d --- /dev/null +++ b/OATS/models/tsfm/distribution/laplace.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from typing import Callable, Optional + +import torch +from jaxtyping import Float, PyTree +from torch.distributions import Laplace +from torch.nn import functional as F + +from ._base import DistributionOutput + + +class LaplaceOutput(DistributionOutput): + distr_cls = Laplace + args_dim = dict(loc=1, scale=1) + + @property + def domain_map( + self, + ) -> PyTree[ + Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" + ]: + return dict(loc=self._loc, scale=self._scale) + + @staticmethod + def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + return loc.squeeze(-1) + + @staticmethod + def _scale(scale: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + epsilon = torch.finfo(scale.dtype).eps + return F.softplus(scale).clamp_min(epsilon).squeeze(-1) + + +class LaplaceFixedScaleOutput(DistributionOutput): + distr_cls = Laplace + args_dim = dict(loc=1) + + def __init__(self, scale: float = 1e-3): + self.scale = scale + + @property + def domain_map( + self, + ) -> PyTree[ + Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" + ]: + return dict(loc=self._loc) + + @staticmethod + def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + return loc.squeeze(-1) + + def _distribution( + self, + distr_params: PyTree[Float[torch.Tensor, "*batch 1"], "T"], + validate_args: Optional[bool] = None, + ) -> Laplace: + loc = distr_params["loc"] + distr_params["scale"] = torch.as_tensor( + self.scale, dtype=loc.dtype, device=loc.device + ) + return self.distr_cls(**distr_params, validate_args=validate_args) diff --git a/OATS/models/tsfm/distribution/log_normal.py b/OATS/models/tsfm/distribution/log_normal.py new file mode 100644 index 0000000..ac6d43c --- /dev/null +++ b/OATS/models/tsfm/distribution/log_normal.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from typing import Callable + +import torch +from jaxtyping import Float, PyTree +from torch.distributions import LogNormal +from torch.nn import functional as F + +from ._base import DistributionOutput + + +class LogNormalOutput(DistributionOutput): + distr_cls = LogNormal + args_dim = dict(loc=1, scale=1) + + @property + def domain_map( + self, + ) -> PyTree[ + Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" + ]: + return dict(loc=self._loc, scale=self._scale) + + @staticmethod + def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + return loc.squeeze(-1) + + @staticmethod + def _scale(scale: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + epsilon = torch.finfo(scale.dtype).eps + return F.softplus(scale).clamp_min(epsilon).squeeze(-1) diff --git a/OATS/models/tsfm/distribution/mixture.py b/OATS/models/tsfm/distribution/mixture.py new file mode 100644 index 0000000..f17558a --- /dev/null +++ b/OATS/models/tsfm/distribution/mixture.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from functools import reduce +from typing import Callable, Optional + +import torch +from jaxtyping import PyTree +from torch.distributions import Categorical, Distribution, constraints + +from tsfm.common.torch_util import unsqueeze_trailing_dims + +from ._base import DistributionOutput + + +class Mixture(Distribution): + arg_constraints = dict() + has_rsample = False + + def __init__( + self, + weights: Categorical, + components: list[Distribution], + validate_args: Optional[bool] = None, + ): + for comp in components: + comp._validate_args = False + + self.weights = weights + self.components = components + + if not isinstance(weights, Categorical): + raise TypeError("weights must be a Categorical distribution") + + if not all(isinstance(comp, Distribution) for comp in components): + raise TypeError("components must all be instances of Distribution") + + batch_shape = weights.batch_shape + event_shape = components[0].event_shape + if validate_args: + if not all(comp.batch_shape == batch_shape for comp in components): + raise ValueError("components must have the same batch_shape as weights") + if not all(comp.event_shape == event_shape for comp in components): + raise ValueError("components must have the same event_shape") + if weights.logits.shape[-1] != len(components): + raise ValueError( + "number of logits of weights must be equal to number of components, " + f"got {weights.logits.shape[-1]} and {len(components)} respectively" + ) + super().__init__( + batch_shape=batch_shape, + event_shape=event_shape, + validate_args=validate_args, + ) + + def expand(self, batch_shape: torch.Size, _instance=None) -> "Mixture": + new = self._get_checked_instance(Mixture, _instance) + batch_shape = torch.Size(batch_shape) + new.weights = self.weights.expand(batch_shape) + new.components = [comp.expand(batch_shape) for comp in self.components] + super(Mixture, new).__init__( + batch_shape=batch_shape, + event_shape=new.components[0].event_shape, + validate_args=False, + ) + new._validate_args = self._validate_args + return new + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + if self._validate_args: + self._validate_sample(value) + + # Check at least in 1 support + valid = reduce( + torch.logical_or, + (comp.support.check(value) for comp in self.components), + ) + if not valid.all(): + raise ValueError( + "Expected value argument " + f"({type(value).__name__} of shape {tuple(value.shape)}) " + f"to be within the support (one of {[repr(comp.support) for comp in self.components]}) " + f"of the distribution {repr(self)}, " + f"but found invalid values:\n{value}" + ) + + weights_log_probs = self.weights.logits.expand( + value.shape + (len(self.components),) + ) + weights_log_probs = torch.stack(weights_log_probs.unbind(dim=-1)) + components_log_probs = torch.stack( + [ + torch.where( + comp.support.check(value), + comp.log_prob( + torch.where( + comp.support.check(value), + value, + comp.sample(), + ) + ), + float("-inf"), + ) + for comp in self.components + ] + ) + weights_log_probs = torch.where( + torch.isinf(components_log_probs), + 0.0, + weights_log_probs, + ) + return (weights_log_probs + components_log_probs).logsumexp(dim=0) + + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: + with torch.no_grad(): + components_samples = torch.stack( + [comp.sample(sample_shape) for comp in self.components], dim=-1 + ) + weights_sample = unsqueeze_trailing_dims( + self.weights.sample(sample_shape), components_samples.shape + ) + samples = torch.gather( + components_samples, + dim=-1, + index=weights_sample, + ).squeeze(-1) + return samples + + @constraints.dependent_property + def support(self) -> constraints.Constraint: + return constraints.real + + @property + def mean(self) -> torch.Tensor: + weights_probs = torch.stack(self.weights.probs.unbind(dim=-1)) + components_means = torch.stack([comp.mean for comp in self.components]) + return (weights_probs * components_means).sum(dim=0) + + @property + def variance(self) -> torch.Tensor: + # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) + weights_probs = torch.stack(self.weights.probs.unbind(dim=-1)) + components_var = torch.stack([comp.variance for comp in self.components]) + expected_cond_var = (weights_probs * components_var).sum(dim=0) + components_means = torch.stack([comp.mean for comp in self.components]) + var_cond_expectation = (weights_probs * components_means.pow(2.0)).sum( + dim=0 + ) - self.mean.pow(2.0) + return expected_cond_var + var_cond_expectation + + def cdf(self, value: torch.Tensor) -> torch.Tensor: + weights_prob = self.weights.probs + components_cdf = torch.stack([comp.cdf(value) for comp in self.components]) + return (weights_prob * components_cdf).sum(dim=0) + + +class MixtureOutput(DistributionOutput): + distr_cls = Mixture + + def __init__(self, components: list[DistributionOutput]): + self.components = components + + def _distribution( + self, + distr_params: PyTree[torch.Tensor, "T"], + validate_args: Optional[bool] = None, + ) -> Distribution: + return self.distr_cls( + weights=Categorical( + logits=distr_params["weights_logits"], validate_args=validate_args + ), + components=[ + component._distribution(comp_params, validate_args=validate_args) + for component, comp_params in zip( + self.components, distr_params["components"] + ) + ], + validate_args=validate_args, + ) + + @property + def args_dim(self) -> PyTree[int, "T"]: + return dict( + weights_logits=len(self.components), + components=[comp.args_dim for comp in self.components], + ) + + @property + def domain_map(self) -> PyTree[Callable[[torch.Tensor], torch.Tensor], "T"]: + return dict( + weights_logits=lambda x: x, + components=[comp.domain_map for comp in self.components], + ) diff --git a/OATS/models/tsfm/distribution/negative_binomial.py b/OATS/models/tsfm/distribution/negative_binomial.py new file mode 100644 index 0000000..c05f765 --- /dev/null +++ b/OATS/models/tsfm/distribution/negative_binomial.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from typing import Callable, Optional + +import torch +from jaxtyping import Float, PyTree +from torch.distributions import Distribution, Gamma, constraints +from torch.distributions.utils import broadcast_all, lazy_property, logits_to_probs +from torch.nn import functional as F + +from ._base import DistributionOutput + + +class NegativeBinomial(Distribution): + arg_constraints = { + "total_count": constraints.positive, + "logits": constraints.real, + } + support = constraints.nonnegative + has_rsample = False + + def __init__( + self, + total_count: float | torch.Tensor, + logits: float | torch.Tensor, + validate_args: Optional[bool] = None, + ): + ( + self.total_count, + self.logits, + ) = broadcast_all(total_count, logits) + self.total_count = self.total_count.type_as(self.logits) + batch_shape = self.logits.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape: torch.Size, _instance=None) -> "NegativeBinomial": + new = self._get_checked_instance(NegativeBinomial, _instance) + batch_shape = torch.Size(batch_shape) + new.total_count = self.total_count.expand(batch_shape) + new.logits = self.logits.expand(batch_shape) + super(NegativeBinomial, new).__init__( + batch_shape=batch_shape, + validate_args=False, + ) + new._validate_args = self._validate_args + return new + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: + with torch.no_grad(): + sample = torch.poisson( + Gamma( + concentration=self.total_count, + rate=torch.exp(-self.logits), + validate_args=False, + ).sample(sample_shape), + ) + return sample + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + if self._validate_args: + self._validate_sample(value) + log_unnormalized_prob = ( + self.total_count * F.logsigmoid(-self.logits) + + F.logsigmoid(self.logits) * value + ) + log_normalization = self._lbeta(1 + value, self.total_count) + torch.log( + self.total_count + value + ) + return log_unnormalized_prob - log_normalization + + def _lbeta(self, x, y): + return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) + + @property + def mean(self) -> torch.Tensor: + return self.total_count * torch.exp(self.logits) + + @property + def variance(self) -> torch.Tensor: + return self.mean / torch.sigmoid(-self.logits) + + +class NegativeBinomialOutput(DistributionOutput): + distr_cls = NegativeBinomial + args_dim = dict(total_count=1, logits=1) + + @property + def domain_map( + self, + ) -> PyTree[ + Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" + ]: + return dict(total_count=self._total_count, logits=self._logits) + + @staticmethod + def _total_count( + total_count: Float[torch.Tensor, "*batch 1"] + ) -> Float[torch.Tensor, "*batch"]: + return F.softplus(total_count).squeeze(-1) + + @staticmethod + def _logits( + logits: Float[torch.Tensor, "*batch 1"] + ) -> Float[torch.Tensor, "*batch"]: + return logits.squeeze(-1) diff --git a/OATS/models/tsfm/distribution/normal.py b/OATS/models/tsfm/distribution/normal.py new file mode 100644 index 0000000..67ad0e0 --- /dev/null +++ b/OATS/models/tsfm/distribution/normal.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from typing import Callable, Optional + +import torch +from jaxtyping import Float, PyTree +from torch.distributions import Normal +from torch.nn import functional as F + +from ._base import DistributionOutput + + +class NormalOutput(DistributionOutput): + distr_cls = Normal + args_dim = dict(loc=1, scale=1) + + @property + def domain_map(self) -> PyTree[Callable, "T"]: + return dict( + loc=self._loc, + scale=self._scale, + ) + + @staticmethod + def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + return loc.squeeze(-1) + + @staticmethod + def _scale(scale: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + epsilon = torch.finfo(scale.dtype).eps + return F.softplus(scale).clamp_min(epsilon).squeeze(-1) + + +class NormalFixedScaleOutput(DistributionOutput): + distr_cls = Normal + args_dim = dict(loc=1) + + def __init__(self, scale: float = 1e-3): + self.scale = scale + + @property + def domain_map( + self, + ) -> PyTree[ + Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" + ]: + return dict(loc=self._loc) + + @staticmethod + def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + return loc.squeeze(-1) + + def _distribution( + self, + distr_params: PyTree[Float[torch.Tensor, "*batch 1"], "T"], + validate_args: Optional[bool] = None, + ) -> Normal: + loc = distr_params["loc"] + distr_params["scale"] = torch.as_tensor( + self.scale, dtype=loc.dtype, device=loc.device + ) + return self.distr_cls(**distr_params, validate_args=validate_args) diff --git a/OATS/models/tsfm/distribution/pareto.py b/OATS/models/tsfm/distribution/pareto.py new file mode 100644 index 0000000..5f4f26b --- /dev/null +++ b/OATS/models/tsfm/distribution/pareto.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from typing import Callable, Optional + +import torch +from jaxtyping import Float, PyTree +from torch.distributions import Pareto +from torch.nn import functional as F + +from ._base import DistributionOutput + + +class ParetoOutput(DistributionOutput): + distr_cls = Pareto + args_dim = dict(scale=1, alpha=1) + + @property + def domain_map( + self, + ) -> PyTree[ + Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" + ]: + return dict(scale=self._scale, alpha=self._alpha) + + def _scale( + self, scale: Float[torch.Tensor, "*batch 1"] + ) -> Float[torch.Tensor, "*batch"]: + epsilon = torch.finfo(scale.dtype).eps + return F.softplus(scale).clamp_min(epsilon).squeeze(-1) + + def _alpha( + self, alpha: Float[torch.Tensor, "*batch 1"] + ) -> Float[torch.Tensor, "*batch"]: + epsilon = torch.finfo(alpha.dtype).eps + return (2.0 + F.softplus(alpha).clamp_min(epsilon)).squeeze(-1) + + +class ParetoFixedAlphaOutput(DistributionOutput): + distr_cls = Pareto + args_dim = dict(scale=1) + + def __init__(self, alpha: float = 3.0): + assert alpha > 0.0 + self.alpha = alpha + + @property + def domain_map( + self, + ) -> PyTree[ + Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" + ]: + return dict(scale=self._scale) + + def _scale( + self, scale: Float[torch.Tensor, "*batch 1"] + ) -> Float[torch.Tensor, "*batch"]: + epsilon = torch.finfo(scale.dtype).eps + return F.softplus(scale).clamp_min(epsilon).squeeze(-1) + + def _distribution( + self, + distr_params: PyTree[Float[torch.Tensor, "*batch 1"], "T"], + validate_args: Optional[bool] = None, + ) -> Pareto: + scale = distr_params["scale"] + distr_params["alpha"] = torch.as_tensor( + self.alpha, dtype=scale.dtype, device=scale.device + ) + return self.distr_cls(**distr_params, validate_args=validate_args) diff --git a/OATS/models/tsfm/distribution/student_t.py b/OATS/models/tsfm/distribution/student_t.py new file mode 100644 index 0000000..f2fac79 --- /dev/null +++ b/OATS/models/tsfm/distribution/student_t.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from typing import Callable + +import torch +from jaxtyping import Float, PyTree +from torch.distributions import StudentT +from torch.nn import functional as F + +from ._base import DistributionOutput + + +class StudentTOutput(DistributionOutput): + distr_cls = StudentT + args_dim = dict(df=1, loc=1, scale=1) + + @property + def domain_map( + self, + ) -> PyTree[ + Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" + ]: + return dict(df=self._df, loc=self._loc, scale=self._scale) + + @staticmethod + def _df(df: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + return (2.0 + F.softplus(df)).squeeze(-1) + + @staticmethod + def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + return loc.squeeze(-1) + + @staticmethod + def _scale(scale: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: + epsilon = torch.finfo(scale.dtype).eps + return F.softplus(scale).clamp_min(epsilon).squeeze(-1) diff --git a/OATS/models/tsfm/eval_util/__init__.py b/OATS/models/tsfm/eval_util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/OATS/models/tsfm/eval_util/_hf_dataset.py b/OATS/models/tsfm/eval_util/_hf_dataset.py new file mode 100644 index 0000000..1709b57 --- /dev/null +++ b/OATS/models/tsfm/eval_util/_hf_dataset.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from pathlib import Path + +import datasets + +from tsfm.common.env import env + + +class HFDataset: + def __init__(self, dataset_name: str, storage_path: Path = env.CUSTOM_DATA_PATH): + self.hf_dataset = datasets.load_from_disk( + str(storage_path / dataset_name) + ).with_format("numpy") + self.freq = self.hf_dataset[0]["freq"] + self.target_dim = ( + target.shape[-1] + if len((target := self.hf_dataset[0]["target"]).shape) > 1 + else 1 + ) + + def __iter__(self): + for sample in self.hf_dataset: + sample["start"] = sample["start"].item() + yield sample diff --git a/OATS/models/tsfm/eval_util/_lsf_dataset.py b/OATS/models/tsfm/eval_util/_lsf_dataset.py new file mode 100644 index 0000000..bac5e03 --- /dev/null +++ b/OATS/models/tsfm/eval_util/_lsf_dataset.py @@ -0,0 +1,220 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import os + +import numpy as np +import pandas as pd + +from tsfm.common.env import env + + +class LSFDataset: + def __init__( + self, + dataset_name: str, + mode: str = "S", + split: str = "test", + ): + self.dataset_name = dataset_name + self.mode = mode + self.split = split + + if dataset_name in ["ETTh1", "ETTh2"]: + self._load_etth() + elif dataset_name in ["ETTm1", "ETTm2"]: + self._load_ettm() + elif dataset_name == "METR_LA": + self._load_metr_la() + elif dataset_name == "solar": + self._load_solar() + elif dataset_name == "walmart": + self._load_walmart() + elif dataset_name == "electricity": + self._load_custom("electricity/electricity.csv", "h") + elif dataset_name == "weather": + self._load_custom("weather/weather.csv", "10T") + else: + raise ValueError(f"Unknown dataset name: {dataset_name}") + + if mode == "S": + self.target_dim = 1 + self.past_feat_dynamic_real_dim = 0 + elif mode == "M": + self.target_dim = self.data.shape[-1] + self.past_feat_dynamic_real_dim = 0 + elif mode == "MS": + self.target_dim = 1 + self.past_feat_dynamic_real_dim = self.data.shape[-1] - 1 + else: + raise ValueError(f"Unknown mode: {mode}") + + def __iter__(self): + if self.mode == "S": + for i in range(self.data.shape[-1]): + yield { + "target": self.data[:, i], + "start": self.start, + } + elif self.mode == "M": + yield { + "target": self.data.transpose(1, 0), + "start": self.start, + } + elif self.mode == "MS": + for i in range(self.data.shape[-1]): + yield { + "target": self.data[:, i], + "past_feat_dynamic_real": np.concatenate( + [self.data[:, :i], self.data[:, i + 1 :]], axis=1 + ).transpose(1, 0), + "start": self.start, + } + + def scale(self, data, start, end): + train = data[start:end] + mean = train.mean(axis=0) + std = train.std(axis=0) + return (data - mean) / std + + def _load_etth(self): + df = pd.read_csv( + os.path.join(env.LSF_PATH, f"ETT-small/{self.dataset_name}.csv") + ) + + train_length = 8640 + val_length = 2880 + test_length = 2880 + data = self.scale(df[df.columns[1:]], 0, train_length).to_numpy() + if self.split == "train": + self.data = data[:train_length] + self.length = train_length + elif self.split == "val": + self.data = data[: train_length + val_length] + self.length = val_length + elif self.split == "test": + self.data = data[: train_length + val_length + test_length] + self.length = test_length + self.start = pd.to_datetime(df[["date"]].iloc[0].item()) + self.freq = "h" + + def _load_ettm(self): + df = pd.read_csv( + os.path.join(env.LSF_PATH, f"ETT-small/{self.dataset_name}.csv") + ) + + train_length = 34560 + val_length = 11520 + test_length = 11520 + data = self.scale(df[df.columns[1:]], 0, train_length).to_numpy() + if self.split == "train": + self.data = data[:train_length] + self.length = train_length + elif self.split == "val": + self.data = data[: train_length + val_length] + self.length = val_length + elif self.split == "test": + self.data = data[: train_length + val_length + test_length] + self.length = test_length + self.start = pd.to_datetime(df[["date"]].iloc[0].item()) + self.freq = "15T" + + def _load_solar(self): + df = pd.read_csv(os.path.join(env.LSF_PATH, "Solar/solar_AL.txt"), header=None) + data = df.to_numpy().reshape(8760, 6, 137).sum(1) + + train_length = int(len(data) * 0.7) + val_length = int(len(data) * 0.1) + test_length = int(len(data) * 0.2) + + data = self.scale(data, 0, train_length).to_numpy() + if self.split == "train": + self.data = data[:train_length] + self.length = train_length + elif self.split == "val": + self.data = data[: train_length + val_length] + self.length = val_length + elif self.split == "test": + self.data = data[: train_length + val_length + test_length] + self.length = test_length + self.start = pd.to_datetime("2006-01-01") + self.freq = "h" + + def _load_metr_la(self): + df = pd.read_csv(os.path.join(env.LSF_PATH, "METR_LA/METR_LA.dyna")) + data = [] + for id in df.entity_id.unique(): + id_df = df[df.entity_id == id] + data.append(id_df.traffic_speed.to_numpy()) + data = np.stack(data, 1) + + train_length = int(len(data) * 0.7) + val_length = int(len(data) * 0.1) + test_length = int(len(data) * 0.2) + + data = self.scale(data, 0, train_length).to_numpy() + if self.split == "train": + self.data = data[:train_length] + self.length = train_length + elif self.split == "val": + self.data = data[: train_length + val_length] + self.length = val_length + elif self.split == "test": + self.data = data[: train_length + val_length + test_length] + self.length = test_length + self.start = pd.to_datetime("2012-03-01") + self.freq = "5T" + + def _load_walmart(self): + df = pd.read_csv( + os.path.join( + env.LSF_PATH, "walmart-recruiting-store-sales-forecasting/train.csv" + ) + ) + + data = [] + for id, row in df[["Store", "Dept"]].drop_duplicates().iterrows(): + row_df = df.query(f"Store == {row.Store} and Dept == {row.Dept}") + if len(row_df) != 143: + continue + data.append(row_df.Weekly_Sales.to_numpy()) + data = np.stack(data, 1) + + train_length = 143 - 28 - 14 + val_length = 14 + test_length = 28 + data = self.scale(data, 0, train_length) + if self.split == "train": + self.data = data[:train_length] + self.length = train_length + elif self.split == "val": + self.data = data[: train_length + val_length] + self.length = val_length + elif self.split == "test": + self.data = data[: train_length + val_length + test_length] + self.length = test_length + self.start = pd.to_datetime("2010-02-05") + self.freq = "W" + + def _load_custom(self, data_path: str, freq: str): + df = pd.read_csv(os.path.join(env.LSF_PATH, data_path)) + cols = list(df.columns) + cols.remove("OT") + cols.remove("date") + df = df[["date"] + cols + ["OT"]] + data = df[df.columns[1:]] + + train_length = int(len(data) * 0.7) + val_length = int(len(data) * 0.1) + test_length = int(len(data) * 0.2) + data = self.scale(data, 0, train_length).to_numpy() + if self.split == "train": + self.data = data[:train_length] + self.length = train_length + elif self.split == "val": + self.data = data[: train_length + val_length] + self.length = val_length + elif self.split == "test": + self.data = data[: train_length + val_length + test_length] + self.length = test_length + self.start = pd.to_datetime(df[["date"]].iloc[0].item()) + self.freq = freq diff --git a/OATS/models/tsfm/eval_util/_pf_dataset.py b/OATS/models/tsfm/eval_util/_pf_dataset.py new file mode 100644 index 0000000..e4c5970 --- /dev/null +++ b/OATS/models/tsfm/eval_util/_pf_dataset.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd +from gluonts.dataset import DatasetWriter +from gluonts.dataset.common import CategoricalFeatureInfo, MetaData, TrainDatasets + +from tsfm.common.env import env + + +def _load_etth(dataset_name: str, prediction_length: Optional[int] = None): + df = pd.read_csv(os.path.join(env.LSF_PATH, f"ETT-small/{dataset_name}.csv")) + data = df[df.columns[1:]].values + + start = pd.to_datetime(df[["date"]].iloc[0].item()) + freq = "h" + prediction_length = prediction_length or 24 + rolling_evaluations = 7 + return data, start, freq, prediction_length, rolling_evaluations + + +def _load_ettm(dataset_name: str, prediction_length: Optional[int] = None): + df = pd.read_csv(os.path.join(env.LSF_PATH, f"ETT-small/{dataset_name}.csv")) + data = df[df.columns[1:]].values + + start = pd.to_datetime(df[["date"]].iloc[0].item()) + freq = "15T" + prediction_length = prediction_length or 24 * 4 + rolling_evaluations = 7 + return data, start, freq, prediction_length, rolling_evaluations + + +def _load_metr_la(dataset_name, prediction_length: Optional[int] = None): + df = pd.read_csv(os.path.join(env.LSF_PATH, "METR_LA/METR_LA.dyna")) + data = [] + for id in df.entity_id.unique(): + id_df = df[df.entity_id == id] + data.append(id_df.traffic_speed.to_numpy()) + data = np.stack(data, 1) + + start = pd.to_datetime("2012-03-01") + freq = "5T" + prediction_length = prediction_length or 12 * 24 + rolling_evaluations = 7 + return data, start, freq, prediction_length, rolling_evaluations + + +def _load_walmart(dataset_name: str, prediction_length: Optional[int] = None): + df = pd.read_csv( + os.path.join( + env.LSF_PATH, "walmart-recruiting-store-sales-forecasting/train.csv" + ) + ) + + data = [] + for id, row in df[["Store", "Dept"]].drop_duplicates().iterrows(): + row_df = df.query(f"Store == {row.Store} and Dept == {row.Dept}") + if len(row_df) != 143: + continue + data.append(row_df.Weekly_Sales.to_numpy()) + data = np.stack(data, 1) + + start = pd.to_datetime("2010-02-05") + freq = "W" + prediction_length = prediction_length or 8 + rolling_evaluations = 4 + return data, start, freq, prediction_length, rolling_evaluations + + +def _load_jena_weather(dataset_name: str, prediction_length: Optional[int] = None): + df = pd.read_csv(os.path.join(env.LSF_PATH, "weather/weather.csv")) + cols = list(df.columns) + cols.remove("OT") + cols.remove("date") + df = df[["date"] + cols + ["OT"]] + data = df[df.columns[1:]].to_numpy() + + start = pd.to_datetime(df[["date"]].iloc[0].item()) + freq = "10T" + prediction_length = prediction_length or 6 * 24 + rolling_evaluations = 7 + return data, start, freq, prediction_length, rolling_evaluations + + +def _load_istanbul_traffic(dataset_name: str, prediction_length: Optional[int] = None): + df = pd.read_csv( + os.path.join(env.LSF_PATH, "istanbul-traffic-index/istanbul_traffic.csv") + ) + df.datetime = pd.to_datetime(df.datetime) + df = df.set_index("datetime") + df = df.resample("h").mean() + + data = df.values + start = df.index[0] + freq = "h" + prediction_length = prediction_length or 24 + rolling_evaluations = 7 + return data, start, freq, prediction_length, rolling_evaluations + + +def _load_turkey_power(dataset_name: str, prediction_length: Optional[int] = None): + df = pd.read_csv( + os.path.join( + env.LSF_PATH, + "electrical-power-demand-in-turkey/power Generation and consumption.csv", + ) + ) + df.Date_Time = pd.to_datetime(df.Date_Time, format="%d.%m.%Y %H:%M") + df = df.set_index("Date_Time") + + data = df.values + start = df.index[0] + freq = "h" + prediction_length = prediction_length or 24 + rolling_evaluations = 7 + return data, start, freq, prediction_length, rolling_evaluations + + +pf_load_func_map = { + "ETTh1": _load_etth, + "ETTh2": _load_etth, + "ETTm1": _load_ettm, + "ETTm2": _load_ettm, + "METR_LA": _load_metr_la, + "walmart": _load_walmart, + "jena_weather": _load_jena_weather, + "istanbul_traffic": _load_istanbul_traffic, + "turkey_power": _load_turkey_power, +} + + +def generate_pf_dataset( + dataset_path: Path, + dataset_name: str, + dataset_writer: DatasetWriter, + prediction_length: Optional[int] = None, +): + load_func = pf_load_func_map[dataset_name] + data, start, freq, prediction_length, rolling_evaluations = load_func( + dataset_name, prediction_length + ) + + train_ts = [] + for cat in range(data.shape[-1]): + sliced_ts = data[: -prediction_length * rolling_evaluations, cat] + train_ts.append( + { + "target": sliced_ts, + "start": start, + "feat_static_cat": [cat], + "item_id": cat, + } + ) + + test_ts = [] + for window in range(rolling_evaluations - 1, -1, -1): + for cat in range(data.shape[-1]): + sliced_ts = data[: len(data) - prediction_length * window, cat] + test_ts.append( + { + "target": sliced_ts, + "start": start, + "feat_static_cat": [cat], + "item_id": cat, + } + ) + + meta = MetaData( + freq=freq, + feat_static_cat=[ + CategoricalFeatureInfo(name="feat_static_cat_0", cardinality=data.shape[-1]) + ], + prediction_length=prediction_length, + ) + dataset = TrainDatasets(metadata=meta, train=train_ts, test=test_ts) + dataset.save(path_str=str(dataset_path), writer=dataset_writer, overwrite=True) diff --git a/OATS/models/tsfm/eval_util/data.py b/OATS/models/tsfm/eval_util/data.py new file mode 100644 index 0000000..31ddf2b --- /dev/null +++ b/OATS/models/tsfm/eval_util/data.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from functools import partial +from typing import NamedTuple + +import gluonts +from gluonts.dataset.common import _FileDataset +from gluonts.dataset.split import TestData, split + +from tsfm.data.builder.lotsa_v1.gluonts import get_dataset + +from ._hf_dataset import HFDataset +from ._lsf_dataset import LSFDataset +from ._pf_dataset import generate_pf_dataset, pf_load_func_map + +gluonts.dataset.repository.dataset_recipes |= { + k: partial(generate_pf_dataset, dataset_name=k) for k in pf_load_func_map.keys() +} + + +class MetaData(NamedTuple): + freq: str + target_dim: int + prediction_length: int + feat_dynamic_real_dim: int = 0 + past_feat_dynamic_real_dim: int = 0 + split: str = "test" + + +def get_gluonts_val_dataset( + dataset_name: str, + prediction_length: int = None, + mode: str = None, + regenerate: bool = False, +) -> tuple[TestData, MetaData]: + default_prediction_lengths = { + "australian_electricity_demand": 336, + "pedestrian_counts": 24, + } + if prediction_length is None and dataset_name in default_prediction_lengths: + prediction_length = default_prediction_lengths[dataset_name] + + dataset = get_dataset( + dataset_name, prediction_length=prediction_length, regenerate=regenerate + ) + + prediction_length = prediction_length or dataset.metadata.prediction_length + _, test_template = split(dataset.train, offset=-prediction_length) + test_data = test_template.generate_instances(prediction_length) + metadata = MetaData( + freq=dataset.metadata.freq, + target_dim=1, + prediction_length=prediction_length, + split="val", + ) + return test_data, metadata + + +def get_gluonts_test_dataset( + dataset_name: str, + prediction_length: int = None, + mode: str = None, + regenerate: bool = False, +) -> tuple[TestData, MetaData]: + default_prediction_lengths = { + "australian_electricity_demand": 336, + "pedestrian_counts": 24, + } + if prediction_length is None and dataset_name in default_prediction_lengths: + prediction_length = default_prediction_lengths[dataset_name] + + dataset = get_dataset( + dataset_name, prediction_length=prediction_length, regenerate=regenerate + ) + + prediction_length = prediction_length or dataset.metadata.prediction_length + _, test_template = split(dataset.test, offset=-prediction_length) + test_data = test_template.generate_instances(prediction_length) + metadata = MetaData( + freq=dataset.metadata.freq, + target_dim=1, + prediction_length=prediction_length, + split="test", + ) + return test_data, metadata + + +def get_lsf_val_dataset( + dataset_name: str, + prediction_length: int = 96, + mode: str = "S", +) -> tuple[TestData, MetaData]: + lsf_dataset = LSFDataset(dataset_name, mode=mode, split="val") + dataset = _FileDataset( + lsf_dataset, freq=lsf_dataset.freq, one_dim_target=lsf_dataset.target_dim == 1 + ) + _, test_template = split(dataset, offset=-lsf_dataset.length) + test_data = test_template.generate_instances( + prediction_length, + windows=lsf_dataset.length - prediction_length + 1, + distance=1, + ) + metadata = MetaData( + freq=lsf_dataset.freq, + target_dim=lsf_dataset.target_dim, + prediction_length=prediction_length, + past_feat_dynamic_real_dim=lsf_dataset.past_feat_dynamic_real_dim, + split="val", + ) + return test_data, metadata + + +def get_lsf_test_dataset( + dataset_name: str, + prediction_length: int = 96, + mode: str = "S", +) -> tuple[TestData, MetaData]: + lsf_dataset = LSFDataset(dataset_name, mode=mode, split="test") + dataset = _FileDataset( + lsf_dataset, freq=lsf_dataset.freq, one_dim_target=lsf_dataset.target_dim == 1 + ) + _, test_template = split(dataset, offset=-lsf_dataset.length) + test_data = test_template.generate_instances( + prediction_length, + windows=lsf_dataset.length - prediction_length + 1, + distance=1, + ) + metadata = MetaData( + freq=lsf_dataset.freq, + target_dim=lsf_dataset.target_dim, + prediction_length=prediction_length, + past_feat_dynamic_real_dim=lsf_dataset.past_feat_dynamic_real_dim, + split="test", + ) + return test_data, metadata + + +def get_custom_eval_dataset( + dataset_name: str, + offset: int, + windows: int, + distance: int, + prediction_length: int, + mode: None = None, +) -> tuple[TestData, MetaData]: + hf_dataset = HFDataset(dataset_name) + dataset = _FileDataset( + hf_dataset, freq=hf_dataset.freq, one_dim_target=hf_dataset.target_dim == 1 + ) + _, test_template = split(dataset, offset=offset) + test_data = test_template.generate_instances( + prediction_length, + windows=windows, + distance=distance, + ) + metadata = MetaData( + freq=hf_dataset.freq, + target_dim=hf_dataset.target_dim, + prediction_length=prediction_length, + split="test", + ) + return test_data, metadata diff --git a/OATS/models/tsfm/eval_util/evaluation.py b/OATS/models/tsfm/eval_util/evaluation.py new file mode 100644 index 0000000..f373561 --- /dev/null +++ b/OATS/models/tsfm/eval_util/evaluation.py @@ -0,0 +1,267 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +# Modifications Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from collections import ChainMap +from dataclasses import dataclass +from typing import Iterable, List, Optional, Union + +import numpy as np +import pandas as pd +from gluonts.dataset import DataEntry +from gluonts.dataset.split import TestData +from gluonts.ev.ts_stats import seasonal_error +from gluonts.itertools import batcher, prod +from gluonts.model import Forecast, Predictor +from gluonts.time_feature import get_seasonality +from toolz import first, valmap +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +@dataclass +class BatchForecast: + """ + Wrapper around ``Forecast`` objects, that adds a batch dimension + to arrays returned by ``__getitem__``, for compatibility with + ``gluonts.ev``. + """ + + forecasts: List[Forecast] + allow_nan: bool = False + + def __getitem__(self, name): + values = [forecast[name].T for forecast in self.forecasts] + res = np.stack(values, axis=0) + + if np.isnan(res).any(): + if not self.allow_nan: + raise ValueError("Forecast contains NaN values") + + logger.warning("Forecast contains NaN values. Metrics may be incorrect.") + + return res + + +def _get_data_batch( + input_batch: List[DataEntry], + label_batch: List[DataEntry], + forecast_batch: List[Forecast], + seasonality: Optional[int] = None, + mask_invalid_label: bool = True, + allow_nan_forecast: bool = False, +) -> ChainMap: + label_target = np.stack([label["target"] for label in label_batch], axis=0) + if mask_invalid_label: + label_target = np.ma.masked_invalid(label_target) + + other_data = { + "label": label_target, + } + + seasonal_error_values = [] + for input_ in input_batch: + seasonality_entry = seasonality + if seasonality_entry is None: + seasonality_entry = get_seasonality(input_["start"].freqstr) + input_target = input_["target"] + if mask_invalid_label: + input_target = np.ma.masked_invalid(input_target) + seasonal_error_values.append( + seasonal_error( + input_target, + seasonality=seasonality_entry, + time_axis=-1, + ) + ) + other_data["seasonal_error"] = np.array(seasonal_error_values) + + return ChainMap( + other_data, BatchForecast(forecast_batch, allow_nan=allow_nan_forecast) # type: ignore + ) + + +def evaluate_forecasts_raw( + forecasts: Iterable[Forecast], + *, + test_data: TestData, + metrics, + axis: Optional[Union[int, tuple]] = None, + batch_size: int = 100, + mask_invalid_label: bool = True, + allow_nan_forecast: bool = False, + seasonality: Optional[int] = None, +) -> dict: + """ + Evaluate ``forecasts`` by comparing them with ``test_data``, according + to ``metrics``. + + .. note:: This feature is experimental and may be subject to changes. + + The optional ``axis`` arguments controls aggregation of the metrics: + - ``None`` (default) aggregates across all dimensions + - ``0`` aggregates across the dataset + - ``1`` aggregates across the first data dimension (time, in the univariate setting) + - ``2`` aggregates across the second data dimension (time, in the multivariate setting) + + Return results as a dictionary. + """ + label_ndim = first(test_data.label)["target"].ndim + + assert label_ndim in [1, 2] + + if axis is None: + axis = tuple(range(label_ndim + 1)) + if isinstance(axis, int): + axis = (axis,) + + assert all(ax in range(3) for ax in axis) + + evaluators = {} + for metric in metrics: + evaluator = metric(axis=axis) + evaluators[evaluator.name] = evaluator + + index_data = [] + + input_batches = batcher(test_data.input, batch_size=batch_size) + label_batches = batcher(test_data.label, batch_size=batch_size) + forecast_batches = batcher(forecasts, batch_size=batch_size) + + pbar = tqdm() + for input_batch, label_batch, forecast_batch in zip( + input_batches, label_batches, forecast_batches + ): + if 0 not in axis: + index_data.extend( + [(forecast.item_id, forecast.start_date) for forecast in forecast_batch] + ) + + data_batch = _get_data_batch( + input_batch, + label_batch, + forecast_batch, + seasonality=seasonality, + mask_invalid_label=mask_invalid_label, + allow_nan_forecast=allow_nan_forecast, + ) + + for evaluator in evaluators.values(): + evaluator.update(data_batch) + + pbar.update(len(forecast_batch)) + pbar.close() + + metrics_values = { + metric_name: evaluator.get() for metric_name, evaluator in evaluators.items() + } + + if index_data: + metrics_values["__index_0"] = index_data + + return metrics_values + + +def evaluate_forecasts( + forecasts: Iterable[Forecast], + *, + test_data: TestData, + metrics, + axis: Optional[Union[int, tuple]] = None, + batch_size: int = 100, + mask_invalid_label: bool = True, + allow_nan_forecast: bool = False, + seasonality: Optional[int] = None, +) -> pd.DataFrame: + """ + Evaluate ``forecasts`` by comparing them with ``test_data``, according + to ``metrics``. + + .. note:: This feature is experimental and may be subject to changes. + + The optional ``axis`` arguments controls aggregation of the metrics: + - ``None`` (default) aggregates across all dimensions + - ``0`` aggregates across the dataset + - ``1`` aggregates across the first data dimension (time, in the univariate setting) + - ``2`` aggregates across the second data dimension (time, in the multivariate setting) + + Return results as a Pandas ``DataFrame``. + """ + metrics_values = evaluate_forecasts_raw( + forecasts=forecasts, + test_data=test_data, + metrics=metrics, + axis=axis, + batch_size=batch_size, + mask_invalid_label=mask_invalid_label, + allow_nan_forecast=allow_nan_forecast, + seasonality=seasonality, + ) + index0 = metrics_values.pop("__index_0", None) + + metric_shape = metrics_values[first(metrics_values)].shape + if metric_shape == (): + index = [None] + else: + index_arrays = np.unravel_index(range(prod(metric_shape)), metric_shape) + if index0 is not None: + index0_repeated = np.take(index0, indices=index_arrays[0], axis=0) + index_arrays = (*zip(*index0_repeated), *index_arrays[1:]) # type: ignore + index = pd.MultiIndex.from_arrays(index_arrays) + + flattened_metrics = valmap(np.ravel, metrics_values) + + return pd.DataFrame(flattened_metrics, index=index) + + +def evaluate_model( + model: Predictor, + *, + test_data: TestData, + metrics, + axis: Optional[Union[int, tuple]] = None, + batch_size: int = 100, + mask_invalid_label: bool = True, + allow_nan_forecast: bool = False, + seasonality: Optional[int] = None, +) -> pd.DataFrame: + """ + Evaluate ``model`` when applied to ``test_data``, according + to ``metrics``. + + .. note:: This feature is experimental and may be subject to changes. + + The optional ``axis`` arguments controls aggregation of the metrics: + - ``None`` (default) aggregates across all dimensions + - ``0`` aggregates across the dataset + - ``1`` aggregates across the first data dimension (time, in the univariate setting) + - ``2`` aggregates across the second data dimension (time, in the multivariate setting) + + Return results as a Pandas ``DataFrame``. + """ + forecasts = model.predict(test_data.input) + + return evaluate_forecasts( + forecasts=forecasts, + test_data=test_data, + metrics=metrics, + axis=axis, + batch_size=batch_size, + mask_invalid_label=mask_invalid_label, + allow_nan_forecast=allow_nan_forecast, + seasonality=seasonality, + ) diff --git a/OATS/models/tsfm/eval_util/metrics.py b/OATS/models/tsfm/eval_util/metrics.py new file mode 100644 index 0000000..d2ec6dc --- /dev/null +++ b/OATS/models/tsfm/eval_util/metrics.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from functools import partial +from typing import Optional + +from gluonts.ev.aggregations import Mean +from gluonts.ev.metrics import BaseMetricDefinition, DirectMetric +from gluonts.ev.stats import squared_error + + +@dataclass +class MedianMSE(BaseMetricDefinition): + """Mean Squared Error""" + + forecast_type: str = "0.5" + + def __call__(self, axis: Optional[int] = None) -> DirectMetric: + return DirectMetric( + name=f"MSE[{self.forecast_type}]", + stat=partial(squared_error, forecast_type=self.forecast_type), + aggregate=Mean(axis=axis), + ) diff --git a/OATS/models/tsfm/eval_util/plot.py b/OATS/models/tsfm/eval_util/plot.py new file mode 100644 index 0000000..6072db0 --- /dev/null +++ b/OATS/models/tsfm/eval_util/plot.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Iterator, Optional + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from gluonts import maybe +from gluonts.model import Forecast + + +def plot_single( + inp: dict, + label: dict, + forecast: Forecast, + context_length: int, + intervals: tuple[float, ...] = (0.5, 0.9), + ax: Optional[plt.axis] = None, + dim: Optional[int] = None, + name: Optional[str] = None, + show_label: bool = False, +): + plt.close("all") + ax = maybe.unwrap_or_else(ax, plt.gca) + + target = np.concatenate([inp["target"], label["target"]], axis=-1) + start = inp["start"] + if dim is not None: + target = target[dim] + forecast = forecast.copy_dim(dim) + + index = pd.period_range(start, periods=len(target), freq=start.freq) + ax.plot( + index.to_timestamp()[-context_length - forecast.prediction_length :], + target[-context_length - forecast.prediction_length :], + label="target", + color="black", + ) + forecast.plot( + intervals=intervals, + ax=ax, + color="blue", + name=name, + show_label=show_label, + ) + ax.set_xticks(ax.get_xticks()) + ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right") + ax.legend(loc="lower left") + + +def plot_next_multi( + axes: np.ndarray, + input_it: Iterator[dict], + label_it: Iterator[dict], + forecast_it: Iterator[Forecast], + context_length: int, + intervals: tuple[float, ...] = (0.5, 0.9), + dim: Optional[int] = None, + name: Optional[str] = None, + show_label: bool = False, +): + axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes] + for ax, inp, label, forecast in zip(axes, input_it, label_it, forecast_it): + plot_single( + inp, + label, + forecast, + context_length, + intervals=intervals, + ax=ax, + dim=dim, + name=name, + show_label=show_label, + ) diff --git a/OATS/models/tsfm/loss/__init__.py b/OATS/models/tsfm/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/OATS/models/tsfm/loss/packed/__init__.py b/OATS/models/tsfm/loss/packed/__init__.py new file mode 100644 index 0000000..75ef80c --- /dev/null +++ b/OATS/models/tsfm/loss/packed/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from ._base import PackedDistributionLoss, PackedLoss, PackedPointLoss +from .distribution import PackedNLLLoss +from .normalized import ( + PackedNMAELoss, + PackedNMLSELoss, + PackedNMSELoss, + PackedNRMSELoss, + PackedPointNormalizedLoss, + PointNormType, +) +from .percentage_error import PackedMAPELoss, PackedSMAPELoss +from .point import PackedMAELoss, PackedMSELoss, PackedRMSELoss + +__all__ = [ + "PackedDistributionLoss", + "PackedLoss", + "PackedMAELoss", + "PackedMAPELoss", + "PackedMSELoss", + "PackedNLLLoss", + "PackedNMAELoss", + "PackedNMLSELoss", + "PackedNMSELoss", + "PackedNRMSELoss", + "PackedPointLoss", + "PackedPointNormalizedLoss", + "PackedRMSELoss", + "PackedSMAPELoss", + "PointNormType", +] diff --git a/OATS/models/tsfm/loss/packed/_base.py b/OATS/models/tsfm/loss/packed/_base.py new file mode 100644 index 0000000..19acc61 --- /dev/null +++ b/OATS/models/tsfm/loss/packed/_base.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import abc +from typing import Any, Optional + +import torch +from einops import rearrange, reduce +from jaxtyping import Bool, Float, Int +from torch.distributions import Distribution + +from tsfm.common.torch_util import safe_div + + +class PackedLoss(abc.ABC): + def __call__( + self, + pred: Any, + target: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Optional[Bool[torch.Tensor, "*batch seq_len"]], + observed_mask: Optional[Bool[torch.Tensor, "*batch seq_len #dim"]] = None, + sample_id: Optional[Int[torch.Tensor, "*batch seq_len"]] = None, + variate_id: Optional[Int[torch.Tensor, "*batch seq_len"]] = None, + ) -> Float[torch.Tensor, ""]: + if observed_mask is None: + observed_mask = torch.ones_like(target, dtype=torch.bool) + if sample_id is None: + sample_id = torch.zeros_like(prediction_mask, dtype=torch.long) + if variate_id is None: + variate_id = torch.zeros_like(prediction_mask, dtype=torch.long) + + loss = self._loss_func( + pred, target, prediction_mask, observed_mask, sample_id, variate_id + ) + return self.reduce_loss( + loss, prediction_mask, observed_mask, sample_id, variate_id + ) + + @abc.abstractmethod + def _loss_func( + self, + pred: Any, + target: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: ... + + def reduce_loss( + self, + loss: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Optional[Bool[torch.Tensor, "*batch seq_len"]], + observed_mask: Optional[Bool[torch.Tensor, "*batch seq_len #dim"]], + sample_id: Optional[Int[torch.Tensor, "*batch seq_len"]], + variate_id: Optional[Int[torch.Tensor, "*batch seq_len"]], + ) -> Float[torch.Tensor, ""]: + id_mask = torch.logical_and( + torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)), + torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)), + ) + mask = prediction_mask.unsqueeze(-1) * observed_mask + tobs = reduce( + id_mask + * reduce( + mask, + "... seq dim -> ... 1 seq", + "sum", + ), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + nobs = reduce( + id_mask * rearrange(prediction_mask, "... seq -> ... 1 seq"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) * prediction_mask.unsqueeze(-1) + nobs = torch.where(nobs == 0, nobs, 1 / nobs).sum() + loss = safe_div(loss, tobs * nobs) + return (loss * mask).sum() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class PackedPointLoss(PackedLoss): + @abc.abstractmethod + def _loss_func( + self, + pred: Float[torch.Tensor, "*batch seq_len #dim"], + target: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: ... + + +class PackedDistributionLoss(PackedLoss): + @abc.abstractmethod + def _loss_func( + self, + pred: Distribution, + target: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: ... diff --git a/OATS/models/tsfm/loss/packed/distribution.py b/OATS/models/tsfm/loss/packed/distribution.py new file mode 100644 index 0000000..98675a2 --- /dev/null +++ b/OATS/models/tsfm/loss/packed/distribution.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +from jaxtyping import Bool, Float, Int +from torch.distributions import Distribution + +from ._base import PackedDistributionLoss + + +class PackedNLLLoss(PackedDistributionLoss): + def _loss_func( + self, + pred: Distribution, + target: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + return -pred.log_prob(target) diff --git a/OATS/models/tsfm/loss/packed/normalized.py b/OATS/models/tsfm/loss/packed/normalized.py new file mode 100644 index 0000000..d2517e2 --- /dev/null +++ b/OATS/models/tsfm/loss/packed/normalized.py @@ -0,0 +1,269 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import abc +from enum import Enum +from typing import Callable, Optional + +import torch +from einops import reduce +from jaxtyping import Bool, Float, Int + +from tsfm.common.core import abstract_class_property +from tsfm.common.torch_util import safe_div + +from ._base import PackedPointLoss + + +class PointNormType(Enum): + NONE = "none" + ABS_TARGET = "absolute_target" # normalize by mean abs_target for each obs + ABS_TARGET_SQ = "absolute_target_squared" # matfact def of NRMSE/ND + TARGET = "target" # normalize by mean target for each obs + TARGET_SQ = "target_squared" # classical def of NRMSE/NMAE + STD_DEV = ( + "standard_deviation" # normalize by standard deviation of target for each obs + ) + VAR = "variance" # normalize by variance of target for each obs + # MAX_MIN = "max_min" + # IQR = "interquartile_range" + + +@abstract_class_property("error_func") +class PackedPointNormalizedLoss(PackedPointLoss, abc.ABC): + error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = NotImplemented + + def __init__( + self, + normalize: PointNormType = PointNormType.NONE, + correction: int = 1, + epsilon: float = 1e-5, + ): + super().__init__() + self.normalize = PointNormType(normalize) + self.correction = correction + self.epsilon = epsilon + + def _loss_func( + self, + pred: Float[torch.Tensor, "*batch seq_len #dim"], + target: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + loss = self.error_func(pred, target) + denominator = self.denominator_func( + target, observed_mask, sample_id, variate_id + ) + loss = safe_div(loss, denominator) + return loss + + @property + def denominator_func(self) -> Callable: + func_map = { + PointNormType.NONE: self.none_denominator, + PointNormType.ABS_TARGET: self.abs_target_denominator, + PointNormType.ABS_TARGET_SQ: self.abs_target_sq_denominator, + PointNormType.TARGET: self.target_denominator, + PointNormType.TARGET_SQ: self.target_sq_denominator, + PointNormType.STD_DEV: self.std_dev_denominator, + PointNormType.VAR: self.var_denominator, + } + if self.normalize not in func_map: + raise ValueError(f"Invalid normalize type '{self.normalize}'") + return func_map[self.normalize] + + @staticmethod + def none_denominator( + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + return torch.ones_like(target) + + @staticmethod + def reduce_denominator( + value: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + id_mask = torch.logical_and( + torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)), + torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)), + ) + tobs = reduce( + id_mask * reduce(observed_mask, "... seq dim -> ... 1 seq", "sum"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + value = reduce( + id_mask * reduce(value * observed_mask, "... seq dim -> ... 1 seq", "sum"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + value = safe_div(value, tobs) + return value + + def abs_target_denominator( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + return self.reduce_denominator( + target.abs(), observed_mask, sample_id, variate_id + ) + + def abs_target_sq_denominator( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + return torch.pow( + self.reduce_denominator(target.abs(), observed_mask, sample_id, variate_id), + 2, + ) + + def target_denominator( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + return self.reduce_denominator(target, observed_mask, sample_id, variate_id) + + def target_sq_denominator( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + return torch.pow( + self.reduce_denominator(target, observed_mask, sample_id, variate_id), 2 + ) + + def std_dev_denominator( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + var = self.var_denominator(target, observed_mask, sample_id, variate_id) + std_dev = torch.sqrt(var + self.epsilon) + return std_dev + + def var_denominator( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + id_mask = torch.logical_and( + torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)), + torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)), + ) + tobs = reduce( + id_mask * reduce(observed_mask, "... seq dim -> ... 1 seq", "sum"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + loc = reduce( + id_mask * reduce(target * observed_mask, "... seq dim -> ... 1 seq", "sum"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + loc = safe_div(loc, tobs) + var = reduce( + id_mask + * reduce( + ((target - loc) ** 2) * observed_mask, "... seq dim -> ... 1 seq", "sum" + ), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + var = safe_div(var, (tobs - self.correction)) + return var + + +class PackedNMAELoss(PackedPointNormalizedLoss): + error_func = torch.nn.L1Loss(reduction="none") + + +class PackedNMSELoss(PackedPointNormalizedLoss): + error_func = torch.nn.MSELoss(reduction="none") + + +class PackedNRMSELoss(PackedPointNormalizedLoss): + error_func = torch.nn.MSELoss(reduction="none") + + def reduce_loss( + self, + loss: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Optional[Bool[torch.Tensor, "*batch seq_len"]], + observed_mask: Optional[Bool[torch.Tensor, "*batch seq_len #dim"]], + sample_id: Optional[Int[torch.Tensor, "*batch seq_len"]], + variate_id: Optional[Int[torch.Tensor, "*batch seq_len"]], + ) -> Float[torch.Tensor, ""]: + id_mask = torch.logical_and( + torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)), + torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)), + ) + mask = prediction_mask.unsqueeze(-1) * observed_mask + loss = reduce( + id_mask + * reduce( + loss * mask, + "... seq dim -> ... 1 seq", + "sum", + ), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + loss = torch.sqrt(loss + self.epsilon) + tobs = reduce( + id_mask + * reduce( + mask, + "... seq dim -> ... 1 seq", + "sum", + ), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + loss = safe_div(loss, torch.sqrt(tobs)) + + return super().reduce_loss( + loss, prediction_mask, observed_mask, sample_id, variate_id + ) + + +class PackedNMLSELoss(PackedPointNormalizedLoss): + error_func = torch.nn.MSELoss(reduction="none") + + def __init__(self, df: float = 1.0): + super().__init__(PointNormType.VAR) + self.df = df + + def _loss_func( + self, + pred: Float[torch.Tensor, "*batch seq_len #dim"], + target: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + loss = super()._loss_func( + pred, target, prediction_mask, observed_mask, sample_id, variate_id + ) + return torch.log1p(loss / self.df) diff --git a/OATS/models/tsfm/loss/packed/percentage_error.py b/OATS/models/tsfm/loss/packed/percentage_error.py new file mode 100644 index 0000000..be7f00e --- /dev/null +++ b/OATS/models/tsfm/loss/packed/percentage_error.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +from jaxtyping import Bool, Float, Int +from torch.nn import functional as F + +from tsfm.common.torch_util import safe_div + +from ._base import PackedPointLoss + + +class PackedMAPELoss(PackedPointLoss): + def _loss_func( + self, + pred: Float[torch.Tensor, "*batch seq_len #dim"], + target: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + loss = F.l1_loss(pred, target, reduction="none") + loss = safe_div(loss, target.abs()) + return 100 * loss + + +class PackedSMAPELoss(PackedPointLoss): + def _loss_func( + self, + pred: Float[torch.Tensor, "*batch seq_len #dim"], + target: Float[torch.Tensor, "*batch seq_len #dim"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> Float[torch.Tensor, "*batch seq_len #dim"]: + loss = F.l1_loss(pred, target, reduction="none") + loss = safe_div(loss, target.abs() + pred.detach().abs()) + return 200 * loss diff --git a/OATS/models/tsfm/loss/packed/point.py b/OATS/models/tsfm/loss/packed/point.py new file mode 100644 index 0000000..fa5413d --- /dev/null +++ b/OATS/models/tsfm/loss/packed/point.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from .normalized import PackedNMAELoss, PackedNMSELoss, PackedNRMSELoss, PointNormType + + +class PackedMAELoss(PackedNMAELoss): + def __init__(self): + super().__init__(normalize=PointNormType.NONE) + + +class PackedMSELoss(PackedNMSELoss): + def __init__(self): + super().__init__(normalize=PointNormType.NONE) + + +class PackedRMSELoss(PackedNRMSELoss): + def __init__(self): + super().__init__(normalize=PointNormType.NONE) diff --git a/OATS/models/tsfm/model/encoder/__init__.py b/OATS/models/tsfm/model/encoder/__init__.py new file mode 100644 index 0000000..c95d31b --- /dev/null +++ b/OATS/models/tsfm/model/encoder/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from .pretrain import TransformerEncoderPretrain +from .forecast import TransformerEncoderForecast + +__all__ = ["TransformerEncoderPretrain", "TransformerEncoderForecast"] \ No newline at end of file diff --git a/OATS/models/tsfm/model/encoder/forecast.py b/OATS/models/tsfm/model/encoder/forecast.py new file mode 100644 index 0000000..aece8e8 --- /dev/null +++ b/OATS/models/tsfm/model/encoder/forecast.py @@ -0,0 +1,883 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import math +from contextlib import contextmanager +from copy import deepcopy +from typing import Any, Generator, Optional + +import lightning as L +import numpy as np +import torch +from einops import rearrange, reduce, repeat +from gluonts.model import Input, InputSpec +from gluonts.torch import PyTorchPredictor +from gluonts.transform import ( + AddObservedValuesIndicator, + AsNumpyArray, + ExpandDimArray, + TestSplitSampler, + Transformation, +) +from gluonts.transform.split import TFTInstanceSplitter +from jaxtyping import Bool, Float, Int +from torch.distributions import Distribution + +from tsfm.common.torch_util import safe_div +from tsfm.loss.packed import PackedNLLLoss as _PackedNLLLoss + +from .module import BasicModule + + +class SampleNLLLoss(_PackedNLLLoss): + def reduce_loss( + self, + loss: Float[torch.Tensor, "batch seq_len #dim"], + prediction_mask: Optional[Bool[torch.Tensor, "batch seq_len"]], + observed_mask: Optional[Bool[torch.Tensor, "batch seq_len #dim"]], + sample_id: Optional[Int[torch.Tensor, "batch seq_len"]], + variate_id: Optional[Int[torch.Tensor, "batch seq_len"]], + ) -> Float[torch.Tensor, "batch"]: + id_mask = torch.logical_and( + torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)), + torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)), + ) + mask = prediction_mask.unsqueeze(-1) * observed_mask + tobs = reduce( + id_mask + * reduce( + mask, + "... seq dim -> ... 1 seq", + "sum", + ), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + loss = safe_div(loss, tobs) + return (loss * mask).sum(dim=(-1, -2)) + + +class TransformerEncoderForecast(L.LightningModule): + def __init__( + self, + prediction_length: int, + target_dim: int, + feat_dynamic_real_dim: int, + past_feat_dynamic_real_dim: int, + context_length: int, + module_kwargs: Optional[dict[str, Any]] = None, + module: Optional[BasicModule] = None, + patch_size: int | str = "auto", + num_samples: int = 100, + ): + assert (module is not None) or ( + module_kwargs is not None + ), "if module is not provided, module_kwargs is required" + super().__init__() + self.save_hyperparameters(ignore=["module"]) + self.module = BasicModule(**module_kwargs) if module is None else module + self.per_sample_loss_func = SampleNLLLoss() + + @contextmanager + def hparams_context( + self, + prediction_length: Optional[int] = None, + target_dim: Optional[int] = None, + feat_dynamic_real_dim: Optional[int] = None, + past_feat_dynamic_real_dim: Optional[int] = None, + context_length: Optional[int] = None, + patch_size: Optional[int | str] = None, + num_samples: Optional[int] = None, + ) -> Generator["TransformerEncoderForecast", None, None]: + kwargs = { + "prediction_length": prediction_length, + "target_dim": target_dim, + "feat_dynamic_real_dim": feat_dynamic_real_dim, + "past_feat_dynamic_real_dim": past_feat_dynamic_real_dim, + "context_length": context_length, + "patch_size": patch_size, + "num_samples": num_samples, + } + old_hparams = deepcopy(self.hparams) + for kw, arg in kwargs.items(): + if arg is not None: + self.hparams[kw] = arg + + yield self + + for kw in kwargs: + self.hparams[kw] = old_hparams[kw] + + def create_predictor( + self, + batch_size: int, + device: str = "auto", + ) -> PyTorchPredictor: + ts_fields = [] + if self.hparams.feat_dynamic_real_dim > 0: + ts_fields.append("feat_dynamic_real") + ts_fields.append("observed_feat_dynamic_real") + past_ts_fields = [] + if self.hparams.past_feat_dynamic_real_dim > 0: + past_ts_fields.append("past_feat_dynamic_real") + past_ts_fields.append("past_observed_feat_dynamic_real") + instance_splitter = TFTInstanceSplitter( + instance_sampler=TestSplitSampler(), + past_length=self.past_length, + future_length=self.hparams.prediction_length, + observed_value_field="observed_target", + time_series_fields=ts_fields, + past_time_series_fields=past_ts_fields, + ) + return PyTorchPredictor( + input_names=self.prediction_input_names, + prediction_net=self, + batch_size=batch_size, + prediction_length=self.hparams.prediction_length, + input_transform=self.get_default_transform() + instance_splitter, + device=device, + ) + + def describe_inputs(self, batch_size: int = 1) -> InputSpec: + data = { + "past_target": Input( + shape=( + batch_size, + self.past_length, + self.hparams.target_dim, + ), + dtype=torch.float, + ), + "past_observed_target": Input( + shape=( + batch_size, + self.past_length, + self.hparams.target_dim, + ), + dtype=torch.bool, + ), + "past_is_pad": Input( + shape=(batch_size, self.past_length), + dtype=torch.bool, + ), + } + if self.hparams.feat_dynamic_real_dim > 0: + data["feat_dynamic_real"] = Input( + shape=( + batch_size, + self.past_length + self.hparams.prediction_length, + self.hparams.feat_dynamic_real_dim, + ), + dtype=torch.float, + ) + data["observed_feat_dynamic_real"] = Input( + shape=( + batch_size, + self.past_length + self.hparams.prediction_length, + self.hparams.feat_dynamic_real_dim, + ), + dtype=torch.bool, + ) + if self.hparams.past_feat_dynamic_real_dim > 0: + data["past_feat_dynamic_real"] = Input( + shape=( + batch_size, + self.past_length, + self.hparams.past_feat_dynamic_real_dim, + ), + dtype=torch.float, + ) + data["past_observed_feat_dynamic_real"] = Input( + shape=( + batch_size, + self.past_length, + self.hparams.past_feat_dynamic_real_dim, + ), + dtype=torch.bool, + ) + return InputSpec(data=data, zeros_fn=torch.zeros) + + @property + def prediction_input_names(self) -> list[str]: + return list(self.describe_inputs()) + + @property + def training_input_names(self): + return self.prediction_input_names + ["future_target", "future_observed_values"] + + @property + def past_length(self) -> int: + return ( + self.hparams.context_length + self.hparams.prediction_length + if self.hparams.patch_size == "auto" + else self.hparams.context_length + ) + + def context_token_length(self, patch_size: int) -> int: + return math.ceil(self.hparams.context_length / patch_size) + + def prediction_token_length(self, patch_size) -> int: + return math.ceil(self.hparams.prediction_length / patch_size) + + @property + def max_patch_size(self) -> int: + return self.module.patch_size + + def forward( + self, + past_target: Float[torch.Tensor, "batch past_time tgt"], + past_observed_target: Bool[torch.Tensor, "batch past_time tgt"], + past_is_pad: Bool[torch.Tensor, "batch past_time"], + feat_dynamic_real: Optional[Float[torch.Tensor, "batch time feat"]] = None, + observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch time feat"] + ] = None, + past_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + past_observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + num_samples: Optional[int] = None, + ) -> Float[torch.Tensor, "batch sample future_time *tgt"]: + distr = self._get_distr( + self.hparams.patch_size, + past_target, + past_observed_target, + past_is_pad, + feat_dynamic_real, + observed_feat_dynamic_real, + past_feat_dynamic_real, + past_observed_feat_dynamic_real, + ) + preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,))) + return self._format_preds( + self.hparams.patch_size, preds, past_target.shape[-1] + ) + + def _val_loss( + self, + patch_size: int, + target: Float[torch.Tensor, "batch time tgt"], + observed_target: Bool[torch.Tensor, "batch time tgt"], + is_pad: Bool[torch.Tensor, "batch time"], + feat_dynamic_real: Optional[Float[torch.Tensor, "batch time feat"]] = None, + observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch time feat"] + ] = None, + past_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + past_observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + ) -> Float[torch.Tensor, "batch"]: + # convert format + ( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + ) = self._convert( + patch_size, + past_target=target[..., : self.hparams.context_length, :], + past_observed_target=observed_target[..., : self.hparams.context_length, :], + past_is_pad=is_pad[..., : self.hparams.context_length], + future_target=target[..., self.hparams.context_length :, :], + future_observed_target=observed_target[ + ..., self.hparams.context_length :, : + ], + future_is_pad=is_pad[..., self.hparams.context_length :], + feat_dynamic_real=feat_dynamic_real, + observed_feat_dynamic_real=observed_feat_dynamic_real, + past_feat_dynamic_real=past_feat_dynamic_real, + past_observed_feat_dynamic_real=past_observed_feat_dynamic_real, + ) + # get predictions + distr = self.module( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + torch.ones_like(time_id, dtype=torch.long) * patch_size, + ) + val_loss = self.per_sample_loss_func( + pred=distr, + target=target, + prediction_mask=prediction_mask, + observed_mask=observed_mask, + sample_id=sample_id, + variate_id=variate_id, + ) + return val_loss + + def _get_distr( + self, + patch_size: int, + past_target: Float[torch.Tensor, "batch past_time tgt"], + past_observed_target: Bool[torch.Tensor, "batch past_time tgt"], + past_is_pad: Bool[torch.Tensor, "batch past_time"], + feat_dynamic_real: Optional[Float[torch.Tensor, "batch time feat"]] = None, + observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch time feat"] + ] = None, + past_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + past_observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + ) -> Distribution: + # convert format + ( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + ) = self._convert( + patch_size, + past_target, + past_observed_target, + past_is_pad, + feat_dynamic_real=feat_dynamic_real, + observed_feat_dynamic_real=observed_feat_dynamic_real, + past_feat_dynamic_real=past_feat_dynamic_real, + past_observed_feat_dynamic_real=past_observed_feat_dynamic_real, + ) + + # get predictions + distr = self.module( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + torch.ones_like(time_id, dtype=torch.long) * patch_size, + ) + return distr + + @staticmethod + def _patched_seq_pad( + patch_size: int, + x: torch.Tensor, + dim: int, + left: bool = True, + value: Optional[float] = None, + ) -> torch.Tensor: + if dim >= 0: + dim = -x.ndim + dim + pad_length = -x.size(dim) % patch_size + if left: + pad = (pad_length, 0) + else: + pad = (0, pad_length) + pad = (0, 0) * (abs(dim) - 1) + pad + return torch.nn.functional.pad(x, pad, value=value) + + def _generate_time_id( + self, + patch_size: int, + past_observed_target: Bool[torch.Tensor, "batch past_seq tgt"], + ) -> tuple[ + Int[torch.Tensor, "batch past_token"], Int[torch.Tensor, "batch future_token"] + ]: + past_seq_id = reduce( + self._patched_seq_pad(patch_size, past_observed_target, -2, left=True), + "... (seq patch) dim -> ... seq", + "max", + patch=patch_size, + ) + past_seq_id = torch.clamp(past_seq_id.cumsum(dim=-1) - 1, min=0) + batch_shape = " ".join(map(str, past_observed_target.shape[:-2])) + future_seq_id = ( + repeat( + torch.arange( + self.prediction_token_length(patch_size), + device=past_observed_target.device, + ), + f"prediction -> {batch_shape} prediction", + ) + + past_seq_id.max(dim=-1, keepdim=True).values + + 1 + ) + return past_seq_id, future_seq_id + + def _convert( + self, + patch_size: int, + past_target: Float[torch.Tensor, "batch past_time tgt"], + past_observed_target: Bool[torch.Tensor, "batch past_time tgt"], + past_is_pad: Bool[torch.Tensor, "batch past_time"], + future_target: Optional[Float[torch.Tensor, "batch future_time tgt"]] = None, + future_observed_target: Optional[ + Bool[torch.Tensor, "batch future_time tgt"] + ] = None, + future_is_pad: Optional[Bool[torch.Tensor, "batch future_time"]] = None, + feat_dynamic_real: Optional[Float[torch.Tensor, "batch time feat"]] = None, + observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch time feat"] + ] = None, + past_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + past_observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + ) -> tuple[ + Float[torch.Tensor, "batch combine_seq patch"], # target + Bool[torch.Tensor, "batch combine_seq patch"], # observed_mask + Int[torch.Tensor, "batch combine_seq"], # sample_id + Int[torch.Tensor, "batch combine_seq"], # time_id + Int[torch.Tensor, "batch combine_seq"], # variate_id + Bool[torch.Tensor, "batch combine_seq"], # prediction_mask + ]: + batch_shape = past_target.shape[:-2] + device = past_target.device + + target = [] + observed_mask = [] + sample_id = [] + time_id = [] + variate_id = [] + prediction_mask = [] + dim_count = 0 + + past_seq_id, future_seq_id = self._generate_time_id( + patch_size, past_observed_target + ) + + if future_target is None: + future_target = torch.zeros( + batch_shape + + ( + self.hparams.prediction_length, + past_target.shape[-1], + ), + dtype=past_target.dtype, + device=device, + ) + target.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad(patch_size, past_target, -2, left=True), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, future_target, -2, left=False + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + ] + ) + if future_observed_target is None: + future_observed_target = torch.ones( + batch_shape + + ( + self.hparams.prediction_length, + past_observed_target.shape[-1], + ), + dtype=torch.bool, + device=device, + ) + observed_mask.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, past_observed_target, -2, left=True + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, future_observed_target, -2, left=False + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + ] + ) + if future_is_pad is None: + future_is_pad = torch.zeros( + batch_shape + (self.hparams.prediction_length,), + dtype=torch.long, + device=device, + ) + sample_id.extend( + [ + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, past_is_pad, -1, left=True, value=1 + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=past_target.shape[-1], + ), + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, future_is_pad, -1, left=False, value=1 + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=past_target.shape[-1], + ), + ] + ) + time_id.extend( + [past_seq_id] * past_target.shape[-1] + + [future_seq_id] * past_target.shape[-1] + ) + variate_id.extend( + [ + repeat( + torch.arange(past_target.shape[-1], device=device) + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim past)", + past=self.context_token_length(patch_size), + ), + repeat( + torch.arange(past_target.shape[-1], device=device) + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim future)", + future=self.prediction_token_length(patch_size), + ), + ] + ) + dim_count += past_target.shape[-1] + prediction_mask.extend( + [ + torch.zeros( + batch_shape + + (self.context_token_length(patch_size) * past_target.shape[-1],), + dtype=torch.bool, + device=device, + ), + torch.ones( + batch_shape + + ( + self.prediction_token_length(patch_size) + * past_target.shape[-1], + ), + dtype=torch.bool, + device=device, + ), + ] + ) + + if feat_dynamic_real is not None: + if observed_feat_dynamic_real is None: + raise ValueError( + "observed_feat_dynamic_real must be provided if feat_dynamic_real is provided" + ) + + target.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + feat_dynamic_real[ + ..., : self.hparams.context_length, : + ], + -2, + left=True, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + feat_dynamic_real[ + ..., self.hparams.context_length :, : + ], + -2, + left=False, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + ] + ) + observed_mask.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + observed_feat_dynamic_real[ + ..., : self.hparams.context_length, : + ], + -2, + left=True, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + observed_feat_dynamic_real[ + ..., self.hparams.context_length :, : + ], + -2, + left=False, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + ] + ) + sample_id.extend( + [ + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, past_is_pad, -1, left=True + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=feat_dynamic_real.shape[-1], + ), + torch.ones( + batch_shape + + ( + self.prediction_token_length(patch_size) + * feat_dynamic_real.shape[-1], + ), + dtype=torch.long, + device=device, + ), + ] + ) + time_id.extend( + [past_seq_id] * feat_dynamic_real.shape[-1] + + [future_seq_id] * feat_dynamic_real.shape[-1] + ) + variate_id.extend( + [ + repeat( + torch.arange(feat_dynamic_real.shape[-1], device=device) + + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim past)", + past=self.context_token_length(patch_size), + ), + repeat( + torch.arange(feat_dynamic_real.shape[-1], device=device) + + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim future)", + future=self.prediction_token_length(patch_size), + ), + ] + ) + dim_count += feat_dynamic_real.shape[-1] + prediction_mask.extend( + [ + torch.zeros( + batch_shape + + ( + self.context_token_length(patch_size) + * feat_dynamic_real.shape[-1], + ), + dtype=torch.bool, + device=device, + ), + torch.zeros( + batch_shape + + ( + self.prediction_token_length(patch_size) + * feat_dynamic_real.shape[-1], + ), + dtype=torch.bool, + device=device, + ), + ] + ) + + if past_feat_dynamic_real is not None: + if past_observed_feat_dynamic_real is None: + raise ValueError( + "past_observed_feat_dynamic_real must be provided if past_feat_dynamic_real is provided" + ) + target.append( + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, past_feat_dynamic_real, -2, left=True + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ) + ) + observed_mask.append( + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, past_observed_feat_dynamic_real, -2, left=True + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ) + ) + sample_id.append( + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, past_is_pad, -1, left=True + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=past_feat_dynamic_real.shape[-1], + ) + ) + time_id.extend([past_seq_id] * past_feat_dynamic_real.shape[-1]) + + variate_id.append( + repeat( + torch.arange(past_feat_dynamic_real.shape[-1], device=device) + + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim past)", + past=self.context_token_length(patch_size), + ) + ) + dim_count += past_feat_dynamic_real.shape[-1] + prediction_mask.append( + torch.zeros( + batch_shape + + ( + self.context_token_length(patch_size) + * past_feat_dynamic_real.shape[-1], + ), + dtype=torch.bool, + device=device, + ) + ) + + target = torch.cat(target, dim=-2) + observed_mask = torch.cat(observed_mask, dim=-2) + sample_id = torch.cat(sample_id, dim=-1) + time_id = torch.cat(time_id, dim=-1) + variate_id = torch.cat(variate_id, dim=-1) + prediction_mask = torch.cat(prediction_mask, dim=-1) + return ( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + ) + + def _format_preds( + self, + patch_size: int, + preds: Float[torch.Tensor, "sample batch combine_seq patch"], + target_dim: int, + ) -> Float[torch.Tensor, "batch sample future_time *tgt"]: + start = target_dim * self.context_token_length(patch_size) + end = start + target_dim * self.prediction_token_length(patch_size) + preds = preds[..., start:end, :patch_size] + preds = rearrange( + preds, + "sample ... (dim seq) patch -> ... sample (seq patch) dim", + dim=target_dim, + )[..., : self.hparams.prediction_length, :] + return preds.squeeze(-1) + + def get_default_transform(self) -> Transformation: + transform = AsNumpyArray( + field="target", + expected_ndim=1 if self.hparams.target_dim == 1 else 2, + dtype=np.float32, + ) + if self.hparams.target_dim == 1: + transform += ExpandDimArray(field="target", axis=0) + transform += AddObservedValuesIndicator( + target_field="target", + output_field="observed_target", + dtype=bool, + ) + + if self.hparams.feat_dynamic_real_dim > 0: + transform += AsNumpyArray( + field="feat_dynamic_real", + expected_ndim=2, + dtype=np.float32, + ) + transform += AddObservedValuesIndicator( + target_field="feat_dynamic_real", + output_field="observed_feat_dynamic_real", + dtype=bool, + ) + + if self.hparams.past_feat_dynamic_real_dim > 0: + transform += AsNumpyArray( + field="past_feat_dynamic_real", + expected_ndim=2, + dtype=np.float32, + ) + transform += AddObservedValuesIndicator( + target_field="past_feat_dynamic_real", + output_field="past_observed_feat_dynamic_real", + dtype=bool, + ) + return transform diff --git a/OATS/models/tsfm/model/encoder/module.py b/OATS/models/tsfm/model/encoder/module.py new file mode 100644 index 0000000..9b2c9c4 --- /dev/null +++ b/OATS/models/tsfm/model/encoder/module.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from functools import partial + +import torch +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin +from hydra.utils import instantiate +from jaxtyping import Bool, Float, Int +from torch import nn +from torch.distributions import Distribution +from torch.utils._pytree import tree_map + +from tsfm.common.torch_util import mask_fill, packed_attention_mask +from tsfm.distribution import DistributionOutput +from tsfm.module.norm import RMSNorm +from tsfm.module.packed_scaler import PackedNOPScaler, PackedStdScaler +from tsfm.module.position import ( + QueryKeyProjection, + RotaryProjection, +) +from tsfm.module.transformer import TransformerEncoder + + +def encode_distr_output( + distr_output: DistributionOutput, +) -> dict[str, str | float | int]: + def _encode(val): + if not isinstance(val, DistributionOutput): + return val + + return { + "_target_": f"{val.__class__.__module__}.{val.__class__.__name__}", + **tree_map(_encode, val.__dict__), + } + + return _encode(distr_output) + + +def decode_distr_output(config: dict[str, str | float | int]) -> DistributionOutput: + return instantiate(config, _convert_="all") + + +class BasicModule( + nn.Module, + PyTorchModelHubMixin, + coders={DistributionOutput: (encode_distr_output, decode_distr_output)}, +): + """Contains components of Moirai to ensure implementation is identical across models""" + + def __init__( + self, + distr_output: DistributionOutput, + d_model: int, + num_layers: int, + patch_size: int, # tuple[int, ...] | list[int] + max_seq_len: int, + attn_dropout_p: float, + dropout_p: float, + num_heads: int = None, + scaling: bool = True, + ): + super().__init__() + self.d_model = d_model + self.num_layers = num_layers + self.patch_size = patch_size + self.max_seq_len = max_seq_len + self.scaling = scaling + + self.mask_encoding = nn.Embedding(num_embeddings=1, embedding_dim=d_model) + self.scaler = PackedStdScaler() if scaling else PackedNOPScaler() + self.in_proj = nn.Linear(patch_size, d_model) + self.encoder = TransformerEncoder( + d_model, + num_layers, + num_heads=num_heads, + pre_norm=True, + attn_dropout_p=attn_dropout_p, + dropout_p=dropout_p, + norm_layer=RMSNorm, + activation=F.silu, + use_glu=True, + use_qk_norm=True, + var_attn_bias_layer=None, + time_qk_proj_layer=partial( + QueryKeyProjection, + proj_layer=RotaryProjection, + kwargs=dict(max_len=max_seq_len), + partial_factor=(0.0, 0.5), + ), + shared_var_attn_bias=False, + shared_time_qk_proj=True, + d_ff=None, + ) + self.distr_output = distr_output + self.param_proj = self.distr_output.get_param_proj(d_model, patch_size) + + def forward( + self, + target: Float[torch.Tensor, "*batch seq_len max_patch"], + observed_mask: Bool[torch.Tensor, "*batch seq_len max_patch"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + time_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + patch_size: Int[torch.Tensor, "*batch seq_len"], + ) -> Distribution: + loc, scale = self.scaler( + target, + observed_mask * ~prediction_mask.unsqueeze(-1), + sample_id, + variate_id, + ) + scaled_target = (target - loc) / scale + reprs = self.in_proj(scaled_target) + masked_reprs = mask_fill(reprs, prediction_mask, self.mask_encoding.weight) + + atten_mask = packed_attention_mask(sample_id) + + reprs = self.encoder( + masked_reprs, + atten_mask, + time_id=time_id, + var_id=variate_id, + ) + + distr_param = self.param_proj(reprs) + distr = self.distr_output.distribution(distr_param, loc=loc, scale=scale) + return distr + \ No newline at end of file diff --git a/OATS/models/tsfm/model/encoder/pretrain.py b/OATS/models/tsfm/model/encoder/pretrain.py new file mode 100644 index 0000000..ab607fe --- /dev/null +++ b/OATS/models/tsfm/model/encoder/pretrain.py @@ -0,0 +1,1953 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import defaultdict +from collections.abc import Callable, Sequence +from typing import Any, Optional + +import lightning as L +import numpy as np +import torch +from jaxtyping import Bool, Float, Int +from torch import nn +from torch.distributions import Distribution +import math +import time +import numpy as np +import random + +from tsfm.loss.packed import ( + PackedDistributionLoss, + PackedLoss, + PackedNLLLoss, +) +from tsfm.module.norm import RMSNorm +from tsfm.module.position import ( + LearnedEmbedding, + LearnedProjection, +) +from tsfm.optim import SchedulerType, get_scheduler +from tsfm.transform import ( + AddObservedMask, + AddTimeIndex, + AddVariateIndex, + EvalCrop_AdaLength, + EvalPad_AdaLength, + EvalMaskedPrediction, + DummyValueImputation, + ExtendMask, + FlatPackCollection, + FlatPackFields, + GetPatchSize, + ImputeTimeSeries, + MaskedPrediction, + PackFields, + PatchCrop, + Patchify, + SampleDimension, + SelectFields, + SequencifyField, + Transformation, +) + +from .module import BasicModule +from tsfm.val.metrics import ( + MSE_mean, + MAE_mean, + MSE_median, + MAE_median, + MASE, + MAPE, + SMAPE, + RMSE, + NRMSE, + ND, + CRPS +) + + + +class TransformerEncoderPretrain(L.LightningModule): + seq_fields: tuple[str, ...] = ( + "target", + "observed_mask", + "time_id", + "variate_id", + "prediction_mask", + "patch_size", + "label", + "label_observed_mask", + ) + train_seq_fields: tuple[str, ...] = ( + "target", + "observed_mask", + "time_id", + "variate_id", + "prediction_mask", + "patch_size", + ) + pad_func_map: dict[str, Callable[[Sequence[int], np.dtype], np.ndarray]] = { + "target": np.zeros, + "observed_mask": np.zeros, + "time_id": np.zeros, + "variate_id": np.zeros, + "prediction_mask": np.zeros, + "patch_size": np.zeros, + } + + def __init__( + self, + min_patches: int, + min_mask_ratio: float, + max_mask_ratio: float, + num_training_steps: int, + num_warmup_steps: int, + max_dim: int = 1, + module_kwargs: Optional[dict[str, Any]] = None, + module: Optional[BasicModule] = None, + num_samples: int = 100, + beta1: float = 0.9, + beta2: float = 0.98, + loss_func: PackedDistributionLoss = PackedNLLLoss(), + val_metric: Optional[PackedLoss | list[PackedLoss]] = [ + MSE_mean() ,MAE_mean(), MSE_median(), MAE_median(), MASE(), MAPE(), SMAPE(), RMSE(), NRMSE(), ND(), CRPS() + ], + lr: float = 1e-3, + weight_decay: float = 1e-2, + log_on_step: bool = False, + num_low_influence_to_remove: int = 16, + enable_influence_scoring: bool = False, + enable_dataset_contribution_logging: bool = False, + enable_reweighting: bool = False, + add_noise: bool = False, + influence_filter_ratio: float = 0.7, + use_cosine_similarity: bool = False, + select_from_generated: bool = False, + generate_after_epoch: int = 0, + mixup: bool = False, + ): + assert (module is not None) or ( + module_kwargs is not None + ), "if module is not provided, module_kwargs is required" + assert ( + num_warmup_steps <= num_training_steps + ), f"num_warmup_steps ({num_warmup_steps}) should be <= num_training_steps ({num_training_steps})." + super().__init__() + self.save_hyperparameters(ignore=["module"]) + self.module = BasicModule(**module_kwargs) if module is None else module + self.influence_scores = {} + self.recommended_weights = {} # Initialize for recommended weights-based filtering + self.threshold = 4000 + self.generation_model = None # Placeholder for generation model + self.generation_data = None # Placeholder for generation data + self.cache_val_batch = None # This is used to cache the validation batch for TS influence scoring + + def forward( + self, + target: Float[torch.Tensor, "*batch seq_len max_patch"], + observed_mask: Bool[torch.Tensor, "*batch seq_len max_patch"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + time_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + patch_size: Int[torch.Tensor, "*batch seq_len"], + prediction_mask: Int[torch.Tensor, "*batch seq_len"], + ) -> Distribution: + output = self.module( + target=target, + observed_mask=observed_mask, + sample_id=sample_id, + time_id=time_id, + variate_id=variate_id, + prediction_mask=prediction_mask, + patch_size=patch_size, + ) + return output + + def infer( + self, + target: Float[torch.Tensor, "*batch seq_len max_patch"], + observed_mask: Bool[torch.Tensor, "*batch seq_len max_patch"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + time_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + patch_size: Int[torch.Tensor, "*batch seq_len"], + prediction_mask: Int[torch.Tensor, "*batch seq_len"], + ) -> Distribution: + distr = self.forward( + target=target, + observed_mask=observed_mask, + sample_id=sample_id, + time_id=time_id, + variate_id=variate_id, + patch_size=patch_size, + prediction_mask=prediction_mask, + ) + + preds = distr.sample(torch.Size((self.hparams.num_samples, ))) # sample batch time features + preds = preds.transpose(0, 1) # batch sample time features + return distr, preds + + def training_step( + self, batch: dict[str, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + # Determine whether to use influence-based or random filtering + current_step = self.global_step + # use_influence_filtering = (current_step % self.hparams.influence_filter_frequency == 0) and self.hparams.enable_influence_scoring + if self.hparams.enable_influence_scoring: + # set use_influence_filtering = 1 by influence_filter_ratio possibility + use_influence_filtering = random.random() < self.hparams.influence_filter_ratio + if self.global_step == 0: + use_influence_filtering = True + else: + use_influence_filtering = False + + # Always apply some form of filtering (influence-based every N steps, random otherwise) + print(f"Step {current_step}: This batch originally has {len(batch['dataset_index'])} samples") + + if use_influence_filtering: + print(f"Step {current_step}: Using influence-based filtering (p={self.hparams.influence_filter_ratio} steps)") + batch, threshold = self._filter_low_influence_samples( + batch, + num_to_remove=self.hparams.num_low_influence_to_remove, + use_influence_scores=True + ) + # threshold = 0 # tmp + + if self.hparams.generate_after_epoch <= current_step: + generated_batch = self._generated_similar_samples(batch) + if self.hparams.select_from_generated: + gen_grad = self._compute_per_sample_gradients_with_indices(generated_batch, generated_batch["_dataset_idx"]) + val_gradients = self.get_validation_gradients_from_trainer(max_val_samples=32) # changed for case-study + # # calculate the influence score of the generated samples + influence_scores_gen = self.compute_influence_scores_batched_on_generated(val_gradients, gen_grad) + + # Select only high-scoring generated samples (above threshold) + keep_indices = [i for i, score in enumerate(influence_scores_gen) if score > threshold] + print(f"Generated {len(influence_scores_gen)} samples, keeping {len(keep_indices)} samples above threshold {threshold}") + + # Filter generated batch to keep only high-scoring samples + if keep_indices: + filtered_generated_batch = {} + for key, value in generated_batch.items(): + if torch.is_tensor(value) and value.shape[0] == len(influence_scores_gen): + filtered_generated_batch[key] = value[keep_indices] + else: + filtered_generated_batch[key] = value + batch = self._merge_batches(batch, filtered_generated_batch) + else: + print("No generated samples above threshold, skipping merge") + else: + # Merge generated samples directly + batch = self._merge_batches(batch, generated_batch) + print(f"Step {current_step}: Merged generated samples, new batch size is {len(batch['dataset_index'])}") + else: + # Determine filtering strategy based on available recommended weights + if hasattr(self, 'recommended_weights') and self.recommended_weights: + print(f"Step {current_step}: Using recommended weights-based filtering") + else: + print(f"Step {current_step}: Using random filtering (no recommended weights available yet)") + + batch, threshold = self._filter_low_influence_samples( + batch, + num_to_remove=self.hparams.num_low_influence_to_remove, + use_influence_scores=False + ) + + # add noise to the batch if self.hparams.add_noise is True + if self.hparams.add_noise: + batch = self._add_noise_to_batch(batch) + + if self.hparams.mixup: + batch = self._mixup_batch(batch) + + if not self.hparams.add_noise and not self.hparams.mixup: + generated_batch = self._generated_similar_samples(batch) + batch = self._merge_batches(batch, generated_batch) + print(f"Step {current_step}: Merged generated samples, new batch size is {len(batch['dataset_index'])}") + + + print(f"Step {current_step}: This batch after filtering has {len(batch['dataset_index'])} samples") + + output = self( + **{field: batch[field] for field in list(self.train_seq_fields) + ["sample_id"]} + ) + loss = self.hparams.loss_func( + pred=output, + target=batch["label"], + observed_mask=batch["label_observed_mask"], + prediction_mask=batch["prediction_mask"], + sample_id=batch["sample_id"], + variate_id=batch["variate_id"], + ) + batch_size = ( + batch["sample_id"].max(dim=1).values.sum() if "sample_id" in batch else None + ) + self.log( + f"train/{self.hparams.loss_func.__class__.__name__}", + loss, + on_step=self.hparams.log_on_step, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + batch_size=batch_size, + rank_zero_only=True, + ) + return loss + + def _merge_batches(self, original_batch, generated_batch): + """Merge the generated batch with the original batch""" + for key in original_batch.keys(): + assert key in generated_batch, f"Key {key} not found in generated batch" + original_batch[key] = torch.cat((original_batch[key], generated_batch[key]), dim=0) + return original_batch + + def _add_noise_to_batch(self, batch): + """Add small noise (0.01 STD) to batch['label'] and batch['target']""" + # noise_std = 0.03 + + # noise = torch.randn_like(batch["label"]) * noise_std + bs = batch["label"].shape[0] + noise = torch.randn_like(batch["label"]) * torch.std(batch["label"], dim=(1,2)).view(bs, 1, 1) * 0.2 + + # Use the same cached noise for both fields + if "label" in batch: + batch["label"] = batch["label"] + noise + + if "target" in batch: + batch["target"] = batch["target"] + noise + + return batch + + def _mixup_batch(self, batch): + """Apply mixup augmentation to 50% of samples in the batch. + + For the replacing samples, change them to be 0.5*sample1 + 0.5*sample2. + sample1/2 are samples randomly picked within the batch. + For label and target, use weighted sum. For other values, use the sample + (from 1 and 2) with longer length. + """ + batch_size = len(batch['dataset_index']) + if batch_size < 2: + return batch # Need at least 2 samples for mixup + + # Select 50% of samples to replace + num_to_replace = batch_size // 2 + replace_indices = torch.randperm(batch_size)[:num_to_replace] + + for idx in replace_indices: + # Randomly pick two samples from the batch (excluding the current one) + sample_indices = torch.randperm(batch_size) + # Ensure we don't pick the same sample twice and not the current index + candidate_indices = [i for i in sample_indices if i != idx][:2] + if len(candidate_indices) < 2: + # If we can't find 2 different samples, use any 2 samples + candidate_indices = sample_indices[:2].tolist() + + idx1, idx2 = candidate_indices[0], candidate_indices[1] + + # For numerical fields (label, target), use weighted sum (0.5 + 0.5) + numerical_fields = ['label', 'target'] + for field in numerical_fields: + if field in batch: + batch[field][idx] = 0.5 * batch[field][idx1] + 0.5 * batch[field][idx2] + + # For other fields, choose the sample with longer sequence length + # We'll use the sample_id field to determine which sample has more data + other_fields = [k for k in batch.keys() if k not in numerical_fields + ['sample_id']] + + # Determine which sample has longer length by comparing sample_id max values + if 'sample_id' in batch: + len1 = batch['sample_id'][idx1].max().item() + len2 = batch['sample_id'][idx2].max().item() + longer_idx = idx1 if len1 >= len2 else idx2 + else: + # Fallback: just use the first sample + longer_idx = idx1 + + # Copy other fields from the sample with longer length + for field in other_fields: + if field in batch: + batch[field][idx] = batch[field][longer_idx].clone() + + return batch + + def _generated_similar_samples(self, batch, num_samples=16): + print("Generating similar samples for influence scoring...") + prompt_list = self._crop_sample_from_patches(batch) + + # load the model and datamodule for generation + import sys + from generation.generate_batch_preview import generate_conditional_batch, load_model + if self.generation_model is None or self.generation_data is None: + ckpt_path = "000060-0.0853.ckpt" + generation_model, self.generation_data = load_model(ckpt_path, seed=0) + object.__setattr__(self, 'generation_model', generation_model) + + # print("prompt list length:", len(prompt_list)) + generated_samples_list = [] + dataset_names_list = [] + subset_id_list = [] + new_prompt_list = [] + for idx in range(len(prompt_list)): + dataset_name = self._get_dataset_name_for_index(batch["_dataset_idx"][idx].item()) + print(dataset_name, + prompt_list[idx].shape if prompt_list is not None else None) + if dataset_name not in self.generation_data.key_list: + print(f"Dataset {dataset_name} not found in generation data, skipping...") + continue + dataset_names_list.append(dataset_name) + new_prompt_list.append(prompt_list[idx]) + + subset_id = self.generation_data.key_list.index(dataset_name) + subset_id_list.append(subset_id) + + generated_samples = generate_conditional_batch( + prompt=new_prompt_list, + subset_id=subset_id_list, + model=self.generation_model, + dataset_name=dataset_names_list, + data_module=self.generation_data, + num_samples=1, # generation 1:1 + ddim=True, + ddim_steps=20, + dataset_idx=(-batch["_dataset_idx"]).tolist(), + ) + + # mv generated_samples to the same device as the model + device = "cuda" + for key, value in generated_samples.items(): + if isinstance(value, torch.Tensor): + generated_samples[key] = value.to(device) + + return generated_samples + + def _crop_sample_from_patches(self, batch, target_length=320, patch_size=32): + """Crop a sample from the patches in the batch to match the target length""" + + # calculate the number of patches to crop + num_patches = target_length // patch_size + + label = batch["label"] # [16, 512, 32] + batch_size, total_patches, patch_dim = label.shape + + # check number of available patches using label_observed_mask + label_observed_mask = batch["label_observed_mask"] # [16, 512, 32] + + # Get the availability mask by checking if any element in the patch dimension is observed + # This assumes that if a patch is available, all elements in that patch should be 1 + patch_available_mask = label_observed_mask.any(dim=-1) # [16, 512] - True where patches are available + + # random sample consecutive patches from the label with length=num_patches + if num_patches > 0: + cropped_samples_list = [] + + for batch_idx in range(batch_size): + # Find available patches for this batch item + available_patches = patch_available_mask[batch_idx] # [512] + num_available = available_patches.sum().item() + + if num_available >= num_patches: + # Find the range of available patches (assuming they are consecutive from start) + available_indices = torch.where(available_patches)[0] # Get indices where patches are available + + if len(available_indices) >= num_patches: + # Check if we can find consecutive patches + max_start_for_consecutive = available_indices[-num_patches].item() if len(available_indices) >= num_patches else 0 + min_start = available_indices[0].item() + + # Randomly select a starting position that ensures num_patches consecutive available patches + if max_start_for_consecutive >= min_start: + start_idx = torch.randint(min_start, max_start_for_consecutive + 1, (1,), device=label.device).item() + else: + start_idx = min_start + + # Extract consecutive patches + end_idx = start_idx + num_patches + patch_indices = torch.arange(start_idx, end_idx, device=label.device) + + # Get the patches for this batch item + selected_patches = label[batch_idx, patch_indices] # [num_patches, patch_dim] + cropped_sample = selected_patches.reshape(-1) # [target_length] + else: + # Not enough available patches, return zeros + cropped_sample = None + else: + # Not enough available patches, return zeros + cropped_sample = None + + cropped_samples_list.append(cropped_sample) + + # Stack all batch items + # cropped_samples = torch.stack(cropped_samples_list, dim=0) # [batch_size, target_length] + + # remove all the None in the cropped_samples_list + cropped_samples_list = [sample for sample in cropped_samples_list if sample is not None] + return cropped_samples_list + else: + # If num_patches is 0 or exceeds available patches, return zeros + return None + + def on_train_batch_start(self, batch, batch_idx): + """Calculate per-example gradients for influence function computation and filter low-influence samples""" + + # Only compute per-example gradients during training + if not self.training: + return + + # Only compute influence scores when influence-based filtering will be used + current_step = self.global_step + # use_influence_filtering = (current_step % self.hparams.influence_filter_frequency == 0) + if self.hparams.enable_influence_scoring: + # set use_influence_filtering = 1 by influence_filter_ratio possibility + use_influence_filtering = random.random() < self.hparams.influence_filter_ratio + if self.global_step == 0: + use_influence_filtering = True + else: + use_influence_filtering = False + + + if not use_influence_filtering and not self.hparams.enable_dataset_contribution_logging: # case + print(f"Step {current_step}: Skipping influence computation (random filtering step)") + return + + # Skip influence scoring if disabled + if not self.hparams.enable_influence_scoring: + print(f"Step {current_step}: Influence scoring disabled - skipping gradient computation") + return + + # Get original dataset indices from the new field + dataset_indices = batch.get("dataset_index", None) + if dataset_indices is None: + print(f"Warning: No dataset_index found in batch {batch_idx}") + print(f"Make sure to use PadCollateWithDatasetIndex and TimeSeriesDatasetWithIndex") + return + + # Get unique dataset indices (filter out padding with -1) + unique_indices = dataset_indices.flatten().unique() + unique_indices = unique_indices[unique_indices >= 0] # Remove padding (-1) + + # print(f"Unique dataset indices in batch {batch_idx}: {unique_indices.tolist()}") + print(f"Unique dataset indices in batch {batch_idx} length: {len(unique_indices)}") + + # Initialize per-example gradient storage if not exists + if not hasattr(self, 'per_example_gradients'): + self.per_example_gradients = {} + + # Initialize influence score history if not exists + if not hasattr(self, 'influence_score_history'): + self.influence_score_history = {} + + # Calculate per-sample gradients using original dataset indices + per_sample_grads = self._compute_per_sample_gradients_with_indices(batch, unique_indices) + + # Store gradients indexed by original dataset index + for i, dataset_idx in enumerate(unique_indices): + dataset_idx_item = dataset_idx.item() + + # Extract gradients for this specific sample + sample_grads = {} + for param_name, grad_batch in per_sample_grads.items(): + if grad_batch is not None and i < len(grad_batch): + sample_grads[param_name] = grad_batch[i].clone().detach() + + # Store in per-example gradient dict + if dataset_idx_item not in self.per_example_gradients: + self.per_example_gradients[dataset_idx_item] = [] + + self.per_example_gradients[dataset_idx_item].append({ + 'gradients': sample_grads, + 'step': self.global_step, + 'epoch': self.current_epoch, + 'loss': None # We don't have outputs yet in on_train_batch_start + }) + + # directly calculate the influence scores on validation gradients + val_gradients = self.get_validation_gradients_from_trainer(max_val_samples=32) # changed for case-study + influence_scores = self.compute_influence_scores_batched(val_gradients) + + # update the influence scores, but append the new scores to the end, don't overwrite the existing scores + for sample_idx, scores in influence_scores.items(): + if sample_idx in self.influence_scores: + self.influence_scores[sample_idx].extend(scores) + else: + self.influence_scores[sample_idx] = scores + + # Clean up old influence scores to keep only recent 4000 steps + self._cleanup_old_influence_scores(keep_recent_steps=4000) # 4000 for case-study + + # Only compute and log dataset contributions if dataset contribution logging is enabled + if self.hparams.enable_dataset_contribution_logging: + # aggregate the influence scores by dataset name + dataset_influence_scores = {} + count_dataset_scores = {} + for sample_idx, scores in self.influence_scores.items(): + for score in scores: + dataset_name = score.get('dataset_name', 'Unknown') + if dataset_name not in dataset_influence_scores: + dataset_influence_scores[dataset_name] = 0 + count_dataset_scores[dataset_name] = 0 + # use running average to aggregate the influence scores + dataset_influence_scores[dataset_name] = (dataset_influence_scores[dataset_name] * count_dataset_scores[dataset_name] + score['influence_score']) / (count_dataset_scores[dataset_name] + 1) + count_dataset_scores[dataset_name] += 1 + + # sort the dataset_influence_scores by the score + dataset_influence_scores = sorted(dataset_influence_scores.items(), key=lambda x: x[1], reverse=True) + + # print the dataset_influence_scores in order + print("=" * 80) + print("DATASET INFLUENCE SCORES:") + for dataset_name, score in dataset_influence_scores: + print(f" {dataset_name:25} | Score: {score:10.4f} | Step: {self.global_step:6d} | Epoch: {self.current_epoch:6d} | Count: {count_dataset_scores[dataset_name]:6d}") + print("=" * 80) + else: + print("Dataset contribution logging disabled (enable_dataset_contribution_logging=False)") + + if self.hparams.enable_reweighting: + # Calculate updated sampling ratios based on dataset influence scores + + scores = np.array([score for _, score in dataset_influence_scores]) + dataset_names = [name for name, _ in dataset_influence_scores] + + # Method 1: Linear scaling (more conservative) + # Normalize to [0.1, 2.0] range to avoid extreme ratios + if scores.max() > scores.min(): + normalized_scores = (scores - scores.min()) / (scores.max() - scores.min()) + linear_ratios = 0.1 + 1.9 * normalized_scores # Scale to [0.1, 2.0] + linear_ratios = linear_ratios / linear_ratios.mean() # Normalize so mean = 1.0 + else: + linear_ratios = np.ones(len(scores)) # All equal if no variation + + # Create sampling ratio dictionary + sampling_ratios = {} + for i, dataset_name in enumerate(dataset_names): + sampling_ratios[dataset_name] = { + 'influence_score': scores[i], + 'linear_ratio': float(linear_ratios[i]), + 'count': count_dataset_scores[dataset_name] + } + + # Store sampling ratios for potential use by external components + self.latest_sampling_ratios = sampling_ratios + + # Option: Use linear ratios as the recommended sampling weights + recommended_weights = {name: info['linear_ratio'] for name, info in sampling_ratios.items()} + print(f"Recommended sampling weights: {recommended_weights}") + + # Save recommended weights as member variable for future filtering rounds + self.recommended_weights = recommended_weights + + # Dataset weights calculation completed - weights not applied to maintain decoupling + else: + print("Dataset reweighting disabled (enable_reweighting=False)") + + # Update influence score history for future filtering + self._update_influence_score_history(influence_scores) + + # clear the per-example gradients + self.clear_per_example_gradients() + + def _filter_low_influence_samples(self, batch, num_to_remove=16, use_influence_scores=True): + """Filter out samples from the current batch using influence scores or random selection""" + + # use_influence_scores = True, then we use influence score filtering + # use_influence_scores = False, then we use recommended weights-based filtering or random filtering + + if "dataset_index" not in batch: + print("Warning: No dataset_index found in batch, skipping filtering") + return batch + + dataset_indices = batch["dataset_index"] + batch_size, seq_len = dataset_indices.shape + + # Get unique dataset indices in this batch + unique_indices = dataset_indices.flatten().unique() + unique_indices = unique_indices[unique_indices >= 0] # Remove padding (-1) + + if len(unique_indices) <= num_to_remove: + print(f"Batch has only {len(unique_indices)} unique samples, not removing any") + return batch + + threshold = 0 + if use_influence_scores: + # Use influence-based filtering (original behavior) + if not self.hparams.enable_influence_scoring: + print("Warning: Influence scoring disabled but influence-based filtering requested, using random filtering instead") + use_influence_scores = False + else: + # Get influence scores for samples in this batch + sample_scores = [] + for idx in unique_indices: + idx_item = idx.item() + + # Get latest influence score for this sample + if (hasattr(self, 'influence_score_history') and + idx_item in self.influence_score_history): + latest_score = self.influence_score_history[idx_item] + sample_scores.append((idx_item, latest_score)) + else: + # If no history, assign neutral score (0.0) + sample_scores.append((idx_item, 0.0)) + print(f"No influence score history found for sample {idx_item}") + + # Sort by influence score (ascending) and get the lowest scoring samples + sample_scores.sort(key=lambda x: x[1]) + samples_to_remove = [idx for idx, score in sample_scores[:num_to_remove]] + + # save the batch["target"] of the samples that has top-3 influence scores and lowest-3 influence scores + if self.global_step % 500 == 0: + import os + + # Get top-3 and lowest-3 influence score samples + # sample_scores is already sorted by influence score (ascending) + lowest_3_samples = sample_scores[:3] if len(sample_scores) >= 3 else sample_scores + top_3_samples = sample_scores[-3:] if len(sample_scores) >= 3 else sample_scores + + # Create directory for saving targets if it doesn't exist + # save_dir = os.path.join(os.getcwd(), "influence_target_analysis") + save_dir = "./influence_target_analysis" + os.makedirs(save_dir, exist_ok=True) + + # Function to extract target data for specific indices + def extract_targets_for_indices(batch, target_indices): + extracted_targets = [] + dataset_indices = batch["dataset_index"] + targets = batch["target"] if "target" in batch else None + + if targets is None: + print("Warning: No 'target' found in batch") + return [] + + for target_idx in target_indices: + # Find positions in batch where this dataset index appears + mask = (dataset_indices == target_idx) + if mask.any(): + # Extract target for this sample + target_data = targets[mask] + extracted_targets.append({ + 'dataset_index': target_idx, + 'target_data': target_data.cpu(), + 'positions': mask.nonzero().cpu() + }) + return extracted_targets + + # Extract targets for lowest-3 and top-3 samples + lowest_indices = [idx for idx, score in lowest_3_samples] + top_indices = [idx for idx, score in top_3_samples] + + lowest_targets = extract_targets_for_indices(batch, lowest_indices) + top_targets = extract_targets_for_indices(batch, top_indices) + + # Get dataset names for the samples + lowest_dataset_names = [self._get_dataset_name_for_index(idx) for idx, score in lowest_3_samples] + top_dataset_names = [self._get_dataset_name_for_index(idx) for idx, score in top_3_samples] + + # Save the data + save_data = { + 'step': self.global_step, + 'lowest_3_influence': { + 'samples': lowest_3_samples, + 'targets': lowest_targets, + 'dataset_names': lowest_dataset_names + }, + 'top_3_influence': { + 'samples': top_3_samples, + 'targets': top_targets, + 'dataset_names': top_dataset_names + } + } + + # Save to file + save_path = os.path.join(save_dir, f"influence_targets_step_{self.global_step}.pt") + torch.save(save_data, save_path) + print(f"Saved influence target analysis to {save_path}") + print(f"Lowest 3 influence scores: {[score for _, score in lowest_3_samples]}") + print(f"Top 3 influence scores: {[score for _, score in top_3_samples]}") + + + print(f"Influence-based filtering: removing {num_to_remove} batch positions from low-influence samples: {samples_to_remove}") + print(f"Average influence score: {sum(score for _, score in sample_scores) / len(sample_scores)}") + print(f"Median influence score: {sample_scores[len(sample_scores)//2][1]}") + threshold = sample_scores[len(sample_scores)//2][1] + print(f"Samples to remove average influence score: {sum(score for _, score in sample_scores[:num_to_remove]) / num_to_remove}") + + if not use_influence_scores: + # Use recommended weights-based filtering if available, otherwise random filtering + if hasattr(self, 'recommended_weights') and self.recommended_weights: + # Get dataset names and weights for samples in this batch + sample_weights = [] + for idx in unique_indices: + idx_item = idx.item() + dataset_name = self._get_dataset_name_for_index(idx_item) + weight = self.recommended_weights.get(dataset_name, 1.0) # Default to 1.0 if not found + sample_weights.append((idx_item, weight, dataset_name)) + + # Use probabilistic sampling based on inverse weights to maintain diversity + # Lower weights = higher probability of removal, but still maintains randomness + import numpy as np + + # Extract weights and indices + indices = [idx for idx, weight, dataset_name in sample_weights] + weights = np.array([weight for idx, weight, dataset_name in sample_weights]) + dataset_names = [dataset_name for idx, weight, dataset_name in sample_weights] + + # Calculate removal probabilities (inverse of weights, normalized) + # Add small epsilon to avoid division by zero + epsilon = 1e-6 + inverse_weights = 1.0 / (weights + epsilon) + removal_probs = inverse_weights / inverse_weights.sum() + + # Sample indices to remove based on probabilities + try: + selected_indices = np.random.choice( + len(indices), + size=min(num_to_remove, len(indices)), + replace=False, + p=removal_probs + ) + samples_to_remove = [indices[i] for i in selected_indices] + + print(f"Recommended weights-based filtering: probabilistically removing {len(samples_to_remove)} batch positions") + for i in selected_indices: + idx, weight, dataset_name = indices[i], weights[i], dataset_names[i] + prob = removal_probs[i] + print(f" Removing sample {idx} from {dataset_name} (weight: {weight:.3f}, removal_prob: {prob:.3f})") + + except Exception as e: + print(f"Error in probabilistic sampling: {e}, falling back to random selection") + import random + samples_to_remove = random.sample(indices, min(num_to_remove, len(indices))) + else: + # Fallback to random filtering if no recommended weights available + import random + # unique_indices_list = unique_indices.cpu().numpy().tolist() + # samples_to_remove = random.sample(unique_indices_list, min(num_to_remove, len(unique_indices_list))) + samples_to_remove = [] + + print(f"Random filtering (no recommended weights available): removing {len(samples_to_remove)} batch positions from randomly selected samples: {samples_to_remove}") + + if len(samples_to_remove) == 0: + return batch, threshold + + # Create mask for samples to keep + keep_mask = torch.ones(batch_size, dtype=torch.bool, device=dataset_indices.device) + + # Track how many batch positions we've removed + removed_count = 0 + + for sample_idx in samples_to_remove: + if removed_count >= num_to_remove: + break + + # Find batch positions that contain this sample + sample_mask = (dataset_indices == sample_idx).any(dim=1) + sample_positions = sample_mask.nonzero().squeeze(1) + + # Remove only one instance of this sample (or fewer if we're at the limit) + positions_to_remove = min(len(sample_positions), num_to_remove - removed_count) + if positions_to_remove > 0: + keep_mask[sample_positions[:positions_to_remove]] = False + removed_count += positions_to_remove + + # Filter all tensors in the batch + filtered_batch = {} + for key, value in batch.items(): + if torch.is_tensor(value): + if value.dim() > 0 and value.shape[0] == batch_size: + # This tensor has batch dimension, filter it + filtered_batch[key] = value[keep_mask] + else: + # This tensor doesn't have batch dimension, keep as is + filtered_batch[key] = value + else: + # Non-tensor values, keep as is + filtered_batch[key] = value + + original_batch_size = batch_size + new_batch_size = keep_mask.sum().item() + actual_removed = original_batch_size - new_batch_size + + # Determine filter type for logging + if use_influence_scores: + filter_type = "influence-based" + elif hasattr(self, 'recommended_weights') and self.recommended_weights: + filter_type = "recommended weights-based" + else: + filter_type = "random" + + print(f"Filtered batch size ({filter_type}): {original_batch_size} -> {new_batch_size} (actually removed {actual_removed} positions)") + + return filtered_batch, threshold + + def _update_influence_score_history(self, influence_scores): + """Update the influence score history for batch filtering""" + + if not hasattr(self, 'influence_score_history'): + self.influence_score_history = {} + + # Process each sample's influence scores + for sample_idx, score_entries in influence_scores.items(): + if len(score_entries) > 0: + # Use the most recent influence score for this sample + latest_score_entry = score_entries[-1] + influence_score = latest_score_entry['influence_score'] + + # Store the latest influence score + self.influence_score_history[sample_idx] = influence_score + + print(f"Updated influence score history for {len(influence_scores)} samples") + + def _cleanup_old_influence_scores(self, keep_recent_steps=1600): + """Clean up old influence scores to keep only recent steps for memory management""" + + if not hasattr(self, 'influence_scores') or not self.influence_scores: + return + + current_step = self.global_step + cutoff_step = current_step - keep_recent_steps + + # Count entries before cleanup + total_entries_before = sum(len(scores) for scores in self.influence_scores.values()) + samples_before = len(self.influence_scores) + + # Clean up old entries from each sample's influence scores + samples_to_remove = [] + + for sample_idx, score_entries in self.influence_scores.items(): + # Filter out old entries based on step + recent_entries = [ + entry for entry in score_entries + if entry.get('step', 0) > cutoff_step + ] + + if recent_entries: + # Keep only recent entries + self.influence_scores[sample_idx] = recent_entries + else: + # Mark sample for removal if no recent entries + samples_to_remove.append(sample_idx) + + # Remove samples with no recent entries + for sample_idx in samples_to_remove: + del self.influence_scores[sample_idx] + + # Count entries after cleanup + total_entries_after = sum(len(scores) for scores in self.influence_scores.values()) + samples_after = len(self.influence_scores) + + entries_removed = total_entries_before - total_entries_after + samples_removed = samples_before - samples_after + + if entries_removed > 0 or samples_removed > 0: + print(f"Cleaned up influence scores: removed {entries_removed} old entries from {samples_removed} samples") + print(f"Kept influence scores from step {cutoff_step + 1} onwards (recent {keep_recent_steps} steps)") + print(f"Remaining: {total_entries_after} entries across {samples_after} samples") + + def _compute_per_sample_gradients_with_indices(self, batch, unique_indices): + """Compute per-sample gradients for samples with specific dataset indices.""" + per_sample_grads = {} + + # Initialize gradient storage + for name, param in self.named_parameters(): + if param.requires_grad: + per_sample_grads[name] = [] + + # Get dataset indices tensor + dataset_indices = batch["dataset_index"] + + # Compute gradient for each unique dataset index + for dataset_idx in unique_indices: + # Zero gradients + self.zero_grad() + + # Create mask for this dataset index + mask = (dataset_indices == dataset_idx) + + # Find positions where this dataset index appears + batch_indices, seq_indices = torch.where(mask) + + if len(batch_indices) == 0: + # No data for this index, store None + print(f"WARNING: No data for dataset index {dataset_idx}, skipping...") + for name, param in self.named_parameters(): + if param.requires_grad: + per_sample_grads[name].append(None) + continue + + # Get unique batch indices (samples in the batch containing this dataset index) + unique_batch_indices = batch_indices.unique() + + # Create a mini-batch with only the relevant samples + single_batch = {} + for key, value in batch.items(): + if torch.is_tensor(value): + if value.dim() > 1: + # Take only the samples that contain this dataset index + single_batch[key] = value[unique_batch_indices] + else: + single_batch[key] = value + else: + single_batch[key] = value + + # Forward pass for this dataset index + try: + output = self(**{ + field: single_batch[field] + for field in list(self.train_seq_fields) + ["sample_id"] + if field in single_batch + }) + + # Compute loss for this dataset index + loss = self.hparams.loss_func( + pred=output, + target=single_batch.get("label"), + observed_mask=single_batch.get("label_observed_mask"), + prediction_mask=single_batch.get("prediction_mask"), + sample_id=single_batch.get("sample_id"), + variate_id=single_batch.get("variate_id"), + ) + + # Scale loss by the proportion of data from this dataset index + total_elements = mask.sum().item() + loss = loss * total_elements / mask.numel() + + # Backward pass + loss.backward(retain_graph=True) + + # Store gradients + for name, param in self.named_parameters(): + if param.requires_grad and param.grad is not None: + per_sample_grads[name].append(param.grad.clone().detach()) + else: + per_sample_grads[name].append(None) + + except Exception as e: + print(f"Error computing gradient for dataset index {dataset_idx}: {e}") + # Store None for this sample + for name, param in self.named_parameters(): + if param.requires_grad: + per_sample_grads[name].append(None) + + # Convert lists to tensors + for name in per_sample_grads: + valid_grads = [g for g in per_sample_grads[name] if g is not None] + if valid_grads: + per_sample_grads[name] = torch.stack(valid_grads, dim=0) + else: + per_sample_grads[name] = None + + # Clear gradients + self.zero_grad() + + return per_sample_grads + + def compute_influence_scores(self, val_gradients): + """Compute influence scores by inner product with validation gradients""" + if not hasattr(self, 'per_example_gradients'): + print("No per-example gradients stored") + return {} + + influence_scores = {} + + # Try to get dataset metadata for mapping global indices to dataset names + dataset_metadata = None + try: + if hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'train_dataset'): + train_dataset = self.trainer.datamodule.train_dataset + # Navigate to the ConcatDatasetBuilderWithGlobalIndex if it exists + # This assumes the train_dataset was created by instantiating a config with a ConcatDatasetBuilderWithGlobalIndex + # We need to trace back to find the dataset builder that created this dataset + if hasattr(train_dataset, 'datasets'): # ConcatDataset + # Look for global index metadata in any of the sub-datasets + for sub_dataset in train_dataset.datasets: + if hasattr(sub_dataset, 'global_offset'): + # This indicates we're using the enhanced datasets with global indexing + # We need to find the builder that created the dataset hierarchy + pass + except Exception as e: + print(f"Warning: Could not access dataset metadata for sub-dataset names: {e}") + + for sample_idx, grad_history in self.per_example_gradients.items(): + sample_scores = [] + + for grad_entry in grad_history: + train_grads = grad_entry['gradients'] + + # Compute inner product or cosine similarity with validation gradients + similarity_score = 0.0 + param_count = 0 + + if self.hparams.use_cosine_similarity: + # Compute cosine similarity for each parameter separately and average + for param_name in train_grads: + if param_name in val_gradients and train_grads[param_name] is not None: + train_grad = train_grads[param_name].flatten() + val_grad = val_gradients[param_name].flatten() + + # Ensure same size + if train_grad.shape == val_grad.shape: + # Compute cosine similarity: cos(θ) = (A·B) / (||A|| * ||B||) + train_norm = torch.norm(train_grad) + val_norm = torch.norm(val_grad) + + if train_norm > 0 and val_norm > 0: + cosine_sim = torch.dot(train_grad, val_grad) / (train_norm * val_norm) + similarity_score += cosine_sim.item() + param_count += 1 + else: + # Original dot product computation + for param_name in train_grads: + if param_name in val_gradients and train_grads[param_name] is not None: + train_grad = train_grads[param_name].flatten() + val_grad = val_gradients[param_name].flatten() + + # Ensure same size + if train_grad.shape == val_grad.shape: + similarity_score += torch.dot(train_grad, val_grad).item() + param_count += 1 + + if param_count > 0: + score_entry = { + 'influence_score': similarity_score, + 'step': grad_entry['step'], + 'epoch': grad_entry['epoch'], + 'loss': None, # We don't have outputs yet in on_train_batch_start + 'global_dataset_index': sample_idx # Store the global index + } + + # Try to add dataset name if we can map global index to dataset + if dataset_metadata and sample_idx in dataset_metadata: + score_entry['dataset_name'] = dataset_metadata[sample_idx] + else: + # Fallback: try to map using a simpler approach + score_entry['dataset_name'] = self._get_dataset_name_for_index(sample_idx) + + sample_scores.append(score_entry) + + influence_scores[sample_idx] = sample_scores + + return influence_scores + + def compute_influence_scores_batched(self, val_gradients): + """Compute influence scores using batched operations for major speedup""" + if not hasattr(self, 'per_example_gradients'): + print("No per-example gradients stored") + return {} + + if not val_gradients: + print("No validation gradients provided") + return {} + + # Pre-process validation gradients - get consistent parameter order + param_names = sorted([n for n in val_gradients.keys() if val_gradients[n] is not None]) + if not param_names: + print("No valid validation gradients found") + return {} + + val_grad_flat = torch.cat([val_gradients[name].flatten() for name in param_names]) + device = val_grad_flat.device + + # Collect all training gradients into batches + all_sample_indices = [] + all_train_grads = [] + all_metadata = [] + + for sample_idx, grad_history in self.per_example_gradients.items(): + for grad_entry in grad_history: + train_grads = grad_entry['gradients'] + + # Check if all required gradients exist and are valid + valid_grads = [] + valid = True + + for name in param_names: + if name in train_grads and train_grads[name] is not None: + valid_grads.append(train_grads[name].flatten()) + else: + valid = False + break + + if valid and len(valid_grads) == len(param_names): + try: + train_grad_flat = torch.cat(valid_grads) + # Ensure same device + train_grad_flat = train_grad_flat.to(device) + + all_sample_indices.append(sample_idx) + all_train_grads.append(train_grad_flat) + all_metadata.append(grad_entry) + except Exception as e: + print(f"Warning: Could not process gradients for sample {sample_idx}: {e}") + continue + + if not all_train_grads: + print("No valid training gradients found for batch processing") + return {} + + print(f"Batch processing influence scores for {len(all_train_grads)} gradient entries...") + + try: + # Stack into matrix: [num_samples, num_params] + train_grad_matrix = torch.stack(all_train_grads) + + if self.hparams.use_cosine_similarity: + # Normalize training gradients (each row) + train_grad_norms = torch.norm(train_grad_matrix, dim=1, keepdim=True) + # Avoid division by zero + train_grad_norms = torch.clamp(train_grad_norms, min=1e-8) + train_grad_matrix_normalized = train_grad_matrix / train_grad_norms + + # Normalize validation gradients + val_grad_norm = torch.norm(val_grad_flat) + val_grad_norm = torch.clamp(val_grad_norm, min=1e-8) + val_grad_flat_normalized = val_grad_flat / val_grad_norm + + # Compute cosine similarity: [num_samples] + influence_scores_flat = torch.matmul(train_grad_matrix_normalized, val_grad_flat_normalized) + else: + # Original dot product computation: [num_samples] + influence_scores_flat = torch.matmul(train_grad_matrix, val_grad_flat) + + # Process results back into the expected format + influence_scores = {} + for i, (sample_idx, metadata) in enumerate(zip(all_sample_indices, all_metadata)): + score_entry = { + 'influence_score': influence_scores_flat[i].item(), + 'step': metadata['step'], + 'epoch': metadata['epoch'], + 'loss': metadata.get('loss', None), + 'global_dataset_index': sample_idx + } + + # Try to add dataset name + try: + score_entry['dataset_name'] = self._get_dataset_name_for_index(sample_idx) + except Exception as e: + score_entry['dataset_name'] = f"Unknown_Idx_{sample_idx}" + + if sample_idx not in influence_scores: + influence_scores[sample_idx] = [] + influence_scores[sample_idx].append(score_entry) + + print(f"Successfully computed influence scores for {len(influence_scores)} unique samples") + return influence_scores + + except Exception as e: + print(f"Error in batched influence computation: {e}") + print("Falling back to original method...") + return self.compute_influence_scores(val_gradients) + + def compute_influence_scores_batched_on_generated(self, val_gradients, gen_grad): + """Compute influence scores for generated samples using batched operations""" + if not val_gradients: + print("No validation gradients provided") + return {} + + if not gen_grad: + print("No generated gradients provided") + return {} + + # Pre-process validation gradients - get consistent parameter order + param_names = sorted([n for n in val_gradients.keys() if val_gradients[n] is not None]) + if not param_names: + print("No valid validation gradients found") + return {} + + val_grad_flat = torch.cat([val_gradients[name].flatten() for name in param_names]) + device = val_grad_flat.device + + # Collect generated gradients into batches + all_train_grads = [] + + # Process generated gradients + for name in param_names: + if name in gen_grad and gen_grad[name] is not None: + # gen_grad[name] should be [num_generated_samples, param_shape...] + gen_grad_param = gen_grad[name] + if gen_grad_param.dim() > 1: + # Flatten each sample's gradient for this parameter + flattened_grads = gen_grad_param.view(gen_grad_param.shape[0], -1) + if len(all_train_grads) == 0: + # Initialize with the first parameter's gradients + all_train_grads = [[] for _ in range(gen_grad_param.shape[0])] + + for i in range(gen_grad_param.shape[0]): + all_train_grads[i].append(flattened_grads[i]) + else: + print(f"Warning: Unexpected gradient shape for {name}: {gen_grad_param.shape}") + continue + else: + print(f"Warning: Parameter {name} not found in generated gradients") + return {} + + if not all_train_grads: + print("No valid generated gradients found for batch processing") + return {} + + # Concatenate gradients for each sample + try: + processed_grads = [] + for sample_grads in all_train_grads: + if len(sample_grads) == len(param_names): + sample_grad_flat = torch.cat(sample_grads) + sample_grad_flat = sample_grad_flat.to(device) + processed_grads.append(sample_grad_flat) + + if not processed_grads: + print("No valid processed gradients found") + return {} + + print(f"Batch processing influence scores for {len(processed_grads)} generated samples...") + + # Stack into matrix: [num_generated_samples, num_params] + train_grad_matrix = torch.stack(processed_grads) + + # Original dot product computation: [num_generated_samples] + influence_scores_flat = torch.matmul(train_grad_matrix, val_grad_flat) + + # Return the influence scores as a simple list or tensor + influence_scores_list = influence_scores_flat.detach().cpu().numpy().tolist() + + print(f"Successfully computed influence scores for {len(influence_scores_list)} generated samples") + print(f"Generated samples influence scores: {influence_scores_list}") + + return influence_scores_list + + except Exception as e: + print(f"Error in generated samples influence computation: {e}") + return {} + + def _get_dataset_name_for_index(self, global_idx: int) -> str: + """Helper method to get dataset name for a global index.""" + try: + # BOUNDS CHECK: Validate global_idx is within reasonable range + dataset_size = len(self.trainer.datamodule.train_dataset) # Reasonable upper bound + if global_idx > dataset_size: + print(f"WARNING: Global index {global_idx} exceeds reasonable bounds (>{dataset_size}). This might indicate stale influence scores.") + return f"OutOfBounds_Idx_{global_idx}" + + # Method 1: Try to use ConcatDatasetBuilderWithGlobalIndex metadata (preferred) + if (hasattr(self.trainer, 'datamodule') and + hasattr(self.trainer.datamodule, 'data_builder') and + hasattr(self.trainer.datamodule.data_builder, 'get_dataset_name_for_global_index')): + + # Additional bounds check using actual dataset size + if hasattr(self.trainer.datamodule, 'train_dataset'): + dataset_size = len(self.trainer.datamodule.train_dataset) + if global_idx >= dataset_size: + print(f"WARNING: Global index {global_idx} >= dataset size {dataset_size}. Clearing stale influence scores.") + # Clear stale per_example_gradients to prevent future issues + if hasattr(self, 'per_example_gradients'): + # Remove any indices beyond the current dataset size + stale_indices = [idx for idx in self.per_example_gradients.keys() if idx >= dataset_size] + for idx in stale_indices: + del self.per_example_gradients[idx] + if stale_indices: + print(f"Cleared {len(stale_indices)} stale influence scores with indices: {stale_indices[:10]}{'...' if len(stale_indices) > 10 else ''}") + return f"Stale_Idx_{global_idx}_Cleared" + + return self.trainer.datamodule.data_builder.get_dataset_name_for_global_index(global_idx) + + # Method 2: Fallback to manual traversal of ConcatDataset + elif hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'train_dataset'): + train_dataset = self.trainer.datamodule.train_dataset + + # Bounds check against actual dataset + if global_idx >= len(train_dataset): + print(f"WARNING: Global index {global_idx} >= actual dataset size {len(train_dataset)}") + return f"OutOfRange_Idx_{global_idx}" + + # Check if it's a ConcatDataset with sub-datasets that have global_offset + if hasattr(train_dataset, 'datasets'): + cumulative_size = 0 + for i, sub_dataset in enumerate(train_dataset.datasets): + dataset_size = len(sub_dataset) + if global_idx < cumulative_size + dataset_size: + # This global index belongs to this sub-dataset + # Try to get the dataset name from various sources + + # Method 2a: Check if dataset has indexer with dataset info + if hasattr(sub_dataset, 'indexer') and hasattr(sub_dataset.indexer, 'dataset'): + if hasattr(sub_dataset.indexer.dataset, 'info') and hasattr(sub_dataset.indexer.dataset.info, 'dataset_name'): + return sub_dataset.indexer.dataset.info.dataset_name + + # Method 2b: Check if dataset itself has info + if hasattr(sub_dataset, 'info') and hasattr(sub_dataset.info, 'dataset_name'): + return sub_dataset.info.dataset_name + + # Method 2c: Return a descriptive name based on position + return f"SubDataset_{i}" + + cumulative_size += dataset_size + + # Final fallback + return f"Dataset_GlobalIdx_{global_idx}" + + except Exception as e: + return f"Unknown_Idx_{global_idx}_Error_{str(e)[:30]}" + + def get_validation_gradients_from_trainer(self, max_val_samples=32): # None + """Compute gradients using trainer's validation dataloader + + Args: + max_val_samples: Maximum number of validation samples to use. If None, uses all samples. + """ + + if not hasattr(self.trainer, 'val_dataloaders') or not self.trainer.val_dataloaders: + print("Warning: No validation dataloader found in trainer") + return {} + + # Get validation dataloader from trainer + val_dataloader = self.trainer.val_dataloaders[0] # Use first validation dataloader + + # Check sampler for shuffle as well + if hasattr(val_dataloader, 'sampler') and hasattr(val_dataloader.sampler, 'shuffle'): + if val_dataloader.sampler.shuffle: + print("WARNING: Validation dataloader sampler has shuffle=True!") + print("This may still cause non-deterministic ordering even if shuffle=false in config.") + print("Val dataloader sampler shuffle: ", val_dataloader.sampler.shuffle) + + return self.get_validation_gradients(val_dataloader, max_val_samples) + + def get_validation_gradients(self, dataloader_or_dataset, max_val_samples=None): + """Compute gradients on validation dataset for influence computation + + Args: + dataloader_or_dataset: Validation dataloader or dataset + max_val_samples: Maximum number of validation samples to use. If None, uses all samples. + """ + + # Save current training state + was_training = self.training + self.eval() + self.zero_grad() + + # Handle both dataset and dataloader inputs + val_dataloader = dataloader_or_dataset + + total_samples = 0 + samples_processed = 0 + + print(f"Computing validation gradients over {len(val_dataloader)} batches...") + if max_val_samples is not None: + print(f"Limiting to maximum {max_val_samples} validation samples") + print(f"NOTE: Validation dataset indices will NOT interfere with training influence scores") + + # Collect all validation data first for single backward pass + all_batches = [] + + for batch_idx, val_batch in enumerate(val_dataloader): + try: + # IMPORTANT: Remove dataset_index from validation batch to prevent contamination + if self.cache_val_batch is not None: + val_batch_clean = self.cache_val_batch + print("Using cached val batch --------------------------------") + else: + val_batch_clean = {k: v for k, v in val_batch.items() if k != 'dataset_index'} + self.cache_val_batch = val_batch_clean + # Move tensors to device + val_batch_clean = { + k: v.to(self.device) if torch.is_tensor(v) else v + for k, v in val_batch_clean.items() + } + + # Get batch size for this batch + batch_size = ( + val_batch_clean["sample_id"].max(dim=1).values.sum().item() + if "sample_id" in val_batch_clean and val_batch_clean["sample_id"].dim() > 1 + else val_batch_clean["target"].shape[0] + ) + + # Check if we've reached the limit + if max_val_samples is not None and samples_processed + batch_size > max_val_samples: + # Only take what we need from this batch + remaining_samples = max_val_samples - samples_processed + if remaining_samples <= 0: + break + + # Truncate the batch to only take remaining_samples + for key in val_batch_clean: + if torch.is_tensor(val_batch_clean[key]) and val_batch_clean[key].dim() > 0: + val_batch_clean[key] = val_batch_clean[key][:remaining_samples] + + batch_size = remaining_samples + + # Store the batch for later processing + all_batches.append(val_batch_clean) + + total_samples += batch_size + samples_processed += batch_size + + # Clear intermediate variables to free memory + del val_batch + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + if (batch_idx + 1) % 10 == 0: + print(f"Processed {batch_idx + 1}/{len(val_dataloader)} validation batches") + + # Check if we've reached the limit + if max_val_samples is not None and samples_processed >= max_val_samples: + print(f"Reached maximum validation samples limit: {max_val_samples}") + break + + except Exception as e: + print(f"Error in validation batch {batch_idx}: {e}") + continue + + if total_samples == 0: + print("Warning: No validation samples processed") + return {} + + print(f"Collected {total_samples} validation samples, computing gradients...") + + # Now compute the total loss and do a single backward pass + try: + # Compute total loss across all collected data with gradients enabled + total_loss = 0.0 + num_batches = len(all_batches) + + for i, val_batch_clean in enumerate(all_batches): + # Forward pass with gradient computation enabled + with torch.enable_grad(): + output = self(**{ + field: val_batch_clean[field] + for field in list(self.train_seq_fields) + ["sample_id"] + if field in val_batch_clean + }) + + # Compute loss for this batch + batch_loss = self.hparams.loss_func( + pred=output, + target=val_batch_clean["label"], + observed_mask=val_batch_clean["label_observed_mask"], + prediction_mask=val_batch_clean["prediction_mask"], + sample_id=val_batch_clean["sample_id"], + variate_id=val_batch_clean["variate_id"], + ) + + # Add to total loss + total_loss += batch_loss + + # Average the loss + total_loss = total_loss / num_batches + + # Single backward pass + total_loss.backward() + + # Extract gradients + val_gradients = {} + for name, param in self.named_parameters(): + if param.requires_grad and param.grad is not None: + val_gradients[name] = param.grad.clone().detach() + + # Clear gradients and restore training state + self.zero_grad() + if was_training: + self.train() + + print(f"Computed validation gradients for {len(val_gradients)} parameters over {total_samples} samples") + + return val_gradients + + except Exception as e: + print(f"Error in single backward pass: {e}") + print("Falling back to original method...") + + # Fallback: Clear gradients and restore training state + self.zero_grad() + if was_training: + self.train() + + # Return empty dict to indicate failure + return {} + + + def clear_per_example_gradients(self, keep_recent_steps=None): + """Clear old per-example gradients to manage memory""" + if not hasattr(self, 'per_example_gradients'): + return + + self.per_example_gradients = {} + + def validation_step( + self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0 + ) -> torch.Tensor: + distr, preds = self.infer( + **{field: batch[field] for field in list(self.train_seq_fields) + ["sample_id"]} + ) + val_loss = self.hparams.loss_func( + pred=distr, + target=batch["label"], + observed_mask=batch["label_observed_mask"], + **{ + field: batch[field] + for field in [ + "prediction_mask", + "sample_id", + "variate_id", + ] + }, + ) + batch_size = ( + batch["sample_id"].max(dim=1).values.sum() if "sample_id" in batch else None + ) + self.log( + f"val/{self.hparams.loss_func.__class__.__name__}", + val_loss, + on_step=self.hparams.log_on_step, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + batch_size=batch_size, + rank_zero_only=True, + add_dataloader_idx=True, + ) + + if self.hparams.val_metric is not None: + val_metrics = ( + self.hparams.val_metric + if isinstance(self.hparams.val_metric, list) + else [self.hparams.val_metric] + ) + for metric_func in val_metrics: + metric = metric_func( + pred=preds, + target=batch["label"], + observed_mask=batch["label_observed_mask"], + **{ + field: batch[field] + for field in [ + "prediction_mask", + "sample_id", + "variate_id", + ] + }, + ) + + self.log( + f"val/{metric_func.__class__.__name__}", + metric, + on_step=self.hparams.log_on_step, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + batch_size=batch_size, + rank_zero_only=True, + add_dataloader_idx=True, + ) + + return val_loss + + def configure_optimizers(self) -> dict: + decay = set() + no_decay = set() + + whitelist_params = ( + LearnedProjection, + nn.Linear, + ) + blacklist_params = ( + LearnedEmbedding, + RMSNorm, + nn.Embedding, + nn.LayerNorm, + ) + + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + if not p.requires_grad: + continue + + fpn = f"{mn}.{pn}" if mn else pn + if pn.endswith("bias"): + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_params): + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_params): + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} + inter_params = decay & no_decay + union_params = decay | no_decay + assert ( + len(inter_params) == 0 + ), f"parameters {str(inter_params)} made it into both decay/no_decay sets!" + assert ( + len(param_dict.keys() - union_params) == 0 + ), f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" + + optim_groups = [ + { + "params": filter( + lambda p: p.requires_grad, + [param_dict[pn] for pn in sorted(list(decay))], + ), + "weight_decay": self.hparams.weight_decay, + }, + { + "params": filter( + lambda p: p.requires_grad, + [param_dict[pn] for pn in sorted(list(no_decay))], + ), + "weight_decay": 0.0, + }, + ] + + optimizer = torch.optim.AdamW( + optim_groups, + lr=self.hparams.lr, + betas=(self.hparams.beta1, self.hparams.beta2), + eps=1e-6, + ) + scheduler = get_scheduler( + SchedulerType.COSINE_WITH_RESTARTS, + optimizer, + num_warmup_steps=self.hparams.num_warmup_steps, + num_training_steps=self.hparams.num_training_steps, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "train_loss", + "interval": "step", + }, + } + + @property + def train_transform_map(self) -> dict[str, Callable[..., Transformation]]: + def default_train_transform(): + return ( + SampleDimension( + max_dim=self.hparams.max_dim, + fields=("target",), + optional_fields=(), + ) + + GetPatchSize( + min_time_patches=self.hparams.min_patches, + target_field="target", + patch_size=self.module.patch_size, + patch_size_constraints=None, + offset=True, + ) + + PatchCrop( + min_time_patches=self.hparams.min_patches, + max_patches=self.module.max_seq_len, + will_flatten=True, + offset=True, + fields=("target",), + optional_fields=(), + ) + + PackFields( + output_field="target", + fields=("target",), + feat=False, + ) + + AddObservedMask( + fields=("target",), + optional_fields=(), + observed_mask_field="observed_mask", + collection_type=dict, + ) + + ImputeTimeSeries( + fields=("target",), + optional_fields=(), + imputation_method=DummyValueImputation(value=0.0), + ) + + Patchify( + max_patch_size=self.module.patch_size, + fields=("target", "observed_mask"), + optional_fields=(), + ) + + MaskedPrediction( + min_mask_ratio=self.hparams.min_mask_ratio, + max_mask_ratio=self.hparams.max_mask_ratio, + target_field="target", + truncate_fields=(), + optional_truncate_fields=(), + prediction_mask_field="prediction_mask", + expected_ndim=3, + ) + + AddVariateIndex( + fields=("target",), + optional_fields=(), + variate_id_field="variate_id", + expected_ndim=3, + max_dim=self.hparams.max_dim, + randomize=False, + collection_type=dict, + ) + + AddTimeIndex( + fields=("target",), + optional_fields=(), + time_id_field="time_id", + expected_ndim=3, + collection_type=dict, + ) + + FlatPackCollection( + field="variate_id", + feat=False, + ) + + FlatPackCollection( + field="time_id", + feat=False, + ) + + FlatPackCollection( + field="prediction_mask", + feat=False, + ) + + FlatPackCollection( + field="observed_mask", + feat=True, + ) + + FlatPackCollection( + field="label_observed_mask", + feat=True, + ) + + FlatPackFields( + output_field="label", + fields=("label",), + optional_fields=(), + feat=True, + ) + + FlatPackFields( + output_field="target", + fields=("target",), + optional_fields=(), + feat=True, + ) + + SequencifyField(field="patch_size", target_field="target") + + SelectFields(fields=list(self.seq_fields) + ["_dataset_idx"]) + ) + + return defaultdict(lambda: default_train_transform) + + @property + def val_transform_map( + self, + ) -> dict[str | type, Callable[..., Transformation]]: + def default_val_transform( + offset: int, + distance: int, + prediction_length: int, + context_length: int, + patch_size: int, + ): + return ( + SampleDimension( + max_dim=1, + fields=("target",), + optional_fields=(), + ) + + GetPatchSize( + min_time_patches=2, + target_field="target", + patch_size=self.module.patch_size, + patch_size_constraints=None, + offset=True, + ) + + EvalCrop_AdaLength( + offset, + distance, + prediction_length, + context_length, + fields=("target",), + optional_fields=(), + ) + + PackFields( + output_field="target", + fields=("target",), + ) + + EvalPad_AdaLength( + prediction_length=prediction_length, + context_length=context_length, + patch_size=self.module.patch_size, + fields=("target",), + optional_fields=() + ) + + AddObservedMask( + fields=("target",), + optional_fields=(), + observed_mask_field="observed_mask", + collection_type=dict, + ) + + ImputeTimeSeries( + fields=("target",), + optional_fields=(), + imputation_method=DummyValueImputation(value=0.0), + ) + + Patchify( + max_patch_size=self.module.patch_size, + fields=("target", "observed_mask"), + optional_fields=(), + ) + + AddVariateIndex( + fields=("target",), + optional_fields=(), + variate_id_field="variate_id", + expected_ndim=3, + max_dim=self.hparams.max_dim, + randomize=False, + collection_type=dict, + ) + + AddTimeIndex( + fields=("target",), + optional_fields=(), + time_id_field="time_id", + expected_ndim=3, + collection_type=dict, + ) + + EvalMaskedPrediction( + mask_length=math.ceil(prediction_length / patch_size), + target_field="target", + truncate_fields=(), + optional_truncate_fields=(), + prediction_mask_field="prediction_mask", + expected_ndim=3, + ) + + ExtendMask( + fields=tuple(), + optional_fields=(), + mask_field="prediction_mask", + expected_ndim=3, + ) + + FlatPackCollection( + field="variate_id", + feat=False, + ) + + FlatPackCollection( + field="time_id", + feat=False, + ) + + FlatPackCollection( + field="prediction_mask", + feat=False, + ) + + FlatPackCollection( + field="observed_mask", + feat=True, + ) + + FlatPackFields( + output_field="target", + fields=("target",), + optional_fields=(), + feat=True, + ) + + FlatPackCollection( + field="label_observed_mask", + feat=True, + ) + + FlatPackFields( + output_field="label", + fields=("label",), + optional_fields=(), + feat=True, + ) + + SequencifyField(field="patch_size", target_field="target") + + SelectFields(fields=list(self.seq_fields) + ["_dataset_idx"]) + ) + + return defaultdict(lambda: default_val_transform) \ No newline at end of file diff --git a/OATS/models/tsfm/module/__init__.py b/OATS/models/tsfm/module/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/OATS/models/tsfm/module/attention.py b/OATS/models/tsfm/module/attention.py new file mode 100644 index 0000000..cbdcb4d --- /dev/null +++ b/OATS/models/tsfm/module/attention.py @@ -0,0 +1,350 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import math +from collections.abc import Callable +from functools import partial +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from jaxtyping import Bool, Float, Int +from torch import nn + +from .position import AttentionBias, QueryKeyProjection + +# TODO: Support returning weights +# TODO: Support caching (return past_key_value) + + +def native_scaled_dot_product_attention( + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + value: Float[torch.Tensor, "*batch group hpg kv_len dim"], + attn_mask: Optional[ + Bool[torch.Tensor, "*batch #group #hpg q_len kv_len"] + | Float[torch.Tensor, "*batch #group #hpg q_len kv_len"] + ] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, +): + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_weight = query @ key.transpose(-2, -1) * scale_factor + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias = torch.zeros_like(attn_weight) + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_weight = attn_weight + attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + + +class GroupedQueryAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_groups: int, + bias: bool = True, + norm_layer: Optional[type[nn.Module] | partial[nn.Module]] = nn.LayerNorm, + softmax_scale: Optional[float] = None, + attn_dropout_p: float = 0.0, + var_attn_bias: Optional[Callable[[], AttentionBias]] = None, + time_attn_bias: Optional[Callable[[], AttentionBias]] = None, + var_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + time_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + ): + super().__init__() + assert num_heads > 0 and dim % num_heads == 0 + assert (num_heads % num_groups == 0) and (num_heads >= num_groups) + + self.num_heads = num_heads + self.num_groups = num_groups + self.head_dim = dim // num_heads + self.heads_per_group = num_heads // num_groups + self.var_attn_bias = var_attn_bias() if var_attn_bias is not None else None + self.time_attn_bias = time_attn_bias() if time_attn_bias is not None else None + self.var_qk_proj = var_qk_proj() if var_qk_proj is not None else None + self.time_qk_proj = time_qk_proj() if time_qk_proj is not None else None + + self.softmax_scale = softmax_scale or 1 / math.sqrt(self.head_dim) + + self.q_proj = nn.Linear(dim, dim, bias=bias) + self.k_proj = nn.Linear(dim, self.head_dim * num_groups, bias=bias) + self.v_proj = nn.Linear(dim, self.head_dim * num_groups, bias=bias) + self.q_norm = ( + norm_layer(self.head_dim) if norm_layer is not None else nn.Identity() + ) + self.k_norm = ( + norm_layer(self.head_dim) if norm_layer is not None else nn.Identity() + ) + self.attn_dropout_p = attn_dropout_p + self.out_proj = nn.Linear(dim, dim, bias=bias) + + def _get_var_id( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_var_id: Optional[Int[torch.Tensor, "*batch q_len"]], + kv_var_id: Optional[Int[torch.Tensor, "*batch kv_len"]], + ) -> tuple[ + Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], + Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], + ]: + if self.var_attn_bias is not None or self.var_qk_proj is not None: + if query_var_id is None: + query_var_id = repeat( + torch.zeros((), device=query.device, dtype=torch.long), + f" -> {' '.join(map(str, query.shape[:-4]))} 1 1 {query.shape[-2]}", + ) + else: + query_var_id = rearrange(query_var_id, "... q_len -> ... 1 1 q_len") + + if kv_var_id is None: + kv_var_id = repeat( + torch.zeros((), device=key.device, dtype=torch.long), + f" -> {' '.join(map(str, key.shape[:-4]))} 1 1 {key.shape[-2]}", + ) + else: + kv_var_id = rearrange(kv_var_id, "... kv_len -> ... 1 1 kv_len") + + return query_var_id, kv_var_id + + def _get_time_id( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]], + kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]], + ) -> tuple[ + Optional[Int[torch.Tensor, "*batch 1 1 q_len"]], + Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]], + ]: + if self.time_attn_bias is not None or self.time_qk_proj is not None: + if query_time_id is None: + query_time_id = repeat( + torch.arange( + query.shape[-2], device=query.device, dtype=torch.long + ), + f"q_len -> {' '.join(map(str, query.shape[:-4]))} 1 1 q_len", + ) + else: + query_time_id = rearrange(query_time_id, "... q_len -> ... 1 1 q_len") + + if kv_time_id is None: + kv_time_id = repeat( + torch.arange(key.shape[-2], device=key.device, dtype=torch.long), + f"kv_len -> {' '.join(map(str, key.shape[:-4]))} 1 1 kv_len", + ) + else: + kv_time_id = rearrange(kv_time_id, "... kv_len-> ... 1 1 kv_len") + + return query_time_id, kv_time_id + + def _update_attn_mask( + self, + attn_mask: Optional[Bool[torch.Tensor, "*batch q_len kv_len"]], + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_var_id: Optional[Int[torch.Tensor, "*batch 1 1 q_len"]] = None, + kv_var_id: Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]] = None, + query_time_id: Optional[Int[torch.Tensor, "*batch 1 1 q_len"]] = None, + kv_time_id: Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]] = None, + ) -> Optional[ + Bool[torch.Tensor, "*batch #group #hpg q_len kv_len"] + | Float[torch.Tensor, "*batch #group #hpg q_len kv_len"] + ]: + if attn_mask is not None: + attn_mask = rearrange( + attn_mask, + "... q_len kv_len -> ... 1 1 q_len kv_len", + ) + + attn_bias = 0 + if self.var_attn_bias is not None: + attn_bias = attn_bias + self.var_attn_bias( + query, + key, + query_id=query_var_id, + kv_id=kv_var_id, + ) + + if self.time_attn_bias is not None: + attn_bias = attn_bias + self.time_attn_bias( + query, + key, + query_id=query_time_id, + kv_id=kv_time_id, + ) + + attn_mask = ( + attn_mask + if isinstance(attn_bias, int) + else ( + attn_bias + if attn_mask is None + else attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) + ) + ) + return attn_mask + + def _qk_proj( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_var_id: Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], + kv_var_id: Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], + query_time_id: Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], + kv_time_id: Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], + ) -> tuple[ + Float[torch.Tensor, "*batch group hpg q_len dim"], + Float[torch.Tensor, "*batch group hpg kv_len dim"], + ]: + if self.var_qk_proj is not None: + query, key = self.var_qk_proj( + query, key, query_id=query_var_id, kv_id=kv_var_id + ) + + if self.time_qk_proj is not None: + query, key = self.time_qk_proj( + query, key, query_id=query_time_id, kv_id=kv_time_id + ) + + return query, key + + def forward( + self, + query: Float[torch.Tensor, "*batch q_len dim"], + key: Float[torch.Tensor, "*batch kv_len dim"], + value: Float[torch.Tensor, "*batch kv_len dim"], + attn_mask: Optional[Bool[torch.Tensor, "*batch q_len kv_len"]] = None, + query_var_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, + kv_var_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, + query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, + kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, + ) -> Float[torch.Tensor, "*batch q_len dim"]: + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + query = self.q_norm( + rearrange( + query, + "... q_len (group hpg dim) -> ... group hpg q_len dim", + group=self.num_groups, + hpg=self.heads_per_group, + ) + ) + key = self.k_norm( + repeat( + key, + "... kv_len (group dim) -> ... group hpg kv_len dim", + group=self.num_groups, + hpg=self.heads_per_group, + ) + ) + value = repeat( + value, + "... kv_len (group dim) -> ... group hpg kv_len dim", + group=self.num_groups, + hpg=self.heads_per_group, + ) + + query_var_id, kv_var_id = self._get_var_id(query, key, query_var_id, kv_var_id) + query_time_id, kv_time_id = self._get_time_id( + query, + key, + query_time_id, + kv_time_id, + ) + + attn_mask = self._update_attn_mask( + attn_mask, + query, + key, + query_var_id=query_var_id, + kv_var_id=kv_var_id, + query_time_id=query_time_id, + kv_time_id=kv_time_id, + ) + + query, key = self._qk_proj( + query, + key, + query_var_id=query_var_id, + kv_var_id=kv_var_id, + query_time_id=query_time_id, + kv_time_id=kv_time_id, + ) + + out = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=self.attn_dropout_p, + scale=self.softmax_scale, + ) + out = rearrange(out, "... group hpg q_len dim -> ... q_len (group hpg dim)") + return self.out_proj(out) + + +class MultiQueryAttention(GroupedQueryAttention): + def __init__( + self, + dim: int, + num_heads: int, + bias: bool = True, + norm_layer: Optional[type[nn.Module] | partial[nn.Module]] = nn.LayerNorm, + softmax_scale: Optional[float] = None, + attn_dropout_p: float = 0.0, + var_attn_bias: Optional[Callable[[], AttentionBias]] = None, + time_attn_bias: Optional[Callable[[], AttentionBias]] = None, + var_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + time_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + num_groups=1, + bias=bias, + norm_layer=norm_layer, + softmax_scale=softmax_scale, + attn_dropout_p=attn_dropout_p, + var_attn_bias=var_attn_bias, + time_attn_bias=time_attn_bias, + var_qk_proj=var_qk_proj, + time_qk_proj=time_qk_proj, + ) + + +class MultiHeadAttention(GroupedQueryAttention): + def __init__( + self, + dim: int, + num_heads: int, + bias: bool = True, + norm_layer: Optional[type[nn.Module] | partial[nn.Module]] = nn.LayerNorm, + softmax_scale: Optional[float] = None, + attn_dropout_p: float = 0.0, + var_attn_bias: Optional[Callable[[], AttentionBias]] = None, + time_attn_bias: Optional[Callable[[], AttentionBias]] = None, + var_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + time_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + num_groups=num_heads, + bias=bias, + norm_layer=norm_layer, + softmax_scale=softmax_scale, + attn_dropout_p=attn_dropout_p, + var_attn_bias=var_attn_bias, + time_attn_bias=time_attn_bias, + var_qk_proj=var_qk_proj, + time_qk_proj=time_qk_proj, + ) diff --git a/OATS/models/tsfm/module/ffn.py b/OATS/models/tsfm/module/ffn.py new file mode 100644 index 0000000..7d88612 --- /dev/null +++ b/OATS/models/tsfm/module/ffn.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch import nn + + +class FeedForward(nn.Module): + def __init__( + self, + in_dim: int, + hidden_dim: Optional[int] = None, + out_dim: Optional[int] = None, + activation: Callable[[torch.Tensor], torch.Tensor] = F.gelu, + bias: bool = True, + ffn_dropout_p: float = 0.0, + ): + super().__init__() + hidden_dim = hidden_dim or 4 * in_dim + out_dim = out_dim or in_dim + + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self.bias = bias + self.ffn_dropout_p = ffn_dropout_p + + self.fc1 = nn.Linear(in_dim, hidden_dim, bias=bias) + self.fc2 = nn.Linear(hidden_dim, out_dim, bias=bias) + self.dropout1 = nn.Dropout(ffn_dropout_p) + self.dropout2 = nn.Dropout(ffn_dropout_p) + self.activation = activation + + def forward( + self, x: Float[torch.Tensor, "... in_dim"] + ) -> Float[torch.Tensor, "... out_dim"]: + x = self._in_proj(x) + return self.dropout2(self.fc2(self.dropout1(x))) + + def _in_proj( + self, x: Float[torch.Tensor, "... in_dim"] + ) -> Float[torch.Tensor, "... out_dim"]: + return self.activation(self.fc1(x)) + + +class GatedLinearUnitFeedForward(FeedForward): + def __init__( + self, + in_dim: int, + hidden_dim: Optional[int] = None, + out_dim: Optional[int] = None, + activation: Callable[[torch.Tensor], torch.Tensor] = F.silu, + bias: bool = True, + ffn_dropout_p: float = 0.0, + ): + super().__init__( + in_dim, + hidden_dim=hidden_dim or self.adjust_hidden_dim(4 * in_dim), + out_dim=out_dim, + activation=activation, + bias=bias, + ffn_dropout_p=ffn_dropout_p, + ) + self.fc_gate = nn.Linear(self.in_dim, self.hidden_dim, bias=self.bias) + + @staticmethod + def adjust_hidden_dim(dim): + return (int(dim * 2 / 3) + 7) // 8 * 8 + + def _in_proj( + self, x: Float[torch.Tensor, "... in_dim"] + ) -> Float[torch.Tensor, "... out_dim"]: + return self.activation(self.fc_gate(x)) * self.fc1(x) diff --git a/OATS/models/tsfm/module/norm.py b/OATS/models/tsfm/module/norm.py new file mode 100644 index 0000000..b665f4e --- /dev/null +++ b/OATS/models/tsfm/module/norm.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from typing import Optional + +import torch +from jaxtyping import Float +from torch import nn + + +class RMSNorm(nn.Module): + def __init__( + self, + normalized_shape: int | list[int] | torch.Size, + eps: float = 1e-5, + weight: bool = True, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + + self.normalized_shape = normalized_shape + self.eps = eps + self.mean_dim = tuple(range(-len(normalized_shape), 0)) + + if weight: + self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def forward( + self, x: Float[torch.Tensor, "*batch normalized_shape"] + ) -> Float[torch.Tensor, "*batch normalized_shape"]: + output = x * torch.rsqrt( + x.pow(2).mean(dim=self.mean_dim, keepdim=True) + self.eps + ) + if self.weight is not None: + return output * self.weight + return output + + def extra_repr(self) -> str: + return ( + f"normalized_shape={self.normalized_shape}, " + f"eps={self.eps}, " + f"weight={self.weight is not None}" + ) diff --git a/OATS/models/tsfm/module/packed_scaler.py b/OATS/models/tsfm/module/packed_scaler.py new file mode 100644 index 0000000..3bfbdb0 --- /dev/null +++ b/OATS/models/tsfm/module/packed_scaler.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from typing import Optional + +import torch +from einops import reduce +from jaxtyping import Bool, Float, Int +from torch import nn + +from tsfm.common.torch_util import safe_div + + +class PackedScaler(nn.Module): + def forward( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"] = None, + sample_id: Int[torch.Tensor, "*batch seq_len"] = None, + variate_id: Optional[Int[torch.Tensor, "*batch seq_len"]] = None, + ): + if observed_mask is None: + observed_mask = torch.ones_like(target, dtype=torch.bool) + if sample_id is None: + sample_id = torch.zeros( + target.shape[:-1], dtype=torch.long, device=target.device + ) + if variate_id is None: + variate_id = torch.zeros( + target.shape[:-1], dtype=torch.long, device=target.device + ) + + loc, scale = self._get_loc_scale( + target.double(), observed_mask, sample_id, variate_id + ) + return loc.float(), scale.float() + + def _get_loc_scale( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> tuple[ + Float[torch.Tensor, "*batch seq_len #dim"], + Float[torch.Tensor, "*batch seq_len #dim"], + ]: + raise NotImplementedError + + +class PackedNOPScaler(PackedScaler): + def _get_loc_scale( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> tuple[ + Float[torch.Tensor, "*batch 1 #dim"], Float[torch.Tensor, "*batch 1 #dim"] + ]: + loc = torch.zeros_like(target, dtype=target.dtype) + scale = torch.ones_like(target, dtype=target.dtype) + return loc, scale + + +class PackedStdScaler(PackedScaler): + def __init__(self, correction: int = 1, minimum_scale: float = 1e-5): + super().__init__() + self.correction = correction + self.minimum_scale = minimum_scale + + def _get_loc_scale( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> tuple[ + Float[torch.Tensor, "*batch 1 #dim"], Float[torch.Tensor, "*batch 1 #dim"] + ]: + id_mask = torch.logical_and( + torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)), + torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)), + ) + tobs = reduce( + id_mask * reduce(observed_mask, "... seq dim -> ... 1 seq", "sum"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + loc = reduce( + id_mask * reduce(target * observed_mask, "... seq dim -> ... 1 seq", "sum"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + loc = safe_div(loc, tobs) + var = reduce( + id_mask + * reduce( + ((target - loc) ** 2) * observed_mask, + "... seq dim -> ... 1 seq", + "sum", + ), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + var = safe_div(var, (tobs - self.correction)) + scale = torch.sqrt(var + self.minimum_scale) + loc[sample_id == 0] = 0 + scale[sample_id == 0] = 1 + return loc, scale + + +class PackedAbsMeanScaler(PackedScaler): + def _get_loc_scale( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> tuple[ + Float[torch.Tensor, "*batch 1 #dim"], Float[torch.Tensor, "*batch 1 #dim"] + ]: + id_mask = torch.logical_and( + torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)), + torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)), + ) + tobs = reduce( + id_mask * reduce(observed_mask, "... seq dim -> ... 1 seq", "sum"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + scale = reduce( + id_mask + * reduce(target.abs() * observed_mask, "... seq dim -> ... 1 seq", "sum"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + scale = safe_div(scale, tobs) + loc = torch.zeros_like(scale) + + loc[sample_id == 0] = 0 + scale[sample_id == 0] = 1 + return loc, scale diff --git a/OATS/models/tsfm/module/position/__init__.py b/OATS/models/tsfm/module/position/__init__.py new file mode 100644 index 0000000..f37f968 --- /dev/null +++ b/OATS/models/tsfm/module/position/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .additive import LearnedEmbedding, SinusoidalPositionEncoding +from .attn_bias import ( + AttentionBias, + BinaryAttentionBias, + LinearAttentionBias, + RelativeAttentionBias, +) +from .attn_projection import ( + IdentityProjection, + LearnedProjection, + Projection, + QueryKeyProjection, + RotaryProjection, +) + +__all__ = [ + "AttentionBias", + "IdentityProjection", + "RelativeAttentionBias", + "BinaryAttentionBias", + "LearnedEmbedding", + "LearnedProjection", + "LinearAttentionBias", + "Projection", + "QueryKeyProjection", + "RotaryProjection", + "SinusoidalPositionEncoding", +] diff --git a/OATS/models/tsfm/module/position/additive.py b/OATS/models/tsfm/module/position/additive.py new file mode 100644 index 0000000..ba1facf --- /dev/null +++ b/OATS/models/tsfm/module/position/additive.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import math + +import torch +from jaxtyping import Float, Int +from torch import nn + + +class SinusoidalPositionEncoding(nn.Module): + def __init__( + self, + *, + width: int, + max_len: int, + normalize: bool = True, + ): + """ + Construct a sinusoidal positional embedding module. + + :param width: + Width of the embedding. + :param max_len: + Maximum length of the embedding. + :param normalize: + Perform L2 normalization of the embedding. + """ + super().__init__() + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, width, 2) * (-math.log(10000.0) / width)) + + pe = torch.zeros(max_len, width) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + if normalize: + l2 = torch.linalg.vector_norm(pe, dim=-1) + pe /= l2.unsqueeze(-1) + + self.register_buffer("pe", pe, persistent=False) + + def forward( + self, pos_id: Int[torch.Tensor, "*batch length"] + ) -> Float[torch.Tensor, "*batch length dim"]: + return self.pe[pos_id] + + +class LearnedEmbedding(nn.Module): + def __init__( + self, + *, + width: int, + max_len: int, + ): + super().__init__() + self.pe = nn.Embedding( + max_len, + width, + ) + + def forward( + self, pos_id: Int[torch.Tensor, "*batch length"] + ) -> Float[torch.Tensor, "*batch length dim"]: + return self.pe(pos_id) diff --git a/OATS/models/tsfm/module/position/attn_bias.py b/OATS/models/tsfm/module/position/attn_bias.py new file mode 100644 index 0000000..773505f --- /dev/null +++ b/OATS/models/tsfm/module/position/attn_bias.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import abc + +import torch +from einops import rearrange +from jaxtyping import Float, Int +from torch import nn + + +class AttentionBias(nn.Module, abc.ABC): + def __init__( + self, + dim: int, + num_heads: int, + num_groups: int, + ): + super().__init__() + assert num_heads > 0 and dim % num_heads == 0 + assert (num_heads % num_groups == 0) and (num_heads >= num_groups) + + self.num_heads = num_heads + self.num_groups = num_groups + self.heads_per_group = num_heads // num_groups + self.head_dim = dim // num_heads + + @abc.abstractmethod + def forward( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_id: Int[torch.Tensor, "*batch 1 1 q_len"], + kv_id: Int[torch.Tensor, "*batch 1 1 kv_len"], + ) -> Float[torch.Tensor, "*batch #group #hpg q_len kv_len"]: ... + + +class RelativeAttentionBias(AttentionBias): + def __init__(self, num_buckets: int, dim: int, num_heads: int, num_groups: int): + super().__init__(dim, num_heads, num_groups) + self.emb = nn.Embedding( + num_embeddings=num_buckets, embedding_dim=self.num_heads + ) + + def forward( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_id: Int[torch.Tensor, "*batch 1 1 q_len"], + kv_id: Int[torch.Tensor, "*batch 1 1 kv_len"], + ) -> Float[torch.Tensor, "*batch #group #hpg q_len kv_len"]: + raise NotImplementedError + + +class BinaryAttentionBias(AttentionBias): + def __init__(self, dim: int, num_heads: int, num_groups: int): + super().__init__(dim, num_heads, num_groups) + self.emb = nn.Embedding(num_embeddings=2, embedding_dim=self.num_heads) + + def forward( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_id: Int[torch.Tensor, "*batch 1 1 q_len"], + kv_id: Int[torch.Tensor, "*batch 1 1 kv_len"], + ) -> Float[torch.Tensor, "*batch #group #hpg q_len kv_len"]: + ind = torch.eq(query_id.unsqueeze(-1), kv_id.unsqueeze(-2)) + weight = rearrange(self.emb.weight, "two num_heads -> two num_heads 1 1") + bias = rearrange( # try to avoid advanced indexing + ~ind * weight[:1] + ind * weight[1:], + "... 1 (group hpg) q_len kv_len -> ... group hpg q_len kv_len", + group=self.num_groups, + hpg=self.heads_per_group, + ) + return bias + + +class LinearAttentionBias(AttentionBias): + def __init__(self, dim: int, num_heads: int, num_groups: int): + super().__init__(dim, num_heads, num_groups) + m = 0.5 ** ((1 + torch.arange(self.num_heads)) * (8 / self.num_heads)) + m = rearrange( + m, + "(group hpg) -> group hpg 1 1", + group=self.num_groups, + hpg=self.heads_per_group, + ) + self.register_buffer("m", m) + + def forward( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_id: Int[torch.Tensor, "*batch 1 1 q_len"], + kv_id: Int[torch.Tensor, "*batch 1 1 kv_len"], + ) -> Float[torch.Tensor, "*batch #group #hpg q_len kv_len"]: + ind = kv_id.unsqueeze(-2) - query_id.unsqueeze(-1) + return self.m * ind diff --git a/OATS/models/tsfm/module/position/attn_projection.py b/OATS/models/tsfm/module/position/attn_projection.py new file mode 100644 index 0000000..1ffa118 --- /dev/null +++ b/OATS/models/tsfm/module/position/attn_projection.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import abc +import math +from functools import cached_property +from typing import Any, Optional + +import torch +from einops import einsum, rearrange, repeat +from jaxtyping import Float, Int +from torch import nn + + +class Projection(nn.Module, abc.ABC): + def __init__(self, proj_width: int, num_heads: int, num_groups: int, **kwargs: Any): + super().__init__() + self.proj_width = proj_width + self.num_heads = num_heads + self.num_groups = num_groups + self.heads_per_group = num_heads // num_groups + + @abc.abstractmethod + def forward( + self, + x: Float[torch.Tensor, "*batch group hpg seq dim"], + seq_id: Optional[Int[torch.Tensor, "*batch #group #hpg seq"]], + ) -> Float[torch.Tensor, "*batch group hpg seq dim"]: ... + + +class IdentityProjection(Projection): + def __init__(self, *, proj_width: int, num_heads: int, num_groups: int, **kwargs): + super().__init__(proj_width, num_heads, num_groups) + + def forward( + self, + x: Float[torch.Tensor, "*batch group hpg seq dim"], + seq_id: Optional[Int[torch.Tensor, "*batch #group #hpg seq"]] = None, + ) -> Float[torch.Tensor, "*batch group hpg seq dim"]: + return x + + +class RotaryProjection(Projection): + def __init__( + self, + *, + proj_width: int, + num_heads: int, + num_groups: int, + max_len: int = 512, + base: int = 10000, + ): + super().__init__(proj_width, num_heads, num_groups) + assert ( + self.proj_width % 2 == 0 + ), f"proj_width must be even, got {self.proj_width}" + self.register_buffer( + "theta", + 1.0 + / torch.pow( + base, + torch.arange(0, self.proj_width, 2, dtype=torch.float) + / self.proj_width, + ), + persistent=False, + ) + self.register_buffer("cos", None, persistent=False) + self.register_buffer("sin", None, persistent=False) + self._init_freq(max_len=max_len) + + def _init_freq(self, max_len: int): + if self.cos is None or self.cos.size(-2) < max_len: + position = torch.arange( + max_len, device=self.theta.device, dtype=self.theta.dtype + ) + m_theta = einsum(position, self.theta, "length, width -> length width") + m_theta = repeat(m_theta, "length width -> length (width 2)") + self.register_buffer("cos", torch.cos(m_theta), persistent=False) + self.register_buffer("sin", torch.sin(m_theta), persistent=False) + + @staticmethod + def _rotate(x: Float[torch.Tensor, "... dim"]) -> Float[torch.Tensor, "... dim"]: + x1, x2 = rearrange(x, "... (dim r) -> r ... dim", r=2) + return rearrange([-x2, x1], "r ... dim -> ... (dim r)", r=2) # noqa + + def forward( + self, + x: Float[torch.Tensor, "*batch group hpg seq dim"], + seq_id: Optional[Int[torch.Tensor, "*batch #group #hpg seq"]], + ) -> Float[torch.Tensor, "*batch group hpg seq dim"]: + self._init_freq(max_len=seq_id.max() + 1) + rot_cos = self.cos[seq_id] + rot_sin = self.sin[seq_id] + return rot_cos * x + rot_sin * self._rotate(x) + + +class LearnedProjection(Projection): + def __init__( + self, + *, + proj_width: int, + num_heads: int, + num_groups: int, + max_len: int = 512, + ): + super().__init__(proj_width, num_heads, num_groups) + self.max_len = max_len + self.weight = nn.Parameter( + torch.empty((max_len, self.proj_width, self.proj_width)) + ) + self.reset_parameters() + + def reset_parameters(self): + for idx in range(self.max_len): + nn.init.kaiming_uniform_(self.weight[idx], a=math.sqrt(5)) + + def forward( + self, + x: Float[torch.Tensor, "*batch group hpg seq dim"], + seq_id: Optional[Int[torch.Tensor, "*batch #group #hpg seq"]], + ) -> Float[torch.Tensor, "*batch group hpg seq dim"]: + weight = self.weight[seq_id] + return einsum(weight, x, "... out inp, ... inp -> ... out") + + +class QueryKeyProjection(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_groups: int, + proj_layer: type[Projection], + kwargs: Optional[dict[str, Any]] = None, + key_proj_layer: Optional[type[Projection]] = None, + key_kwargs: Optional[dict[str, Any]] = None, + partial_factor: Optional[tuple[float, float]] = None, + ): + super().__init__() + if partial_factor is not None: + assert ( + 0.0 <= partial_factor[0] < partial_factor[1] <= 1.0 + ), f"got {partial_factor[0]}, {partial_factor[1]}" + assert num_heads > 0 and dim % num_heads == 0 + assert (num_heads % num_groups == 0) and (num_heads >= num_groups) + + self.head_dim = dim // num_heads + self.partial_factor = partial_factor + self.query_proj = proj_layer( + proj_width=self.proj_width, + num_heads=num_heads, + num_groups=num_groups, + **(kwargs or {}), + ) + if key_proj_layer is None: + self.key_proj = self.query_proj + else: + self.key_proj = key_proj_layer( + proj_width=self.proj_width, + num_heads=num_heads, + num_groups=num_groups, + **(key_kwargs or {}), + ) + + @cached_property + def proj_width(self) -> int: + if self.partial_factor is None: + return self.head_dim + return int(self.head_dim * (self.partial_factor[1] - self.partial_factor[0])) + + @cached_property + def split_sizes(self) -> tuple[int, int, int]: + if self.partial_factor is None: + return 0, self.head_dim, 0 + return ( + int(self.partial_factor[0] * self.head_dim), + self.proj_width, + int((1.0 - self.partial_factor[1]) * self.head_dim), + ) + + def forward( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_id: Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], + kv_id: Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], + ) -> tuple[ + Float[torch.Tensor, "*batch group hpg seq dim"], + Float[torch.Tensor, "*batch group hpg seq dim"], + ]: + if self.partial_factor is not None: + queries = list(query.split(self.split_sizes, dim=-1)) + keys = list(key.split(self.split_sizes, dim=-1)) + queries[1] = self.query_proj(queries[1], seq_id=query_id) + keys[1] = self.key_proj(keys[1], seq_id=kv_id) + query = torch.cat(queries, dim=-1) + key = torch.cat(keys, dim=-1) + else: + query = self.query_proj(query, seq_id=query_id) + key = self.key_proj(key, seq_id=kv_id) + return query, key diff --git a/OATS/models/tsfm/module/transformer.py b/OATS/models/tsfm/module/transformer.py new file mode 100644 index 0000000..ff2ad45 --- /dev/null +++ b/OATS/models/tsfm/module/transformer.py @@ -0,0 +1,193 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from collections.abc import Callable +from functools import partial +from typing import Optional + +import torch +import torch.nn.functional as F +from jaxtyping import Bool, Float, Int +from torch import nn + +from .attention import MultiHeadAttention +from .ffn import FeedForward, GatedLinearUnitFeedForward +from .position import AttentionBias, QueryKeyProjection + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + self_attn: MultiHeadAttention, + ffn: FeedForward, + norm1: Optional[nn.Module], + norm2: Optional[nn.Module], + post_attn_dropout_p: float = 0.0, + pre_norm: bool = True, + ): + super().__init__() + self.pre_norm = pre_norm + self.dropout_p = post_attn_dropout_p + + self.self_attn = self_attn + self.ffn = ffn + self.norm1 = norm1 or nn.Identity() + self.norm2 = norm2 or nn.Identity() + self.dropout = nn.Dropout(post_attn_dropout_p) + + def forward( + self, + x: Float[torch.Tensor, "*batch time_len dim"], + attn_mask: Optional[Bool[torch.Tensor, "*batch time_len time_len"]] = None, + var_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + time_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + ) -> Float[torch.Tensor, "*batch time_len dim"]: + if self.pre_norm: + x = x + self._sa_block( + self.norm1(x), attn_mask, var_id=var_id, time_id=time_id + ) + x = x + self.ffn(self.norm2(x)) + else: + x = self.norm1( + x + self._sa_block(x, attn_mask, var_id=var_id, time_id=time_id) + ) + x = self.norm2(x + self.ffn(x)) + + return x + + def _sa_block( + self, + x: Float[torch.Tensor, "*batch time_len dim"], + attn_mask: Optional[Bool[torch.Tensor, "*batch time_len time_len"]], + var_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + time_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + ) -> Float[torch.Tensor, "*batch time_len dim"]: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + query_var_id=var_id, + kv_var_id=var_id, + query_time_id=time_id, + kv_time_id=time_id, + ) + return self.dropout(x) + + +class TransformerEncoder(nn.Module): + def __init__( + self, + d_model: int, + num_layers: int, + num_heads: Optional[int] = None, + num_groups: Optional[int] = None, + pre_norm: bool = True, + attn_dropout_p: float = 0.0, + dropout_p: float = 0.0, + norm_layer: Optional[Callable[[int], nn.Module]] = nn.LayerNorm, + activation: Callable[[torch.Tensor], torch.Tensor] = F.silu, + use_glu: bool = True, + use_qk_norm: bool = True, + var_attn_bias_layer: Optional[Callable[[int, int, int], AttentionBias]] = None, + time_attn_bias_layer: Optional[Callable[[int, int, int], AttentionBias]] = None, + var_qk_proj_layer: Optional[ + Callable[[int, int, int], QueryKeyProjection] + ] = None, + time_qk_proj_layer: Optional[ + Callable[[int, int, int], QueryKeyProjection] + ] = None, + shared_var_attn_bias: bool = False, + shared_time_attn_bias: bool = False, + shared_var_qk_proj: bool = False, + shared_time_qk_proj: bool = False, + d_ff: Optional[int] = None, + ): + super().__init__() + num_heads = num_heads or d_model // 64 + num_groups = num_groups or num_heads # defaults to mha + + var_attn_bias = self.get_layer( + d_model, + num_heads, + num_groups, + var_attn_bias_layer, + shared_var_attn_bias, + ) + time_attn_bias = self.get_layer( + d_model, + num_heads, + num_groups, + time_attn_bias_layer, + shared_time_attn_bias, + ) + var_qk_proj = self.get_layer( + d_model, num_heads, num_groups, var_qk_proj_layer, shared_var_qk_proj + ) + time_qk_proj = self.get_layer( + d_model, num_heads, num_groups, time_qk_proj_layer, shared_time_qk_proj + ) + + get_self_attn = partial( + MultiHeadAttention, + dim=d_model, + num_heads=num_heads, + bias=False, + norm_layer=norm_layer if use_qk_norm else None, + softmax_scale=None, + attn_dropout_p=attn_dropout_p, + var_attn_bias=var_attn_bias, + time_attn_bias=time_attn_bias, + var_qk_proj=var_qk_proj, + time_qk_proj=time_qk_proj, + ) + get_ffn = partial( + GatedLinearUnitFeedForward if use_glu else FeedForward, + in_dim=d_model, + hidden_dim=d_ff, + out_dim=None, + activation=activation, + bias=False, + ffn_dropout_p=dropout_p, + ) + get_encoder_layer_norm = partial(norm_layer, d_model) + + self.layers = nn.ModuleList( + [ + TransformerEncoderLayer( + self_attn=get_self_attn(), + ffn=get_ffn(), + norm1=get_encoder_layer_norm(), + norm2=get_encoder_layer_norm(), + pre_norm=pre_norm, + post_attn_dropout_p=dropout_p, + ) + for _ in range(num_layers) + ] + ) + self.norm = norm_layer(d_model) + + @staticmethod + def get_layer( + dim: int, + num_heads: int, + num_groups: int, + layer: Callable, + shared_layer: bool, + ) -> Optional[Callable[[], nn.Module]]: + if layer is None: + return None + if shared_layer: + module = layer(dim=dim, num_heads=num_heads, num_groups=num_groups) + return lambda: module + return partial(layer, dim=dim, num_heads=num_heads, num_groups=num_groups) + + def forward( + self, + x: Float[torch.Tensor, "*batch time_len dim"], + attn_mask: Optional[Bool[torch.Tensor, "*batch time_len time_len"]] = None, + var_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + time_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + ) -> Float[torch.Tensor, "*batch time_len dim"]: + for layer in self.layers: + x = layer(x, attn_mask, var_id=var_id, time_id=time_id) + return self.norm(x) diff --git a/OATS/models/tsfm/module/ts_embed.py b/OATS/models/tsfm/module/ts_embed.py new file mode 100644 index 0000000..8d03a6e --- /dev/null +++ b/OATS/models/tsfm/module/ts_embed.py @@ -0,0 +1,176 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import math +from typing import Optional + +import torch +from einops import einsum, rearrange +from jaxtyping import Float, Int +from torch import nn + +from tsfm.common.torch_util import size_to_mask + + +def fs2idx( + feat_size: Int[torch.Tensor, "*batch"], feat_sizes: Int[torch.Tensor, "num_feats"] +) -> Int[torch.Tensor, "*batch"]: + return ( + (rearrange(feat_size, "... -> ... 1") == feat_sizes) + .to(torch.long) + .argmax(dim=-1) + ) + + +class MultiInSizeLinear(nn.Module): + def __init__( + self, + in_features_ls: tuple[int, ...], + out_features: int, + bias: bool = True, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.in_features_ls = in_features_ls + self.out_features = out_features + + self.weight = nn.Parameter( + torch.empty( + (len(in_features_ls), out_features, max(in_features_ls)), dtype=dtype + ) + ) + + if bias: + self.bias = nn.Parameter( + torch.empty((len(in_features_ls), out_features), dtype=dtype) + ) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + self.register_buffer( + "mask", + rearrange( + size_to_mask(max(in_features_ls), torch.as_tensor(in_features_ls)), + "num_feats max_feat -> num_feats 1 max_feat", + ), + persistent=False, + ) + self.register_buffer( + "in_features_buffer", + torch.tensor(in_features_ls), + persistent=False, + ) + + def reset_parameters(self): + for idx, feat_size in enumerate(self.in_features_ls): + nn.init.kaiming_uniform_(self.weight[idx, :, :feat_size], a=math.sqrt(5)) + nn.init.zeros_(self.weight[idx, :, feat_size:]) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self.weight[idx, :, :feat_size] + ) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias[idx], -bound, bound) + + def forward( + self, + x: Float[torch.Tensor, "*batch max_feat"], + in_feat_size: Int[torch.Tensor, "*batch"], + ) -> Float[torch.Tensor, "*batch out_feat"]: + out = 0 + for idx, feat_size in enumerate(self.in_features_ls): + weight = self.weight[idx] * self.mask[idx] + bias = self.bias[idx] if self.bias is not None else 0 + out = out + ( + torch.eq(in_feat_size, feat_size).unsqueeze(-1) + * (einsum(weight, x, "out inp, ... inp -> ... out") + bias) + ) + return out + + def extra_repr(self) -> str: + return ( + f"in_features_ls={self.in_features_ls}, " + f"out_features={self.out_features}, " + f"bias={self.bias is not None}, " + f"dtype={self.weight.dtype}" + ) + + +class MultiOutSizeLinear(nn.Module): + def __init__( + self, + in_features: int, + out_features_ls: tuple[int, ...], + dim: int = 1, + bias: bool = True, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.in_features = in_features + self.out_features_ls = out_features_ls + self.dim = dim + + self.weight = nn.Parameter( + torch.empty( + (len(out_features_ls), max(out_features_ls), in_features), dtype=dtype + ) + ) + + if bias: + self.bias = nn.Parameter( + torch.empty((len(out_features_ls), max(out_features_ls)), dtype=dtype) + ) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + self.register_buffer( + "mask", + rearrange( + size_to_mask(max(out_features_ls), torch.as_tensor(out_features_ls)), + "num_feats max_feat -> num_feats max_feat 1", + ), + persistent=False, + ) + self.register_buffer( + "out_features_buffer", + torch.tensor(out_features_ls), + persistent=False, + ) + + def reset_parameters(self): + for idx, feat_size in enumerate(self.out_features_ls): + nn.init.kaiming_uniform_(self.weight[idx, :feat_size], a=math.sqrt(5)) + nn.init.zeros_(self.weight[idx, feat_size:]) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self.weight[idx, :feat_size] + ) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias[idx, :feat_size], -bound, bound) + nn.init.zeros_(self.bias[idx, feat_size:]) + + def forward( + self, + x: Float[torch.Tensor, "*batch in_feat"], + out_feat_size: Int[torch.Tensor, "*batch"], + ) -> Float[torch.Tensor, "*batch max_feat"]: + out = 0 + for idx, feat_size in enumerate(self.out_features_ls): + weight = self.weight[idx] * self.mask[idx] + bias = self.bias[idx] if self.bias is not None else 0 + out = out + ( + torch.eq(out_feat_size, feat_size // self.dim).unsqueeze(-1) + * (einsum(weight, x, "out inp, ... inp -> ... out") + bias) + ) + return out + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, " + f"out_features_ls={self.out_features_ls}, " + f"bias={self.bias is not None}, " + f"dtype={self.weight.dtype}" + ) diff --git a/OATS/models/tsfm/optim/__init__.py b/OATS/models/tsfm/optim/__init__.py new file mode 100644 index 0000000..99b90bb --- /dev/null +++ b/OATS/models/tsfm/optim/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from .lr_scheduler import SchedulerType, get_scheduler + +__all__ = ["SchedulerType", "get_scheduler"] diff --git a/OATS/models/tsfm/optim/lr_scheduler.py b/OATS/models/tsfm/optim/lr_scheduler.py new file mode 100644 index 0000000..afb56c0 --- /dev/null +++ b/OATS/models/tsfm/optim/lr_scheduler.py @@ -0,0 +1,437 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import math +from enum import Enum +from functools import partial +from typing import Optional + +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau + + +def _get_constant_lambda(_=None): + return 1 + + +def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch) + + +def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs): + """ + Create a schedule with a constant learning rate that decreases when a metric has stopped improving. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + kwargs (`dict`, *optional*): + Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau` + for possible parameters. + + Return: + `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule. + """ + + return ReduceLROnPlateau(optimizer, **kwargs) + + +def _get_constant_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + +def get_constant_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1 +): + """ + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps + ) + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def _get_linear_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, + float(num_training_steps - current_step) + / float(max(1, num_training_steps - num_warmup_steps)), + ) + + +def get_linear_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_linear_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + return max( + 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + ) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: int, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + if progress >= 1.0: + return 0.0 + return max( + 0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))) + ) + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: int = 1, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_polynomial_decay_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + lr_end: float, + power: float, + lr_init: int, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + +def get_polynomial_decay_schedule_with_warmup( + optimizer, + num_warmup_steps, + num_training_steps, + lr_end=1e-7, + power=1.0, + last_epoch=-1, +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError( + f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})" + ) + + lr_lambda = partial( + _get_polynomial_decay_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + lr_end=lr_end, + power=power, + lr_init=lr_init, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_inverse_sqrt_schedule_lr_lambda( + current_step: int, *, num_warmup_steps: int, timescale: int = None +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + shift = timescale - num_warmup_steps + decay = 1.0 / math.sqrt((current_step + shift) / timescale) + return decay + + +def get_inverse_sqrt_schedule( + optimizer: Optimizer, + num_warmup_steps: int, + timescale: int = None, + last_epoch: int = -1, +): + """ + Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a + warmup period which increases lr linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + timescale (`int`, *optional*, defaults to `num_warmup_steps`): + Time scale. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + if timescale is None: + timescale = num_warmup_steps + + lr_lambda = partial( + _get_inverse_sqrt_schedule_lr_lambda, + num_warmup_steps=num_warmup_steps, + timescale=timescale, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +class SchedulerType(Enum): + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + INVERSE_SQRT = "inverse_sqrt" + REDUCE_ON_PLATEAU = "reduce_lr_on_plateau" + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, + SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, + SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, + SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule, +} + + +def get_scheduler( + name: str | SchedulerType, + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + scheduler_specific_kwargs: Optional[dict] = None, +): + """ + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + scheduler_specific_kwargs (`dict`, *optional*): + Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler + parameters will cause the scheduler function to raise a TypeError. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + if scheduler_specific_kwargs is None: + scheduler_specific_kwargs = {} + + if name == SchedulerType.REDUCE_ON_PLATEAU: + return schedule_func(optimizer, **scheduler_specific_kwargs) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError( + f"{name} requires `num_warmup_steps`, please provide that argument." + ) + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError( + f"{name} requires `num_training_steps`, please provide that argument." + ) + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **scheduler_specific_kwargs, + ) diff --git a/OATS/models/tsfm/transform/__init__.py b/OATS/models/tsfm/transform/__init__.py new file mode 100644 index 0000000..2479f55 --- /dev/null +++ b/OATS/models/tsfm/transform/__init__.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024, +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified by Microsoft Corporation. +# Licensed under the MIT license. + + +from ._base import Chain, Identity, Transformation +from .crop import EvalCrop, PatchCrop, EvalCrop_AdaLength +from .feature import AddObservedMask, AddTimeIndex, AddVariateIndex +from .field import LambdaSetFieldIfNotPresent, RemoveFields, SelectFields, SetValue +from .imputation import DummyValueImputation, ImputeTimeSeries, LastValueImputation +from .pad import EvalPad, Pad, PadFreq, EvalPad_AdaLength +from .patch import ( + DefaultPatchSizeConstraints, + FixedPatchSizeConstraints, + GetPatchSize, + Patchify, + PatchSizeConstraints, +) +from .resample import SampleDimension +from .reshape import ( + FlatPackCollection, + FlatPackFields, + PackCollection, + PackFields, + SequencifyField, + Transpose, +) +from .task import EvalMaskedPrediction, ExtendMask, MaskedPrediction, NextTokenPrediction, EvalNextTokenPrediction + +__all__ = [ + "AddObservedMask", + "AddTimeIndex", + "AddVariateIndex", + "Chain", + "DefaultPatchSizeConstraints", + "DummyValueImputation", + "EvalCrop", + "EvalMaskedPrediction", + "EvalPad", + "ExtendMask", + "FixedPatchSizeConstraints", + "FlatPackCollection", + "FlatPackFields", + "GetPatchSize", + "Identity", + "ImputeTimeSeries", + "LambdaSetFieldIfNotPresent", + "LastValueImputation", + "MaskedPrediction", + "NextTokenPrediction", + "PackCollection", + "PackFields", + "Pad", + "PadFreq", + "PatchCrop", + "PatchSizeConstraints", + "Patchify", + "RemoveFields", + "SampleDimension", + "SelectFields", + "SequencifyField", + "SetValue", + "Transformation", + "Transpose", + "EvalPad_AdaLength", + "EvalCrop_AdaLength", + "EvalNextTokenPrediction" +] diff --git a/OATS/models/tsfm/transform/_base.py b/OATS/models/tsfm/transform/_base.py new file mode 100644 index 0000000..e1e9694 --- /dev/null +++ b/OATS/models/tsfm/transform/_base.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import abc +from dataclasses import dataclass +from typing import Any + + +class Transformation(abc.ABC): + @abc.abstractmethod + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: ... + + def chain(self, other: "Transformation") -> "Chain": + return Chain([self, other]) + + def __add__(self, other: "Transformation") -> "Chain": + return self.chain(other) + + def __radd__(self, other): + if other == 0: + return self + return other + self + + +@dataclass +class Chain(Transformation): + """ + Chain multiple transformations together. + """ + + transformations: list[Transformation] + + def __post_init__(self) -> None: + transformations = [] + + for transformation in self.transformations: + if isinstance(transformation, Identity): + continue + elif isinstance(transformation, Chain): + transformations.extend(transformation.transformations) + else: + assert isinstance(transformation, Transformation) + transformations.append(transformation) + + self.transformations = transformations + self.__init_passed_kwargs__ = {"transformations": transformations} + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + for t in self.transformations: + data_entry = t(data_entry) + return data_entry + + +class Identity(Transformation): + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + return data_entry diff --git a/OATS/models/tsfm/transform/_mixin.py b/OATS/models/tsfm/transform/_mixin.py new file mode 100644 index 0000000..930d938 --- /dev/null +++ b/OATS/models/tsfm/transform/_mixin.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections.abc import Callable +from typing import Any + +import numpy as np + + +class MapFuncMixin: + @staticmethod + def map_func( + func: Callable[[dict[str, Any], str], Any], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ): + for field in fields: + data_entry[field] = func(data_entry, field) + for field in optional_fields: + if field in data_entry: + data_entry[field] = func(data_entry, field) + + +class ApplyFuncMixin: + @staticmethod + def apply_func( + func: Callable[[dict[str, Any], str], None], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ): + for field in fields: + func(data_entry, field) + for field in optional_fields: + if field in data_entry: + func(data_entry, field) + + +class CollectFuncMixin: + @staticmethod + def collect_func_list( + func: Callable[[dict[str, Any], str], Any], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ) -> list[Any]: + collect = [] + for field in fields: + collect.append(func(data_entry, field)) + for field in optional_fields: + if field in data_entry: + collect.append(func(data_entry, field)) + return collect + + @staticmethod + def collect_func_dict( + func: Callable[[dict[str, Any], str], Any], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ) -> dict[str, Any]: + collect = {} + for field in fields: + collect[field] = func(data_entry, field) + for field in optional_fields: + if field in data_entry: + collect[field] = func(data_entry, field) + return collect + + def collect_func( + self, + func: Callable[[dict[str, Any], str], Any], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ) -> list[Any] | dict[str, Any]: + if not hasattr(self, "collection_type"): + raise NotImplementedError( + f"{self.__class__.__name__} has no attribute 'collection_type', " + "please use collect_func_list or collect_func_dict instead." + ) + + collection_type = getattr(self, "collection_type") + if collection_type == list: + collect_func = self.collect_func_list + elif collection_type == dict: + collect_func = self.collect_func_dict + else: + raise ValueError(f"Unknown collection_type: {collection_type}") + + return collect_func( + func, + data_entry, + fields, + optional_fields=optional_fields, + ) + + +class CheckArrNDimMixin: + def check_ndim(self, name: str, arr: np.ndarray, expected_ndim: int): + if isinstance(arr, list): + self.check_ndim(name, arr[0], expected_ndim - 1) + return + + if arr.ndim != expected_ndim: + raise AssertionError( + f"Array '{name}' for {self.__class__.__name__} " + f"has expected ndim: {expected_ndim}, " + f"but got ndim: {arr.ndim} of shape {arr.shape}." + ) diff --git a/OATS/models/tsfm/transform/crop.py b/OATS/models/tsfm/transform/crop.py new file mode 100644 index 0000000..ac16ea9 --- /dev/null +++ b/OATS/models/tsfm/transform/crop.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections.abc import Sequence +from dataclasses import dataclass +from functools import partial +from typing import Any + +import numpy as np + +from tsfm.common.typing import UnivarTimeSeries + +from ._base import Transformation +from ._mixin import MapFuncMixin + + +@dataclass +class PatchCrop(MapFuncMixin, Transformation): + """ + Crop fields in a data_entry in the temporal dimension based on a patch_size. + :param rng: numpy random number generator + :param min_time_patches: minimum number of patches for time dimension + :param max_patches: maximum number of patches for time * dim dimension (if flatten) + :param will_flatten: whether time series fields will be flattened subsequently + :param offset: whether to offset the start of the crop + :param fields: fields to crop + """ + + min_time_patches: int + max_patches: int + will_flatten: bool = False + offset: bool = True + fields: tuple[str, ...] = ("target",) + optional_fields: tuple[str, ...] = ("past_feat_dynamic_real",) + + def __post_init__(self): + assert ( + self.min_time_patches <= self.max_patches + ), "min_patches must be <= max_patches" + assert len(self.fields) > 0, "fields must be non-empty" + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + a, b = self._get_boundaries(data_entry) + self.map_func( + partial(self._crop, a=a, b=b), # noqa + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + @staticmethod + def _crop(data_entry: dict[str, Any], field: str, a: int, b: int) -> Sequence: + return [ts[a:b] for ts in data_entry[field]] + + def _get_boundaries(self, data_entry: dict[str, Any]) -> tuple[int, int]: + patch_size = data_entry["patch_size"] + field: list[UnivarTimeSeries] = data_entry[self.fields[0]] + time = field[0].shape[0] + nvar = ( + sum(len(data_entry[f]) for f in self.fields) + + sum(len(data_entry[f]) for f in self.optional_fields if f in data_entry) + if self.will_flatten + else 1 + ) + + offset = ( + np.random.randint( + time % patch_size + 1 + ) # offset by [0, patch_size) so that the start is not always a multiple of patch_size + if self.offset + else 0 + ) + total_patches = ( + time - offset + ) // patch_size # total number of patches in time series + + # 1. max_patches should be divided by nvar if the time series is subsequently flattened + # 2. cannot have more patches than total available patches + max_patches = min(self.max_patches // nvar, total_patches) + if max_patches < self.min_time_patches: + raise ValueError( + f"max_patches={max_patches} < min_time_patches={self.min_time_patches}" + ) + + min_time_patches = round(max_patches*0.5) + num_patches = np.random.randint( + min_time_patches, max_patches + 1 + ) # number of patches to consider + first = np.random.randint( + total_patches - num_patches + 1 + ) # first patch to consider + + start = offset + first * patch_size + stop = start + num_patches * patch_size + return start, stop + + +@dataclass +class EvalCrop(MapFuncMixin, Transformation): + offset: int + distance: int + prediction_length: int + context_length: int + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + a, b = self._get_boundaries(data_entry) + self.map_func( + partial(self._crop, a=a, b=b), # noqa + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + @staticmethod + def _crop(data_entry: dict[str, Any], field: str, a: int, b: int) -> Sequence: + return [ts[a : b or None] for ts in data_entry[field]] + + def _get_boundaries(self, data_entry: dict[str, Any]) -> tuple[int, int]: + field: list[UnivarTimeSeries] = data_entry[self.fields[0]] + time = field[0].shape[0] + window = data_entry["window"] + fcst_start = self.offset + window * self.distance + a = fcst_start - self.context_length + b = fcst_start + self.prediction_length + + if self.offset >= 0: + assert time >= b > a >= 0 + else: + assert 0 >= b > a >= -time + + return a, b + + +@dataclass +class EvalCrop_AdaLength(MapFuncMixin, Transformation): + offset: int + distance: int + prediction_length: int + context_length: int + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + a, b = self._get_boundaries(data_entry) + self.map_func( + partial(self._crop, a=a, b=b), # noqa + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + @staticmethod + def _crop(data_entry: dict[str, Any], field: str, a: int, b: int) -> Sequence: + return [ts[a : b or None] for ts in data_entry[field]] + + def _get_boundaries(self, data_entry: dict[str, Any]) -> tuple[int, int]: + field: list[UnivarTimeSeries] = data_entry[self.fields[0]] + time = field[0].shape[0] + + if self.context_length == -7777: + if time <= 1000 + self.prediction_length: + a, b = 0, time + else: + a = 0 + b = a + 1000 + self.prediction_length + return a, b + + window = data_entry["window"] + fcst_start = self.offset + window * self.distance + a = fcst_start - self.context_length + b = fcst_start + self.prediction_length + + if self.offset >= 0: + assert time >= b > a >= 0, f"{time}, {b}, {a}" + else: + assert 0 >= b > a >= -time + + return a, b diff --git a/OATS/models/tsfm/transform/feature.py b/OATS/models/tsfm/transform/feature.py new file mode 100644 index 0000000..f270daa --- /dev/null +++ b/OATS/models/tsfm/transform/feature.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from typing import Any + +import numpy as np +from einops import repeat + +from ._base import Transformation +from ._mixin import CheckArrNDimMixin, CollectFuncMixin + + +@dataclass +class AddVariateIndex(CollectFuncMixin, CheckArrNDimMixin, Transformation): + """ + Add variate_id to data_entry + """ + + fields: tuple[str, ...] + max_dim: int + optional_fields: tuple[str, ...] = tuple() + variate_id_field: str = "variate_id" + expected_ndim: int = 2 + randomize: bool = False + collection_type: type = list + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + self.counter = 0 + self.dimensions = ( + np.random.choice(self.max_dim, size=self.max_dim, replace=False) + if self.randomize + else list(range(self.max_dim)) + ) + data_entry[self.variate_id_field] = self.collect_func( + self._generate_variate_id, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def _generate_variate_id( + self, data_entry: dict[str, Any], field: str + ) -> np.ndarray: + arr = data_entry[field] + self.check_ndim(field, arr, self.expected_ndim) + dim, time = arr.shape[:2] + if self.counter + dim > self.max_dim: + raise ValueError( + f"Variate ({self.counter + dim}) exceeds maximum variate {self.max_dim}. " + ) + field_dim_id = repeat( + np.asarray(self.dimensions[self.counter : self.counter + dim], dtype=int), + "var -> var time", + time=time, + ) + self.counter += dim + return field_dim_id + + +@dataclass +class AddTimeIndex(CollectFuncMixin, CheckArrNDimMixin, Transformation): + """ + Add time_id to data_entry + """ + + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + time_id_field: str = "time_id" + expected_ndim: int = 2 + collection_type: type = list + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + """ + add sequence_id + """ + data_entry[self.time_id_field] = self.collect_func( + self._generate_time_id, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def _generate_time_id(self, data_entry: dict[str, Any], field: str) -> np.ndarray: + arr = data_entry[field] + self.check_ndim(field, arr, self.expected_ndim) + var, time = arr.shape[:2] + field_seq_id = np.arange(time) + field_seq_id = repeat(field_seq_id, "time -> var time", var=var) + return field_seq_id + + +@dataclass +class AddObservedMask(CollectFuncMixin, Transformation): + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + observed_mask_field: str = "observed_mask" + collection_type: type = list + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + observed_mask = self.collect_func( + self._generate_observed_mask, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + data_entry[self.observed_mask_field] = observed_mask + return data_entry + + @staticmethod + def _generate_observed_mask(data_entry: dict[str, Any], field: str) -> np.ndarray: + arr = data_entry[field] + return ~np.isnan(arr) diff --git a/OATS/models/tsfm/transform/field.py b/OATS/models/tsfm/transform/field.py new file mode 100644 index 0000000..a388f75 --- /dev/null +++ b/OATS/models/tsfm/transform/field.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from ._base import Transformation + + +@dataclass +class SetValue: + value: Any + + def __call__(self, data_entry: dict[str, Any]) -> Any: + return self.value + + +@dataclass +class LambdaSetFieldIfNotPresent(Transformation): + field: str + get_value: Callable[[dict[str, Any]], Any] + + @staticmethod + def set_field(data_entry: dict[str, Any], field: str, value: Any) -> dict[str, Any]: + if field not in data_entry.keys(): + data_entry[field] = value + return data_entry + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + return self.set_field(data_entry, self.field, self.get_value(data_entry)) + + +@dataclass +class SelectFields(Transformation): + fields: list[str] + allow_missing: bool = False + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + if self.allow_missing: + return {f: data_entry[f] for f in self.fields if f in data_entry} + return {f: data_entry[f] for f in self.fields} + + +@dataclass +class RemoveFields(Transformation): + fields: list[str] + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + for k in self.fields: + data_entry.pop(k, None) + return data_entry diff --git a/OATS/models/tsfm/transform/imputation.py b/OATS/models/tsfm/transform/imputation.py new file mode 100644 index 0000000..61bc612 --- /dev/null +++ b/OATS/models/tsfm/transform/imputation.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from typing import Any + +import numpy as np +from jaxtyping import Num + +from ._base import Transformation +from ._mixin import ApplyFuncMixin + + +class ImputationMethod: + def __call__( + self, x: Num[np.ndarray, "length *dim"] + ) -> Num[np.ndarray, "length *dim"]: ... + + +@dataclass(frozen=True) +class DummyValueImputation(ImputationMethod): + value: int | float | complex = 0.0 + + def __call__( + self, x: Num[np.ndarray, "length *dim"] + ) -> Num[np.ndarray, "length *dim"]: + x[np.isnan(x)] = self.value + return x + + +@dataclass(frozen=True) +class LastValueImputation(ImputationMethod): + value: int | float | complex = 0.0 + + def __call__( + self, x: Num[np.ndarray, "length *dim"] + ) -> Num[np.ndarray, "length *dim"]: + x = x.T + x[0:1][np.isnan(x[0:1])] = self.value + mask = np.isnan(x) + idx = np.arange(len(x)) + if x.ndim == 2: + idx = np.expand_dims(idx, axis=1) + idx = np.where(~mask, idx, 0) + idx = np.maximum.accumulate(idx, axis=0) + if x.ndim == 2: + x = x[idx, np.arange(x.shape[1])] + else: + x = x[idx] + return x.T + + +class CausalMeanImputation(ImputationMethod): + # TODO: implement causal mean imputation + def __call__( + self, x: Num[np.ndarray, "length *dim"], value: int | float | complex = 0.0 + ) -> Num[np.ndarray, "length *dim"]: ... + + +@dataclass +class ImputeTimeSeries(ApplyFuncMixin, Transformation): + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + imputation_method: ImputationMethod = DummyValueImputation(value=0.0) + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + self.apply_func( + self._impute, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def _impute(self, data_entry: dict[str, Any], field: str): + value = data_entry[field] + nan_entries = np.isnan(value) + if nan_entries.any(): + data_entry[field] = self.imputation_method(value) diff --git a/OATS/models/tsfm/transform/pad.py b/OATS/models/tsfm/transform/pad.py new file mode 100644 index 0000000..f9018a5 --- /dev/null +++ b/OATS/models/tsfm/transform/pad.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from ._base import Transformation +from ._mixin import MapFuncMixin + + +@dataclass +class Pad(MapFuncMixin, Transformation): + min_length: int + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + self.map_func( + self.map, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def map(self, data_entry: dict[str, Any], field: str) -> Any: + arr = data_entry[field] + length = arr.shape[-1] + if length < self.min_length: + pad_amount = self.min_length - length + front_pad = np.random.randint(0, pad_amount + 1) + back_pad = pad_amount - front_pad + pad_width = [(0, 0) for _ in range(arr.ndim)] + pad_width[-1] = (front_pad, back_pad) + arr = np.pad(arr, pad_width, mode="constant", constant_values=np.nan) + return arr + + +@dataclass +class PadFreq(MapFuncMixin, Transformation): + freq_min_length_map: dict[str, int] + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + freq_field: str = "freq" + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + self.map_func( + self.map, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def map(self, data_entry: dict[str, Any], field: str) -> Any: + arr = data_entry[field] + length = arr.shape[-1] + min_length = self.freq_min_length_map[data_entry[self.freq_field]] + if length < min_length: + pad_amount = min_length - length + front_pad = np.random.randint(0, pad_amount + 1) + back_pad = pad_amount - front_pad + pad_width = [(0, 0) for _ in range(arr.ndim)] + pad_width[-1] = (front_pad, back_pad) + arr = np.pad(arr, pad_width, mode="constant", constant_values=np.nan) + return arr + + +@dataclass +class EvalPad(MapFuncMixin, Transformation): + prediction_pad: int + context_pad: int + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + self.map_func( + self.map, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def map(self, data_entry: dict[str, Any], field: str) -> Any: + arr = data_entry[field] + pad_width = [(0, 0) for _ in range(arr.ndim)] + pad_width[-1] = (self.context_pad, self.prediction_pad) + arr = np.pad(arr, pad_width, mode="constant", constant_values=np.nan) + return arr + + +@dataclass +class EvalPad_AdaLength(MapFuncMixin, Transformation): + prediction_length: int + context_length: int + patch_size: int + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + self.map_func( + self.map, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def map(self, data_entry: dict[str, Any], field: str) -> Any: + if self.context_length == -7777: + context_length = data_entry["target"].shape[-1] - self.prediction_length + else: + context_length = self.context_length + prediction_length = self.prediction_length + + context_pad = (-context_length) % self.patch_size + prediction_pad = (-prediction_length) % self.patch_size + + arr = data_entry[field] + pad_width = [(0, 0) for _ in range(arr.ndim)] + pad_width[-1] = (context_pad, prediction_pad) + arr = np.pad(arr, pad_width, mode="constant", constant_values=np.nan) + return arr diff --git a/OATS/models/tsfm/transform/patch.py b/OATS/models/tsfm/transform/patch.py new file mode 100644 index 0000000..e4dffe4 --- /dev/null +++ b/OATS/models/tsfm/transform/patch.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import abc +from dataclasses import dataclass +from functools import partial +from typing import Any, Optional + +import numpy as np +import pandas as pd +from einops import rearrange +from gluonts.time_feature import norm_freq_str +from jaxtyping import Num + +from tsfm.common.typing import UnivarTimeSeries + +from ._base import Transformation +from ._mixin import MapFuncMixin + + +class PatchSizeConstraints(abc.ABC): + @abc.abstractmethod + def _get_boundaries(self, n: int, offset_name: str) -> tuple[int, int]: ... + + def __call__(self, freq: str) -> range: + offset = pd.tseries.frequencies.to_offset(freq) + start, stop = self._get_boundaries(offset.n, norm_freq_str(offset.name)) + return range(start, stop + 1) + + +@dataclass +class FixedPatchSizeConstraints(PatchSizeConstraints): + start: int + stop: Optional[int] = None + + def __post_init__(self): + if self.stop is None: + self.stop = self.start + assert self.start <= self.stop + + def _get_boundaries(self, n: int, offset_name: str) -> tuple[int, int]: + return self.start, self.stop + + +class DefaultPatchSizeConstraints(PatchSizeConstraints): + # https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases + DEFAULT_RANGES = { + "S": (64, 128), # 512s = 8.53min, 4096s = 68.26min + "T": (32, 128), # 64min = 1.07h, 512min = 8.53h + "H": (32, 64), # 128h = 5.33days + "D": (16, 32), + "B": (16, 32), + "W": (16, 32), + "M": (8, 32), + "Q": (1, 8), + "Y": (1, 8), + "A": (1, 8), + } + + def _get_boundaries(self, n: int, offset_name: str) -> tuple[int, int]: + start, stop = self.DEFAULT_RANGES[offset_name] + return start, stop + + +@dataclass +class GetPatchSize(Transformation): + min_time_patches: int + target_field: str = "target" + patch_size: tuple[int, ...] | list[int] | range = (8, 16, 32, 64, 128) + patch_size_constraints: PatchSizeConstraints = DefaultPatchSizeConstraints() + offset: bool = True + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + data_entry["patch_size"] = np.int64(32) + return data_entry + + +@dataclass +class Patchify(MapFuncMixin, Transformation): + max_patch_size: int + fields: tuple[str, ...] = ("target",) + optional_fields: tuple[str, ...] = ("past_feat_dynamic_real",) + pad_value: int | float = 0 + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + patch_size = data_entry["patch_size"] + self.map_func( + partial(self._patchify, patch_size=patch_size), # noqa + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def _patchify(self, data_entry: dict[str, Any], field: str, patch_size: int): + arr = data_entry[field] + if isinstance(arr, list): + return [self._patchify_arr(a, patch_size) for a in arr] + if isinstance(arr, dict): + for k, v in arr.items(): + if k in self.fields or k in self.optional_fields: + arr[k] = self._patchify_arr(v, patch_size) + return arr + return self._patchify_arr(arr, patch_size) + + def _patchify_arr( + self, arr: Num[np.ndarray, "var time*patch"], patch_size: int + ) -> Num[np.ndarray, "var time max_patch"]: + assert arr.shape[-1] % patch_size == 0 + arr = rearrange(arr, "... (time patch) -> ... time patch", patch=patch_size) + pad_width = [(0, 0) for _ in range(arr.ndim)] + pad_width[-1] = (0, self.max_patch_size - patch_size) + arr = np.pad(arr, pad_width, mode="constant", constant_values=self.pad_value) + return arr diff --git a/OATS/models/tsfm/transform/resample.py b/OATS/models/tsfm/transform/resample.py new file mode 100644 index 0000000..7d92ee4 --- /dev/null +++ b/OATS/models/tsfm/transform/resample.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from dataclasses import dataclass +from functools import partial +from typing import Any + +import numpy as np + +from tsfm.common.sampler import Sampler, get_sampler +from tsfm.common.typing import UnivarTimeSeries + +from ._base import Transformation +from ._mixin import CheckArrNDimMixin, CollectFuncMixin, MapFuncMixin + + +@dataclass +class SampleDimension( + CheckArrNDimMixin, CollectFuncMixin, MapFuncMixin, Transformation +): + max_dim: int + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + sampler: Sampler = get_sampler("uniform") + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + total_field_dim = sum( + self.collect_func_list( + self._get_dim, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + ) + self.map_func( + partial(self._process, total_field_dim=total_field_dim), # noqa + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def _get_dim(self, data_entry: dict[str, Any], field: str) -> int: + self.check_ndim(field, data_entry[field], 2) + return len(data_entry[field]) + + def _process( + self, data_entry: dict[str, Any], field: str, total_field_dim: int + ) -> list[UnivarTimeSeries]: + arr: list[UnivarTimeSeries] = data_entry[field] + rand_idx = np.random.permutation(len(arr)) + field_max_dim = (self.max_dim * len(arr)) // total_field_dim + n = self.sampler(min(len(arr), field_max_dim)) + return [arr[idx] for idx in rand_idx[:n]] + + +@dataclass +class Subsample(Transformation): # just take every n-th element + fields: tuple[str, ...] = ("target", "past_feat_dynamic_real") + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + pass + + +class GaussianFilterSubsample( + Subsample +): # blur using gaussian filter before subsampling + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + # gaussian filter + return super()(data_entry) + + +class Downsample(Transformation): # aggregate + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + pass + + +class Upsample(Transformation): + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + pass diff --git a/OATS/models/tsfm/transform/reshape.py b/OATS/models/tsfm/transform/reshape.py new file mode 100644 index 0000000..cb55219 --- /dev/null +++ b/OATS/models/tsfm/transform/reshape.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np +from einops import pack + +from ._base import Transformation +from ._mixin import CollectFuncMixin, MapFuncMixin + + +@dataclass +class SequencifyField(Transformation): + field: str + axis: int = 0 + target_field: str = "target" + target_axis: int = 0 + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + data_entry[self.field] = data_entry[self.field].repeat( + data_entry[self.target_field].shape[self.target_axis], axis=self.axis + ) + return data_entry + + +@dataclass +class PackFields(CollectFuncMixin, Transformation): + output_field: str + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + feat: bool = False + + def __post_init__(self): + self.pack_str: str = "* time feat" if self.feat else "* time" + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + fields = self.collect_func_list( + self.pop_field, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + if len(fields) > 0: + output_field = pack(fields, self.pack_str)[0] + data_entry |= {self.output_field: output_field} + return data_entry + + @staticmethod + def pop_field(data_entry: dict[str, Any], field: str) -> Any: + return np.asarray(data_entry.pop(field)) + + +@dataclass +class FlatPackFields(CollectFuncMixin, Transformation): + output_field: str + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + feat: bool = False + + def __post_init__(self): + self.pack_str: str = "* feat" if self.feat else "*" + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + fields = self.collect_func_list( + self.pop_field, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + if len(fields) > 0: + output_field = pack(fields, self.pack_str)[0] + data_entry |= {self.output_field: output_field} + return data_entry + + @staticmethod + def pop_field(data_entry: dict[str, Any], field: str) -> Any: + return np.asarray(data_entry.pop(field)) + + +@dataclass +class PackCollection(Transformation): + field: str + feat: bool = False + + def __post_init__(self): + self.pack_str: str = "* time feat" if self.feat else "* time" + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + collection = data_entry[self.field] + if isinstance(collection, dict): + collection = list(collection.values()) + data_entry[self.field] = pack(collection, self.pack_str)[0] + return data_entry + + +@dataclass +class FlatPackCollection(Transformation): + field: str + feat: bool = False + + def __post_init__(self): + self.pack_str: str = "* feat" if self.feat else "*" + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + collection = data_entry[self.field] + if isinstance(collection, dict): + collection = list(collection.values()) + data_entry[self.field] = pack(collection, self.pack_str)[0] + return data_entry + + +@dataclass +class Transpose(MapFuncMixin, Transformation): + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + axes: Optional[tuple[int, ...]] = None + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + self.map_func( + self.transpose, + data_entry, + fields=self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def transpose(self, data_entry: dict[str, Any], field: str) -> Any: + out = data_entry[field].transpose(self.axes) + return out diff --git a/OATS/models/tsfm/transform/task.py b/OATS/models/tsfm/transform/task.py new file mode 100644 index 0000000..c3635c7 --- /dev/null +++ b/OATS/models/tsfm/transform/task.py @@ -0,0 +1,264 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from dataclasses import dataclass +from functools import partial +from typing import Any + +import numpy as np +from jaxtyping import Bool, Float + +from tsfm.transform._base import Chain + +from ._base import Transformation +from ._mixin import CheckArrNDimMixin, CollectFuncMixin, MapFuncMixin + + +@dataclass +class MaskedPrediction(MapFuncMixin, CheckArrNDimMixin, Transformation): + min_mask_ratio: float + max_mask_ratio: float + target_field: str = "target" + truncate_fields: tuple[str, ...] = tuple() + optional_truncate_fields: tuple[str, ...] = tuple() + prediction_mask_field: str = "prediction_mask" + expected_ndim: int = 2 + + def __post_init__(self): + assert ( + self.min_mask_ratio <= self.max_mask_ratio + ), "min_mask_ratio must be <= max_mask_ratio" + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + target = data_entry[self.target_field] + data_entry['label'] = target + data_entry['label_observed_mask'] = data_entry['observed_mask'] + prediction_mask = self._generate_prediction_mask(target) + self.map_func( + partial(self._truncate, mask=prediction_mask), # noqa + data_entry, + self.truncate_fields, + optional_fields=self.optional_truncate_fields, + ) + data_entry[self.prediction_mask_field] = prediction_mask + return data_entry + + def _generate_prediction_mask( + self, target: Float[np.ndarray, "var time *feat"] + ) -> Bool[np.ndarray, "var time"]: + self.check_ndim("target", target, self.expected_ndim) + var, time = target.shape[:2] + prediction_mask = np.zeros((var, time), dtype=bool) + mask_ratio = np.random.uniform(self.min_mask_ratio, self.max_mask_ratio) + mask_length = max(1, round(time * mask_ratio)) + prediction_mask[:, -mask_length:] = True + return prediction_mask + + def _truncate( + self, + data_entry: dict[str, Any], + field: str, + mask: np.ndarray, + ) -> np.ndarray | list[np.ndarray] | dict[str, np.ndarray]: + arr: np.ndarray | list[np.ndarray] | dict[str, np.ndarray] = data_entry[field] + if isinstance(arr, list): + return [self._truncate_arr(a, mask) for a in arr] + if isinstance(arr, dict): + for k, v in arr.items(): + if k in self.truncate_fields or k in self.optional_truncate_fields: + arr[k] = self._truncate_arr(v, mask) + return arr + return self._truncate_arr(arr, mask) + + @staticmethod + def _truncate_arr( + arr: Float[np.ndarray, "var time *feat"], mask: Bool[np.ndarray, "var time"] + ) -> Float[np.ndarray, "var time-mask_len *feat"]: + return arr[:, ~mask[0]] + + +@dataclass +class ExtendMask(CheckArrNDimMixin, CollectFuncMixin, Transformation): + fields: tuple[str, ...] + mask_field: str + optional_fields: tuple[str, ...] = tuple() + expected_ndim: int = 2 + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + target_mask: np.ndarray = data_entry[self.mask_field] + aux_target_mask: list[np.ndarray] = self.collect_func_list( + self._generate_target_mask, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + data_entry[self.mask_field] = [target_mask] + aux_target_mask + return data_entry + + def _generate_target_mask( + self, data_entry: dict[str, Any], field: str + ) -> np.ndarray: + arr: np.ndarray = data_entry[field] + self.check_ndim(field, arr, self.expected_ndim) + var, time = arr.shape[:2] + field_target_mask = np.zeros((var, time), dtype=bool) + return field_target_mask + + +@dataclass +class EvalMaskedPrediction(MapFuncMixin, CheckArrNDimMixin, Transformation): + mask_length: int + target_field: str = "target" + truncate_fields: tuple[str, ...] = tuple() + optional_truncate_fields: tuple[str, ...] = tuple() + prediction_mask_field: str = "prediction_mask" + expected_ndim: int = 2 + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + target = data_entry[self.target_field] + data_entry['label'] = target + data_entry['label_observed_mask'] = data_entry['observed_mask'] + prediction_mask = self._generate_prediction_mask(target) + self.map_func( + partial(self._truncate, mask=prediction_mask), # noqa + data_entry, + self.truncate_fields, + optional_fields=self.optional_truncate_fields, + ) + data_entry[self.prediction_mask_field] = prediction_mask + return data_entry + + def _generate_prediction_mask( + self, target: Float[np.ndarray, "var time *feat"] + ) -> Bool[np.ndarray, "var time"]: + self.check_ndim("target", target, self.expected_ndim) + var, time = target.shape[:2] + prediction_mask = np.zeros((var, time), dtype=bool) + prediction_mask[:, -self.mask_length :] = True + return prediction_mask + + def _truncate( + self, + data_entry: dict[str, Any], + field: str, + mask: np.ndarray, + ) -> np.ndarray | list[np.ndarray] | dict[str, np.ndarray]: + arr: np.ndarray | list[np.ndarray] | dict[str, np.ndarray] = data_entry[field] + if isinstance(arr, list): + return [self._truncate_arr(a, mask) for a in arr] + if isinstance(arr, dict): + for k, v in arr.items(): + if k in self.truncate_fields or k in self.optional_truncate_fields: + arr[k] = self._truncate_arr(v, mask) + return arr + return self._truncate_arr(arr, mask) + + @staticmethod + def _truncate_arr( + arr: Float[np.ndarray, "var time *feat"], mask: Bool[np.ndarray, "var time"] + ) -> Float[np.ndarray, "var time-mask_len *feat"]: + return arr[:, ~mask[0]] + + +@dataclass +class NextTokenPrediction(CollectFuncMixin, MapFuncMixin, CheckArrNDimMixin, Transformation): + min_mask_ratio: float + max_mask_ratio: float + target_field: str = "target" + prediction_mask_field: str = "prediction_mask" + expected_ndim: int = 2 + + def __post_init__(self): + assert ( + self.min_mask_ratio <= self.max_mask_ratio + ), "min_mask_ratio must be <= max_mask_ratio" + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + # label and label_observed_mask + label = self._wrap(data_entry, self.target_field, lambda x: x[:, 1:]) + label_observed_mask = self._wrap(data_entry, 'observed_mask', lambda x: x[:, 1:]) + data_entry['label'] = label + data_entry['label_observed_mask'] = label_observed_mask + + # target and observed_mask + target = self._wrap(data_entry, self.target_field, lambda x: x[:, :-1]) + observed_mask = self._wrap(data_entry, 'observed_mask', lambda x: x[:, :-1]) + data_entry['target'] = target + data_entry['observed_mask'] = observed_mask + + # generate the prediction mask + target = data_entry[self.target_field] + prediction_mask = self._generate_prediction_mask(target) + data_entry[self.prediction_mask_field] = prediction_mask + + return data_entry + + def _wrap( + self, + data_entry: dict[str, Any], + field: str, + func, + ) -> np.ndarray | list[np.ndarray] | dict[str, np.ndarray]: + arr: np.ndarray | list[np.ndarray] | dict[str, np.ndarray] = data_entry[field].copy() + if isinstance(arr, list): + return [func(a) for a in arr] + if isinstance(arr, dict): + for k, v in arr.items(): + arr[k] = func(v) + return arr + return func(arr) + + def _generate_prediction_mask( + self, target: Float[np.ndarray, "var time *feat"] + ) -> Bool[np.ndarray, "var time"]: + self.check_ndim("target", target, self.expected_ndim) + var, time = target.shape[:2] + prediction_mask = np.zeros((var, time), dtype=bool) + mask_ratio = np.random.uniform(self.min_mask_ratio, self.max_mask_ratio) + mask_length = max(1, round(time * mask_ratio)) + prediction_mask[:, -mask_length:] = True + return prediction_mask + + +@dataclass +class EvalNextTokenPrediction(CollectFuncMixin, MapFuncMixin, CheckArrNDimMixin, Transformation): + mask_length: int + target_field: str = "target" + prediction_mask_field: str = "prediction_mask" + expected_ndim: int = 2 + + def __post_init__(self): + pass + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + # label and label_observed_mask + target = data_entry[self.target_field] + data_entry['label'] = target + data_entry['label_observed_mask'] = data_entry['observed_mask'] + prediction_mask = self._generate_prediction_mask(target) + data_entry[self.prediction_mask_field] = prediction_mask + + return data_entry + + def _wrap( + self, + data_entry: dict[str, Any], + field: str, + func, + ) -> np.ndarray | list[np.ndarray] | dict[str, np.ndarray]: + arr: np.ndarray | list[np.ndarray] | dict[str, np.ndarray] = data_entry[field].copy() + if isinstance(arr, list): + return [func(a) for a in arr] + if isinstance(arr, dict): + for k, v in arr.items(): + arr[k] = func(v) + return arr + return func(arr) + + def _generate_prediction_mask( + self, target: Float[np.ndarray, "var time *feat"] + ) -> Bool[np.ndarray, "var time"]: + self.check_ndim("target", target, self.expected_ndim) + var, time = target.shape[:2] + prediction_mask = np.zeros((var, time), dtype=bool) + prediction_mask[:, -self.mask_length:] = True + return prediction_mask \ No newline at end of file diff --git a/OATS/models/tsfm/val/metrics.py b/OATS/models/tsfm/val/metrics.py new file mode 100644 index 0000000..9734070 --- /dev/null +++ b/OATS/models/tsfm/val/metrics.py @@ -0,0 +1,245 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import numpy as np +from einops import rearrange, reduce +from functools import partial +from typing import ( + Collection, + Optional, + Callable, + Mapping, + Dict, + List, + Iterator, +) +from gluonts.ev.stats import ( + error, + absolute_error, + absolute_label, + absolute_percentage_error, + absolute_scaled_error, + coverage, + quantile_loss, + scaled_interval_score, + scaled_quantile_loss, + squared_error, + symmetric_absolute_percentage_error, + num_masked_target_values, +) +from gluonts.evaluation.metrics import calculate_seasonal_error + + +def mse(data: Dict[str, torch.Tensor], forecast_type: str = "0.5") -> torch.Tensor: + # aggregate: mean + return squared_error(data, forecast_type) + + +def mae(data: Dict[str, torch.Tensor], forecast_type: str = "0.5") -> torch.Tensor: + # aggregate: mean + return absolute_error(data, forecast_type) + + +def mase(data: Dict[str, torch.Tensor], forecast_type: str = "0.5") -> torch.Tensor: + # aggregate: mean + return absolute_scaled_error(data, forecast_type) + + +def mape(data: Dict[str, torch.Tensor], forecast_type: str = "0.5") -> torch.Tensor: + # aggregate: mean + return safe_div(absolute_error(data, forecast_type), absolute_label(data)) + + +def smape(data: Dict[str, torch.Tensor], forecast_type: str = "0.5") -> torch.Tensor: + # aggregate: mean + # return symmetric_absolute_percentage_error(data, forecast_type) + return safe_div( + 2 * absolute_error(data, forecast_type), + (absolute_label(data) + np.abs(data[forecast_type])), + ) + + +def msis(data: Dict[str, torch.Tensor]) -> torch.Tensor: + # aggregate: mean + # data['seasonal_error'] + return scaled_interval_score(data, alpha=0.05) + + +def rmse(data: Dict[str, torch.Tensor], forecast_type: str = "mean") -> torch.Tensor: + # aggregate: mean + return np.sqrt(squared_error(data, forecast_type)) + + +def nrmse(data: Dict[str, torch.Tensor], forecast_type: str = "mean") -> torch.Tensor: + # aggregate: mean + return safe_div(np.sqrt(squared_error(data, forecast_type)), absolute_label(data)) + + +def nd(data: Dict[str, torch.Tensor], forecast_type: str = "0.5") -> torch.Tensor: + # aggregate: sum + return safe_div(absolute_error(data, forecast_type), absolute_label(data)) + + +def mean_weighted_seq_quantile_loss(data: Dict[str, torch.Tensor]) -> torch.Tensor: + # MeanWeightedSumQuantileLoss: aggregate: sum + stacked_quantile_losses = [] + for q in np.arange(0.1, 1, 0.1): + stacked_quantile_losses.append( + safe_div(quantile_loss(data, q), absolute_label(data)) + ) + stacked_quantile_losses = np.stack(stacked_quantile_losses, axis=0) + mean_quantile_loss = stacked_quantile_losses.mean(axis=0) + return mean_quantile_loss + + +def safe_div( + numer: torch.Tensor, + denom: torch.Tensor, +) -> torch.Tensor: + return numer / np.where( + denom == 0, + 1.0, + denom, + ) + + +class ValMetric: + stat: Callable + aggregate: str + + def __init__(self, stat, aggregate) -> None: + self.stat = stat + self.aggregate = aggregate + + def __call__( + self, pred, target, observed_mask, prediction_mask, sample_id, variate_id + ) -> torch.Tensor: + """ + pred: torch.Tensor, shape (batch_size, patch_len, num_features) + target: torch.Tensor, shape (batch_size, patch_len, num_features) + observed_mask: torch.Tensor, shape (batch_size, patch_len, num_features) + prediction_mask: torch.Tensor, shape (batch_size, patch_len) + sample_id: torch.Tensor, shape (batch_size, patch_len) + variate_id: torch.Tensor, shape (batch_size, patch_len) + freq: str, shape batch_size + """ + + data = {} + for q in np.arange(0.1, 1, 0.1): + data[str(q)] = torch.quantile(pred, q=q, dim=1).cpu().numpy() + data["mean"] = pred.mean(dim=1).cpu().numpy() + data["label"] = target.cpu().numpy() + data["seasonal_error"] = calculate_seasonal_error(data["label"], seasonality=1) + + id_mask = torch.logical_and( + torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)), + torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)), + ) + mask = prediction_mask.unsqueeze(-1) * observed_mask + tobs = reduce( + id_mask + * reduce( + mask, + "... seq dim -> ... 1 seq", + "sum", + ), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + nobs = reduce( + id_mask * rearrange(prediction_mask, "... seq -> ... 1 seq"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) * prediction_mask.unsqueeze(-1) + nobs = torch.where(nobs == 0, nobs, 1 / nobs).sum() + + tobs = tobs.cpu().numpy() + nobs = nobs.cpu().numpy() + mask = mask.cpu().numpy() + + metric = self.stat(data) + + if self.aggregate == "mean": + metric = safe_div(metric, tobs * nobs) + elif self.aggregate == "sum": + pass + metric = (metric * mask).sum() + + return metric + + +class MSE_mean(ValMetric): + def __init__(self) -> None: + stat = partial(mse, forecast_type="mean") + aggregate = "mean" + super().__init__(stat, aggregate) + + +class MAE_mean(ValMetric): + def __init__(self) -> None: + stat = partial(mae, forecast_type="mean") + aggregate = "mean" + super().__init__(stat, aggregate) + + +class MSE_median(ValMetric): + def __init__(self) -> None: + stat = partial(mse, forecast_type="0.5") + aggregate = "mean" + super().__init__(stat, aggregate) + + +class MAE_median(ValMetric): + def __init__(self) -> None: + stat = partial(mae, forecast_type="0.5") + aggregate = "mean" + super().__init__(stat, aggregate) + + +class MASE(ValMetric): + def __init__(self) -> None: + stat = mase + aggregate = "mean" + super().__init__(stat, aggregate) + + +class MAPE(ValMetric): + def __init__(self) -> None: + stat = mape + aggregate = "mean" + super().__init__(stat, aggregate) + + +class SMAPE(ValMetric): + def __init__(self) -> None: + stat = smape + aggregate = "mean" + super().__init__(stat, aggregate) + + +class RMSE(ValMetric): + def __init__(self) -> None: + stat = partial(rmse, forecast_type="mean") + aggregate = "mean" + super().__init__(stat, aggregate) + + +class NRMSE(ValMetric): + def __init__(self) -> None: + stat = partial(nrmse, forecast_type="mean") + aggregate = "mean" + super().__init__(stat, aggregate) + + +class ND(ValMetric): + def __init__(self) -> None: + stat = partial(nd, forecast_type="0.5") + aggregate = "sum" + super().__init__(stat, aggregate) + + +class CRPS(ValMetric): + def __init__(self) -> None: + stat = mean_weighted_seq_quantile_loss + aggregate = "mean" + super().__init__(stat, aggregate)