Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions TimeDP/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,22 @@ Use `main_train.py` for model training and `visualize.py` for domain prompt visu
## Usage
Training TimeDP:
```bash
python main_train.py --base configs/multi_domain_tsgen.yaml --gpus 0, --logdir ./logs/ -sl 168 -up -nl 16 --batch_size 128 -lr 0.0001 -s 0
python main_train.py --base configs/multi_domain_timedp.yaml --gpus 0, --logdir ./logs/ -sl 168 -up -nl 16 --batch_size 128 -lr 0.0001 -s 0
```

Training without PAM:
```bash
python main_train.py --base configs/multi_domain_tsgen.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0
python main_train.py --base configs/multi_domain_timedp.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0
```

Training without domain prompts (unconditional generation model):
```bash
python main_train.py --base configs/multi_domain_tsgen.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0 --uncond
python main_train.py --base configs/multi_domain_timedp.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0 --uncond
```

Visualization of domain prompts:
```bash
python visualize.py --base configs/multi_domain_tsgen.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0 --uncond
python visualize.py --base configs/multi_domain_timedp.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0 --uncond
```


Expand Down
1 change: 1 addition & 0 deletions TimeDP/configs/multi_domain_timedp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ data:
normalize: centered_pit
drop_last: True
reweight: True
input_channels: 1

lightning:
callbacks:
Expand Down
3 changes: 2 additions & 1 deletion TimeDP/ldm/data/tsg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class TSGDataModule(pl.LightningDataModule):
Data module for unified time series generation task.
Slicing is also done with this module. So the train/val is i.i.d within train dataset.
'''
def __init__(self, data_path_dict, window=96, val_portion=0.1, as_tensor:bool=True, normalize="centered_pit", batch_size=128, num_workers=0, pin_memory=True, drop_last=False, reweight=False, **kwargs):
def __init__(self, data_path_dict, window=96, val_portion=0.1, as_tensor:bool=True, normalize="centered_pit", batch_size=128, num_workers=0, pin_memory=True, drop_last=False, reweight=False, input_channels=1, **kwargs):
super().__init__()
self.data_path_dict = data_path_dict # {data_name: data_path}
self.data_dict = {}
Expand All @@ -88,6 +88,7 @@ def __init__(self, data_path_dict, window=96, val_portion=0.1, as_tensor:bool=Tr
self.val_dataset = None
self.test_dataset = None
self.reweight = reweight
self.input_channels = input_channels
# self.transform = None
self.kwargs = kwargs
self.key_list = []
Expand Down
2 changes: 1 addition & 1 deletion TimeDP/utils/init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def init_model_data_trainer(parser):

# data
for k, v in config.data.params.data_path_dict.items():
config.data.params.data_path_dict[k] = v.replace('{DATA_ROOT}', data_root).replace('{SEQ_LEN}', opt.seq_len)
config.data.params.data_path_dict[k] = v.replace('{DATA_ROOT}', data_root).replace('{SEQ_LEN}', str(opt.seq_len))
data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
Expand Down