diff --git a/TimeDP/README.md b/TimeDP/README.md index be45734..c0452d9 100644 --- a/TimeDP/README.md +++ b/TimeDP/README.md @@ -68,22 +68,22 @@ Use `main_train.py` for model training and `visualize.py` for domain prompt visu ## Usage Training TimeDP: ```bash -python main_train.py --base configs/multi_domain_tsgen.yaml --gpus 0, --logdir ./logs/ -sl 168 -up -nl 16 --batch_size 128 -lr 0.0001 -s 0 +python main_train.py --base configs/multi_domain_timedp.yaml --gpus 0, --logdir ./logs/ -sl 168 -up -nl 16 --batch_size 128 -lr 0.0001 -s 0 ``` Training without PAM: ```bash -python main_train.py --base configs/multi_domain_tsgen.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0 +python main_train.py --base configs/multi_domain_timedp.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0 ``` Training without domain prompts (unconditional generation model): ```bash -python main_train.py --base configs/multi_domain_tsgen.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0 --uncond +python main_train.py --base configs/multi_domain_timedp.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0 --uncond ``` Visualization of domain prompts: ```bash -python visualize.py --base configs/multi_domain_tsgen.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0 --uncond +python visualize.py --base configs/multi_domain_timedp.yaml --gpus 0, --logdir ./logs/ -sl 168 --batch_size 128 -lr 0.0001 -s 0 --uncond ``` diff --git a/TimeDP/configs/multi_domain_timedp.yaml b/TimeDP/configs/multi_domain_timedp.yaml index 30123e8..b053cd9 100644 --- a/TimeDP/configs/multi_domain_timedp.yaml +++ b/TimeDP/configs/multi_domain_timedp.yaml @@ -84,6 +84,7 @@ data: normalize: centered_pit drop_last: True reweight: True + input_channels: 1 lightning: callbacks: diff --git a/TimeDP/ldm/data/tsg_dataset.py b/TimeDP/ldm/data/tsg_dataset.py index 6931a74..50767f9 100644 --- a/TimeDP/ldm/data/tsg_dataset.py +++ b/TimeDP/ldm/data/tsg_dataset.py @@ -66,7 +66,7 @@ class TSGDataModule(pl.LightningDataModule): Data module for unified time series generation task. Slicing is also done with this module. So the train/val is i.i.d within train dataset. ''' - def __init__(self, data_path_dict, window=96, val_portion=0.1, as_tensor:bool=True, normalize="centered_pit", batch_size=128, num_workers=0, pin_memory=True, drop_last=False, reweight=False, **kwargs): + def __init__(self, data_path_dict, window=96, val_portion=0.1, as_tensor:bool=True, normalize="centered_pit", batch_size=128, num_workers=0, pin_memory=True, drop_last=False, reweight=False, input_channels=1, **kwargs): super().__init__() self.data_path_dict = data_path_dict # {data_name: data_path} self.data_dict = {} @@ -88,6 +88,7 @@ def __init__(self, data_path_dict, window=96, val_portion=0.1, as_tensor:bool=Tr self.val_dataset = None self.test_dataset = None self.reweight = reweight + self.input_channels = input_channels # self.transform = None self.kwargs = kwargs self.key_list = [] diff --git a/TimeDP/utils/init_utils.py b/TimeDP/utils/init_utils.py index 132fcc7..d5af3c1 100644 --- a/TimeDP/utils/init_utils.py +++ b/TimeDP/utils/init_utils.py @@ -116,7 +116,7 @@ def init_model_data_trainer(parser): # data for k, v in config.data.params.data_path_dict.items(): - config.data.params.data_path_dict[k] = v.replace('{DATA_ROOT}', data_root).replace('{SEQ_LEN}', opt.seq_len) + config.data.params.data_path_dict[k] = v.replace('{DATA_ROOT}', data_root).replace('{SEQ_LEN}', str(opt.seq_len)) data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is.