diff --git a/TarDiff/README.md b/TarDiff/README.md new file mode 100755 index 0000000..2a81240 --- /dev/null +++ b/TarDiff/README.md @@ -0,0 +1,100 @@ +# TarDiff: Target-Oriented Diffusion Guidance for Synthetic Electronic Health Record Time Series Generation + + + +## Introduction +![TarDiff OverView](./images/overview.png) + +Synthetic Electronic Health Record (EHR) time-series generation is crucial for advancing clinical machine learning models, as it helps address data scarcity by providing more training data. However, most existing approaches focus primarily on replicating statistical distributions and temporal dependencies of real-world data. We argue that fidelity to observed data alone does not guarantee better model performance, as common patterns may dominate, limiting the representation of rare but important conditions. This highlights the need for generate synthetic samples to improve performance of specific clinical models to fulfill their target outcomes. To address this, we propose TarDiff, a novel target-oriented diffusion framework that integrates task-specific influence guidance into the synthetic data generation process. Unlike conventional approaches that mimic training data distributions, TarDiff optimizes synthetic samples by quantifying their expected contribution to improving downstream model performance through influence functions. Specifically, we measure the reduction in task-specific loss induced by synthetic samples and embed this influence gradient into the reverse diffusion process, thereby steering the generation towards utility-optimized data. Evaluated on six publicly available EHR datasets, TarDiff achieves state-of-the-art performance, outperforming existing methods by up to 20.4% in AUPRC and 18.4% in AUROC. Our results demonstrate that TarDiff not only preserves temporal fidelity but also enhances downstream model performance, offering a robust solution to data scarcity and class imbalance in healthcare analytics. + + +## 1 · Environment + +Prepare TarDiff's environment. +``` +conda env create -f environment.yaml +conda activate tardiff +``` + +Prepare TS downstream task environment depands on the repo you used for the specific task. + +## 2 · Data Pre-processing + +You can access the raw datasets at the following links: + +- [eICU Collaborative Research Database](https://eicu-crd.mit.edu/) +- [MIMIC-III Clinical Database](https://physionet.org/content/mimiciii/1.4/) + +> **Note:** Both datasets require prior approval and credentialing before download. + +We focus exclusively on the multivariate time-series recordings available in these datasets. +To assist with preprocessing, we provide high-level extraction scripts under **`data_preprocess/`**. + + + +## 3 · Stage 1 · Train the *Base* Diffusion Model + +```bash +bash train.sh # trains TarDiff on MIMIC-III ICU-stay data +``` + +> **Edit tip:** open the example YAML in `configs/base/` and replace any placeholder data paths with your own before running. + +This step produces an unconditional diffusion model checkpoint—no guidance yet. + +--- + +## 4 · Stage 2 · Train a Downstream Task Model (Guidance Source) + +An example RNN classifier is supplied in **`classifier/`**. + +```bash +cd classifier +bash train.sh # saves weights to classifier/checkpoint/ +cd .. +``` + +Feel free to swap in any architecture that suits your task. + +--- + +## 5 · Stage 3 · Target-Guided Generation + +With **both** checkpoints ready—the diffusion backbone and the task model—start guided sampling: + +```bash +bash generation.sh # remember to update paths to both weights +``` + +The script creates a synthetic dataset tailored to the guidance task. + +--- + +## 6 · Stage 4 · Utility Evaluation — *TSTR* and *TSRTR* + +After generation, you can assess the utility of the synthetic data for the **target task** using two complementary protocols: + +| Protocol | Training Set | Test Set | Question Answered | +|:---------|:-------------|:---------|:------------------| +| **TSTR** (Train-Synthetic, Test-Real) | **Synthetic only** | **Real** | “If I train a model purely on synthetic EHRs, how well does it generalize to real patients?” | +| **TSRTR** (Train-Synthetic-Real, Test-Real) | **Synthetic + Real** (α ∈ {0.2, …, 1.0}) | **Real** | “If I augment the real training set with α× synthetic samples, does it improve model performance?” | + +--- + +### How to run TSTR and TSRTR evaluations + +You can directly reuse the training script under `classifier/` to run both evaluations: + +```bash +cd classifier +bash train.sh # Edit the training data path to point to either synthetic-only or mixed (real + synthetic) data +``` + +- For **TSTR**, set the training set path to the synthetic dataset. +- For **TSRTR**, combine the real training data with synthetic samples according to your desired α ratio, and update the path accordingly. + +The downstream model will be trained and evaluated automatically on the real validation and test sets. + +--- + +Enjoy exploring target-oriented diffusion for healthcare ML! For issues or pull requests, please open a GitHub ticket. diff --git a/TarDiff/classifier/__init__.py b/TarDiff/classifier/__init__.py new file mode 100644 index 0000000..0eca642 --- /dev/null +++ b/TarDiff/classifier/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. \ No newline at end of file diff --git a/TarDiff/classifier/classifier_train.py b/TarDiff/classifier/classifier_train.py new file mode 100644 index 0000000..1663d66 --- /dev/null +++ b/TarDiff/classifier/classifier_train.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +from __future__ import annotations + +import argparse +from pathlib import Path +from model import RNNClassifier +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader, random_split +from tqdm import tqdm +import os +import pandas as pd +from typing import Optional, Tuple + + +class TimeSeriesDataset(Dataset): + + def __init__( + self, + data, + labels, + normalize=True, + stats=None, + eps=1e-8, + ): + + if isinstance(data, np.ndarray): + data = torch.from_numpy(data) + if isinstance(labels, np.ndarray): + labels = torch.from_numpy(labels) + assert data.ndim == 3, "data must be (N, seq_len, n_features)" + assert len(data) == len(labels) + self.data = data.float() + self.labels = labels.long() + self.normalize = normalize + self.eps = eps + + if self.normalize: + if stats is None: + # compute mean/std over all time‑steps *per feature* + mean = self.data.mean(dim=(0, 1), keepdim=True) # (1,1,F) + std = self.data.std(dim=(0, 1), keepdim=True) + else: + mean, std = stats + if isinstance(mean, np.ndarray): + mean = torch.from_numpy(mean) + if isinstance(std, np.ndarray): + std = torch.from_numpy(std) + mean, std = mean.float(), std.float() + self.register_buffer("_mean", + mean) # cached on device when .to(...) + self.register_buffer("_std", std.clamp_min(self.eps)) + + # tiny helper so buffers exist even on CPU tensors + def register_buffer(self, name: str, tensor: torch.Tensor): + object.__setattr__(self, name, tensor) + + # expose stats to reuse on other splits + @property + def stats(self) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + if not self.normalize: + return None + return self._mean, self._std + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + x = self.data[idx] + if self.normalize: + x = (x - self._mean) / self._std + x = x.squeeze(0) + return x, self.labels[idx] + + +# ----------------------------------------------------------------------------- +# Train / Eval helpers +# ----------------------------------------------------------------------------- + + +def _accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float: + if logits.dim() == 1: # binary with BCEWithLogits + preds = (torch.sigmoid(logits) > 0.5).long() + else: # multi‑class with CE + preds = logits.argmax(dim=1) + return (preds == labels).float().mean().item() + + +def _run_epoch(model, loader, criterion, optimizer, device, train=True): + if train: + model.train() + else: + model.eval() + total_loss = 0.0 + total_acc = 0.0 + for x, y in loader: + x, y = x.to(device), y.to(device) + if train: + optimizer.zero_grad() + logits = model(x) + loss = criterion(logits, y.float() if logits.dim() == 1 else y) + if train: + loss.backward() + optimizer.step() + total_loss += loss.item() * x.size(0) + total_acc += _accuracy(logits, y) * x.size(0) + n = len(loader.dataset) + return total_loss / n, total_acc / n + + +# ----------------------------------------------------------------------------- +# Main – quick demo on synthetic data +# ----------------------------------------------------------------------------- + + +def main(args): + rng = np.random.default_rng(args.seed) + if os.path.exists(args.train_data) and os.path.exists(args.val_data): + X_train, y_train = pd.read_pickle(args.train_data) + X_train = X_train.transpose(0, 2, 1) + X_val, y_val = pd.read_pickle(args.val_data) + X_val = X_val.transpose(0, 2, 1) + else: + X_train = rng.standard_normal(size=(20000, 24, 7), dtype=np.float32) + y_train = rng.integers(0, + args.num_classes, + size=(20000, ), + dtype=np.int64) + X_val = rng.standard_normal(size=(5000, 24, 7), dtype=np.float32) + y_val = rng.integers(0, + args.num_classes, + size=(5000, ), + dtype=np.int64) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + train_set = TimeSeriesDataset(X_train, y_train) + val_set = TimeSeriesDataset(X_val, y_val) + + train_loader = DataLoader(train_set, + batch_size=args.batch_size, + shuffle=True) + val_loader = DataLoader(val_set, batch_size=args.batch_size) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = RNNClassifier( + input_dim=7, + hidden_dim=args.hidden_dim, + num_layers=args.num_layers, + rnn_type=args.rnn_type, + num_classes=args.num_classes, + dropout=args.dropout, + ).to(device) + + criterion = (nn.BCEWithLogitsLoss() + if args.num_classes == 1 else nn.CrossEntropyLoss()) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + best_val = 0.0 + Path(args.ckpt_dir).mkdir(parents=True, exist_ok=True) + + for epoch in tqdm(range(1, args.epochs + 1)): + tr_loss, tr_acc = _run_epoch(model, + train_loader, + criterion, + optimizer, + device, + train=True) + va_loss, va_acc = _run_epoch(model, + val_loader, + criterion, + optimizer, + device, + train=False) + print( + f"Epoch {epoch:02d} | train {tr_loss:.4f}/{tr_acc:.4f} | val {va_loss:.4f}/{va_acc:.4f}" + ) + if va_acc > best_val: + best_val = va_acc + print(f"New best val acc: {best_val:.4f} -> saving model") + torch.save({"model_state": model.state_dict()}, + Path(args.ckpt_dir) / "best_model.pt") + print(f"Train Finished. Best val acc: {best_val:.4f}") + + +# ----------------------------------------------------------------------------- +# CLI +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + p = argparse.ArgumentParser( + description="Bidirectional LSTM/GRU time‑series classifier") + p.add_argument("--hidden_dim", type=int, default=128) + p.add_argument("--num_layers", type=int, default=2) + p.add_argument("--rnn_type", choices=["lstm", "gru"], default="lstm") + p.add_argument("--num_classes", + type=int, + default=1, + help="1 for binary, >1 for multi‑class") + p.add_argument("--batch_size", type=int, default=256) + p.add_argument("--epochs", type=int, default=40) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--dropout", type=float, default=0.2) + p.add_argument("--train_data", type=str, default="data/train_data.npy") + p.add_argument("--val_data", type=str, default="data/val_data.npy") + p.add_argument("--ckpt_dir", type=str, default="checkpoints") + p.add_argument("--seed", type=int, default=42) + + args = p.parse_args() + main(args) diff --git a/TarDiff/classifier/model.py b/TarDiff/classifier/model.py new file mode 100644 index 0000000..851c0be --- /dev/null +++ b/TarDiff/classifier/model.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +from __future__ import annotations +import torch +from torch import nn + + +class RNNClassifier(nn.Module): + """Bidirectional LSTM/GRU classifier for fixed‑length sequences.""" + + def __init__( + self, + input_dim: int, + hidden_dim: int = 128, + num_layers: int = 2, + rnn_type: str = "lstm", + num_classes: int = 2, + dropout: float = 0.2, + ) -> None: + super().__init__() + rnn_cls = {"lstm": nn.LSTM, "gru": nn.GRU}[rnn_type.lower()] + self.rnn = rnn_cls( + input_size=input_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + bidirectional=True, + dropout=dropout if num_layers > 1 else 0.0, + ) + self.fc = nn.Linear(hidden_dim * 2, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # (B, T, F) + rnn_out, _ = self.rnn(x) # (B, T, 2*H) + last_hidden = rnn_out[:, -1, :] # final time‑step representation + logits = self.fc(last_hidden) # (B, C) or (B, 1) + return logits.squeeze(-1) # binary → (B,) ; multi‑class stays (B, C) diff --git a/TarDiff/classifier/train.sh b/TarDiff/classifier/train.sh new file mode 100644 index 0000000..8c7c6e3 --- /dev/null +++ b/TarDiff/classifier/train.sh @@ -0,0 +1,6 @@ +# prepare guidance model train_tuple.pkl : (data, label) + +python classifier_train.py --num_classes 1 --rnn_type gru --hidden_dim 256 --train_data data/mimic_icustay/train_tuple.pkl --val_data data/mimic_icustay/val_tuple.pkl + + + diff --git a/TarDiff/configs/base/mimic_icustay_base.yaml b/TarDiff/configs/base/mimic_icustay_base.yaml new file mode 100755 index 0000000..95b8981 --- /dev/null +++ b/TarDiff/configs/base/mimic_icustay_base.yaml @@ -0,0 +1,98 @@ +seq_length: &seqlen 24 +model: + base_learning_rate: 5.e-5 # set to target_lr by starting main.py with '--scale_lr False' + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0005 + linear_end: 0.1 + num_timesteps_cond: 1 + log_every_t: 40 + timesteps: 200 + loss_type: l1 + first_stage_key: "context" + cond_stage_key: "context" + image_size: *seqlen + channels: 7 + cond_stage_trainable: True + concat_mode: False + scale_by_std: False # True + monitor: 'val/loss_simple_ema' + conditioning_key: crossattn + cond_part_drop: False + dis_loss_flag: False + pair_loss_flag: False + pair_loss_type: l2 + pair_loss_weight: 1.0 + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [1000] + cycle_lengths: [10000000000000] + f_start: [1.e-6] + f_max: [1.] + f_min: [ 1.] + + unet_config: + target: ldm.modules.diffusionmodules.unet1d.UNetModel + params: + image_size: *seqlen + dims: 1 + in_channels: 7 + out_channels: 7 + model_channels: 64 + attention_resolutions: [ 1, 2, 4] # 8, 4, 2 + num_res_blocks: 2 + channel_mult: [ 1,2,4,4 ] # 8,4,2,1 + num_heads: 8 + use_scale_shift_norm: True + resblock_updown: True + context_dim: 32 + repre_emb_channels: 32 + latent_unit: 1 + use_cfg: True + use_spatial_transformer: True + num_classes: 2 + + first_stage_config: # no first stage model for ts data + target: ldm.models.autoencoder.IdentityFirstStage # VQModelInterface + + cond_stage_config: + target: ldm.modules.encoders.modules.DomainUnifiedEncoder # SplitTSEqEncoder # SplitTSEqEncoder, SingleTSEncoder + params: + dim: 32 + window: *seqlen + latent_dim: 32 # 32 * 3 + num_channels: 7 + use_prototype: False + # use_cfg: True + +data: + target: ldm.data.ts_data_loader.TSClassCondTrainDataModule + params: + data_path_dict: + MIMIC_III_Readmission: data/mimic_icustay/train_tuple.pkl + window: *seqlen + val_portion: 0.1 + batch_size: 256 + num_workers: 8 + normalize: centered_pit + drop_last: True + reweight: False + input_dim: +lightning: + callbacks: + image_logger: + target: utils.callback_utils.TSLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: false + log_images_kwargs: + inpaint: false + plot_swapped_concepts: false + + + trainer: + benchmark: True + max_steps: 20 + grad_watch: False \ No newline at end of file diff --git a/TarDiff/data_preprocess/README.md b/TarDiff/data_preprocess/README.md new file mode 100644 index 0000000..097f0c1 --- /dev/null +++ b/TarDiff/data_preprocess/README.md @@ -0,0 +1 @@ +We preprocess MIMIC-III by first querying the raw **vitals** and **admissions** tables, then isolating each ICU stay (`icustay_id`) as an independent sample. For every stay we extract seven routinely recorded signals—heart-rate, systolic/diastolic blood pressure, mean arterial pressure, respiratory rate, temperature, oxygen saturation (SpO₂), and urine output—resample them to an equal 1-hour grid, and truncate or zero-pad so every sample is a fixed **24 × 7** time-series matrix covering the first 24 hours in the unit. We attach a binary in-hospital mortality label from the admissions record, stack all samples into a single array, randomly shuffle, and split 80 % / 20 % into training and test sets while reporting the class balance. This yields a clean, length-aligned dataset ready for downstream modeling without exposing any protected health information. \ No newline at end of file diff --git a/TarDiff/environment.yaml b/TarDiff/environment.yaml new file mode 100755 index 0000000..59f9d25 --- /dev/null +++ b/TarDiff/environment.yaml @@ -0,0 +1,502 @@ +name: tardiff +channels: + - pytorch + - nvidia + - defaults + - conda-forge + - https://repo.anaconda.com/pkgs/main + - https://repo.anaconda.com/pkgs/r +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - alsa-lib=1.2.8=h166bdaf_0 + - anyio=4.2.0=pyhd8ed1ab_0 + - aom=3.5.0=h27087fc_0 + - appdirs=1.4.4=pyh9f0ad1d_0 + - argon2-cffi=23.1.0=pyhd8ed1ab_0 + - argon2-cffi-bindings=21.2.0=py38h01eb140_4 + - arrow=1.3.0=pyhd8ed1ab_0 + - asttokens=2.4.1=pyhd8ed1ab_0 + - async-lru=2.0.4=pyhd8ed1ab_0 + - attr=2.5.1=h166bdaf_1 + - attrs=23.2.0=pyh71513ae_0 + - babel=2.14.0=pyhd8ed1ab_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - beautifulsoup4=4.12.2=pyha770c72_0 + - blas=1.0=mkl + - bleach=6.1.0=pyhd8ed1ab_0 + - blosc=1.21.4=h0f2a231_0 + - boost-cpp=1.78.0=h5adbc97_2 + - bottleneck=1.3.7=py38ha9d4c09_0 + - brotli=1.1.0=hd590300_1 + - brotli-bin=1.1.0=hd590300_1 + - brotli-python=1.1.0=py38h17151c0_1 + - bzip2=1.0.8=hd590300_5 + - c-ares=1.25.0=hd590300_0 + - ca-certificates=2023.12.12=h06a4308_0 + - cached-property=1.5.2=hd8ed1ab_1 + - cached_property=1.5.2=pyha770c72_1 + - cairo=1.16.0=ha61ee94_1014 + - certifi=2023.11.17=pyhd8ed1ab_0 + - cffi=1.16.0=py38h6d47a40_0 + - cfitsio=4.2.0=hd9d235c_0 + - charset-normalizer=3.3.2=pyhd8ed1ab_0 + - click=8.1.7=unix_pyh707e725_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - comm=0.2.1=pyhd8ed1ab_0 + - cuda=11.6.1=0 + - cuda-cccl=11.6.55=hf6102b2_0 + - cuda-command-line-tools=11.6.2=0 + - cuda-compiler=11.6.2=0 + - cuda-cudart=11.6.55=he381448_0 + - cuda-cudart-dev=11.6.55=h42ad0f4_0 + - cuda-cuobjdump=11.6.124=h2eeebcb_0 + - cuda-cupti=11.6.124=h86345e5_0 + - cuda-cuxxfilt=11.6.124=hecbf4f6_0 + - cuda-driver-dev=11.6.55=0 + - cuda-gdb=12.3.101=0 + - cuda-libraries=11.6.1=0 + - cuda-libraries-dev=11.6.1=0 + - cuda-memcheck=11.8.86=0 + - cuda-nsight=12.3.101=0 + - cuda-nsight-compute=12.3.2=0 + - cuda-nvcc=11.6.124=hbba6d2d_0 + - cuda-nvdisasm=12.3.101=0 + - cuda-nvml-dev=11.6.55=haa9ef22_0 + - cuda-nvprof=12.3.101=0 + - cuda-nvprune=11.6.124=he22ec0a_0 + - cuda-nvrtc=11.6.124=h020bade_0 + - cuda-nvrtc-dev=11.6.124=h249d397_0 + - cuda-nvtx=11.6.124=h0630a44_0 + - cuda-nvvp=12.3.101=0 + - cuda-runtime=11.6.1=0 + - cuda-samples=11.6.101=h8efea70_0 + - cuda-sanitizer-api=12.3.101=0 + - cuda-toolkit=11.6.1=0 + - cuda-tools=11.6.1=0 + - cuda-visual-tools=11.6.1=0 + - cudatoolkit=11.3.1=hb98b00a_12 + - curl=8.1.2=h409715c_0 + - cycler=0.12.1=pyhd8ed1ab_0 + - dbus=1.13.6=h5008d03_3 + - debugpy=1.8.0=py38h17151c0_1 + - decorator=5.1.1=pyhd8ed1ab_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - docker-pycreds=0.4.0=py_0 + - entrypoints=0.4=pyhd8ed1ab_0 + - exceptiongroup=1.2.0=pyhd8ed1ab_2 + - executing=2.0.1=pyhd8ed1ab_0 + - expat=2.5.0=hcb278e6_1 + - ffmpeg=5.1.2=gpl_h8dda1f0_106 + - fftw=3.3.10=nompi_hf0379b8_106 + - filelock=3.13.1=pyhd8ed1ab_0 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_1 + - fontconfig=2.14.2=h14ed4e7_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - fonttools=4.47.2=py38h01eb140_0 + - fqdn=1.5.1=pyhd8ed1ab_0 + - freeglut=3.2.2=h9c3ff4c_1 + - freetype=2.12.1=h267a509_2 + - freexl=1.0.6=h166bdaf_1 + - fsspec=2023.12.2=pyhca7485f_0 + - gds-tools=1.8.1.2=0 + - geos=3.11.1=h27087fc_0 + - geotiff=1.7.1=h7a142b4_6 + - gettext=0.21.1=h27087fc_0 + - giflib=5.2.1=h0b41bf4_3 + - gitdb=4.0.11=pyhd8ed1ab_0 + - gitpython=3.1.41=pyhd8ed1ab_0 + - glib=2.78.1=hfc55251_0 + - glib-tools=2.78.1=hfc55251_0 + - gmp=6.3.0=h59595ed_0 + - gnutls=3.7.9=hb077bed_0 + - grad-cam=1.5.0=pyhd8ed1ab_0 + - graphite2=1.3.13=h58526e2_1001 + - gst-plugins-base=1.21.3=h4243ec0_1 + - gstreamer=1.21.3=h25f0c4b_1 + - gstreamer-orc=0.4.34=hd590300_0 + - h5py=3.8.0=nompi_py38hd5fa8ee_100 + - harfbuzz=6.0.0=h8e241bc_0 + - hdf4=4.2.15=h9772cbc_5 + - hdf5=1.12.2=nompi_h4df4325_101 + - huggingface_hub=0.20.2=pyhd8ed1ab_0 + - icu=70.1=h27087fc_0 + - idna=3.6=pyhd8ed1ab_0 + - importlib_metadata=7.0.1=hd8ed1ab_0 + - importlib_resources=6.1.1=pyhd8ed1ab_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - ipykernel=6.28.0=pyhd33586a_0 + - ipython=8.12.2=pyh41d4057_0 + - ipywidgets=8.1.1=pyhd8ed1ab_0 + - isoduration=20.11.0=pyhd8ed1ab_0 + - jack=1.9.22=h11f4161_0 + - jasper=2.0.33=h0ff4b12_1 + - jbig=2.1=h7f98852_2003 + - jedi=0.19.1=pyhd8ed1ab_0 + - jinja2=3.1.3=pyhd8ed1ab_0 + - jpeg=9e=h0b41bf4_3 + - json-c=0.16=hc379101_0 + - json5=0.9.14=pyhd8ed1ab_0 + - jsonpointer=2.4=py38h578d9bd_3 + - jsonschema=4.20.0=pyhd8ed1ab_0 + - jsonschema-specifications=2023.12.1=pyhd8ed1ab_0 + - jsonschema-with-format-nongpl=4.20.0=pyhd8ed1ab_0 + - jupyter=1.0.0=pyhd8ed1ab_10 + - jupyter-lsp=2.2.1=pyhd8ed1ab_0 + - jupyter_client=8.6.0=pyhd8ed1ab_0 + - jupyter_console=6.6.3=pyhd8ed1ab_0 + - jupyter_core=5.7.1=py38h578d9bd_0 + - jupyter_events=0.9.0=pyhd8ed1ab_0 + - jupyter_server=2.12.4=pyhd8ed1ab_0 + - jupyter_server_terminals=0.5.1=pyhd8ed1ab_0 + - jupyterlab=4.0.10=pyhd8ed1ab_0 + - jupyterlab_pygments=0.3.0=pyhd8ed1ab_0 + - jupyterlab_server=2.25.2=pyhd8ed1ab_0 + - jupyterlab_widgets=3.0.9=pyhd8ed1ab_0 + - kealib=1.5.0=ha7026e8_0 + - keyutils=1.6.1=h166bdaf_0 + - kiwisolver=1.4.5=py38h7f3f72f_1 + - kornia=0.7.0=pyhd8ed1ab_0 + - krb5=1.20.1=h81ceb04_0 + - lame=3.100=h166bdaf_1003 + - lcms2=2.15=hfd0df8a_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=4.0.0=h27087fc_0 + - libabseil=20230802.1=cxx17_h59595ed_0 + - libaec=1.1.2=h59595ed_1 + - libblas=3.9.0=12_linux64_mkl + - libbrotlicommon=1.1.0=hd590300_1 + - libbrotlidec=1.1.0=hd590300_1 + - libbrotlienc=1.1.0=hd590300_1 + - libcap=2.67=he9d0100_0 + - libcblas=3.9.0=12_linux64_mkl + - libclang=15.0.7=default_hb11cfb5_4 + - libclang13=15.0.7=default_ha2b6cf4_4 + - libcublas=11.9.2.110=h5e84587_0 + - libcublas-dev=11.9.2.110=h5c901ab_0 + - libcufft=10.7.1.112=hf425ae0_0 + - libcufft-dev=10.7.1.112=ha5ce4c0_0 + - libcufile=1.8.1.2=0 + - libcufile-dev=1.8.1.2=0 + - libcups=2.3.3=h36d4200_3 + - libcurand=10.3.4.107=0 + - libcurand-dev=10.3.4.107=0 + - libcurl=8.1.2=h409715c_0 + - libcusolver=11.3.4.124=h33c3c4e_0 + - libcusparse=11.7.2.124=h7538f96_0 + - libcusparse-dev=11.7.2.124=hbbe9722_0 + - libdb=6.2.32=h9c3ff4c_0 + - libdeflate=1.17=h0b41bf4_0 + - libdrm=2.4.114=h166bdaf_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=hd590300_2 + - libevent=2.1.10=h28343ad_4 + - libexpat=2.5.0=hcb278e6_1 + - libffi=3.4.2=h7f98852_5 + - libflac=1.4.3=h59595ed_0 + - libgcc-ng=13.2.0=h807b86a_3 + - libgcrypt=1.10.3=hd590300_0 + - libgdal=3.6.2=h6c674c2_9 + - libgfortran-ng=13.2.0=h69a702a_3 + - libgfortran5=13.2.0=ha4646dd_3 + - libglib=2.78.1=hebfc3b9_0 + - libglu=9.0.0=he1b5a44_1001 + - libgomp=13.2.0=h807b86a_3 + - libgpg-error=1.47=h71f35ed_0 + - libiconv=1.17=hd590300_2 + - libidn2=2.3.4=h166bdaf_0 + - libkml=1.3.0=h01aab08_1016 + - liblapack=3.9.0=12_linux64_mkl + - liblapacke=3.9.0=12_linux64_mkl + - libllvm15=15.0.7=hadd5161_1 + - libnetcdf=4.9.1=nompi_h34a3ff0_101 + - libnghttp2=1.58.0=h47da74e_0 + - libnpp=11.6.3.124=hd2722f0_0 + - libnpp-dev=11.6.3.124=h3c42840_0 + - libnsl=2.0.1=hd590300_0 + - libnvjpeg=11.6.2.124=hd473ad6_0 + - libnvjpeg-dev=11.6.2.124=hb5906b9_0 + - libogg=1.3.4=h7f98852_1 + - libopencv=4.7.0=py38h340f60e_0 + - libopus=1.3.1=h7f98852_1 + - libpciaccess=0.17=h166bdaf_0 + - libpng=1.6.39=h753d276_0 + - libpq=15.2=hb675445_0 + - libprotobuf=3.21.12=hfc55251_2 + - librttopo=1.1.0=ha49c73b_12 + - libsndfile=1.2.2=hc60ed4a_1 + - libsodium=1.0.18=h36c2ea0_1 + - libspatialite=5.0.1=h221c8f1_23 + - libsqlite=3.44.2=h2797004_0 + - libssh2=1.11.0=h0841786_0 + - libstdcxx-ng=13.2.0=h7e041cc_3 + - libsystemd0=253=h8c4010b_1 + - libtasn1=4.19.0=h166bdaf_0 + - libtiff=4.5.0=h6adf6a1_2 + - libtool=2.4.7=h27087fc_0 + - libudev1=253=h0b41bf4_1 + - libunistring=0.9.10=h7f98852_0 + - libuuid=2.38.1=h0b41bf4_0 + - libuv=1.46.0=hd590300_0 + - libva=2.18.0=h0b41bf4_0 + - libvorbis=1.3.7=h9c3ff4c_0 + - libvpx=1.11.0=h9c3ff4c_3 + - libwebp=1.2.4=h1daa5a0_1 + - libwebp-base=1.2.4=h166bdaf_0 + - libxcb=1.13=h7f98852_1004 + - libxcrypt=4.4.36=hd590300_1 + - libxkbcommon=1.5.0=h79f4944_1 + - libxml2=2.10.3=hca2bb57_4 + - libzip=1.10.1=h2629f0a_3 + - libzlib=1.2.13=hd590300_5 + - lmdb=0.9.29=h2531618_0 + - lz4-c=1.9.3=h9c3ff4c_1 + - markupsafe=2.1.3=py38h01eb140_1 + - matplotlib=3.5.1=py38h578d9bd_0 + - matplotlib-base=3.5.1=py38hf4fb855_0 + - matplotlib-inline=0.1.6=pyhd8ed1ab_0 + - mistune=3.0.2=pyhd8ed1ab_0 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.1=py38hd3c417c_0 + - mkl_random=1.2.2=py38h51133e4_0 + - mpg123=1.32.4=h59595ed_0 + - munkres=1.1.4=pyh9f0ad1d_0 + - mysql-common=8.0.33=hf1915f5_2 + - mysql-libs=8.0.33=hca2cd23_2 + - nbclient=0.8.0=pyhd8ed1ab_0 + - nbconvert=7.14.1=pyhd8ed1ab_0 + - nbconvert-core=7.14.1=pyhd8ed1ab_0 + - nbconvert-pandoc=7.14.1=pyhd8ed1ab_0 + - nbformat=5.9.2=pyhd8ed1ab_0 + - ncurses=6.4=h6a678d5_0 + - nest-asyncio=1.5.8=pyhd8ed1ab_0 + - nettle=3.9.1=h7ab15ed_0 + - ninja=1.11.0=h924138e_0 + - notebook=7.0.6=pyhd8ed1ab_0 + - notebook-shim=0.2.3=pyhd8ed1ab_0 + - nsight-compute=2023.3.1.1=0 + - nspr=4.35=h27087fc_0 + - nss=3.96=h1d7d5a4_0 + - numexpr=2.8.4=py38he184ba9_0 + - numpy=1.24.4=py38h59b608b_0 + - opencv=4.7.0=py38h578d9bd_0 + - openh264=2.3.1=hcb278e6_2 + - openjpeg=2.5.0=hfec8fc6_2 + - openssl=3.1.4=hd590300_0 + - overrides=7.4.0=pyhd8ed1ab_0 + - p11-kit=0.24.1=hc5aa10d_0 + - packaging=21.3=pyhd8ed1ab_0 + - pandoc=3.1.3=h32600fe_0 + - pandocfilters=1.5.0=pyhd8ed1ab_0 + - parso=0.8.3=pyhd8ed1ab_0 + - pathtools=0.1.2=py_1 + - patsy=0.5.6=pyhd8ed1ab_0 + - pcre2=10.40=hc3806b6_0 + - pexpect=4.8.0=pyh1a96a4e_2 + - pickleshare=0.7.5=py_1003 + - pillow=9.4.0=py38hde6dc18_1 + - pip=20.3.3=py38h06a4308_0 + - pixman=0.43.0=h59595ed_0 + - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1 + - platformdirs=4.1.0=pyhd8ed1ab_0 + - ply=3.11=py_1 + - pooch=1.7.0=py38h06a4308_0 + - poppler=23.03.0=h091648b_0 + - poppler-data=0.4.12=hd8ed1ab_0 + - postgresql=15.2=h3248436_0 + - proj=9.1.1=h8ffa02c_2 + - prometheus_client=0.19.0=pyhd8ed1ab_0 + - prompt-toolkit=3.0.42=pyha770c72_0 + - prompt_toolkit=3.0.42=hd8ed1ab_0 + - psutil=5.9.7=py38h01eb140_0 + - pthread-stubs=0.4=h36c2ea0_1001 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pulseaudio=16.1=hcb278e6_3 + - pulseaudio-client=16.1=h5195f5e_3 + - pulseaudio-daemon=16.1=ha8d29e2_3 + - pure_eval=0.2.2=pyhd8ed1ab_0 + - py-opencv=4.7.0=py38h6f1a3b6_0 + - pycparser=2.21=pyhd8ed1ab_0 + - pygments=2.17.2=pyhd8ed1ab_0 + - pyparsing=3.1.1=pyhd8ed1ab_0 + - pyqt=5.15.7=py38ha0d8c90_3 + - pyqt5-sip=12.11.0=py38h8dc9893_3 + - pysocks=1.7.1=pyha2e5f31_6 + - python=3.8.18=hd12c33a_0_cpython + - python-dateutil=2.8.2=pyhd8ed1ab_0 + - python-fastjsonschema=2.19.1=pyhd8ed1ab_0 + - python-json-logger=2.0.7=pyhd8ed1ab_0 + - python-lmdb=1.4.1=py38hdb54ec7_1 + - python_abi=3.8=4_cp38 + - pytorch-cuda=11.6=h867d48c_1 + - pytorch-mutex=1.0=cuda + - pytz=2023.3.post1=pyhd8ed1ab_0 + - pyyaml=6.0.1=py38h01eb140_1 + - pyzmq=25.1.2=py38h34c975a_0 + - qt-main=5.15.6=h602db52_6 + - qtconsole-base=5.5.1=pyha770c72_0 + - qtpy=2.4.1=pyhd8ed1ab_0 + - readline=8.2=h5eee18b_0 + - referencing=0.32.1=pyhd8ed1ab_0 + - requests=2.31.0=pyhd8ed1ab_0 + - rfc3339-validator=0.1.4=pyhd8ed1ab_0 + - rfc3986-validator=0.1.1=pyh9f0ad1d_0 + - safetensors=0.3.3=py38h0cc4f7c_1 + - scikit-learn=1.2.1=py38h6a678d5_0 + - send2trash=1.8.2=pyh41d4057_0 + - sentry-sdk=1.39.2=pyhd8ed1ab_0 + - setproctitle=1.3.3=py38h01eb140_0 + - setuptools=68.2.2=py38h06a4308_0 + - sip=6.7.12=py38h17151c0_0 + - six=1.16.0=pyhd3eb1b0_1 + - snappy=1.1.10=h9fff704_0 + - sniffio=1.3.0=pyhd8ed1ab_0 + - soupsieve=2.5=pyhd8ed1ab_1 + - sqlite=3.41.2=h5eee18b_0 + - stack_data=0.6.2=pyhd8ed1ab_0 + - statsmodels=0.13.2=py38h6c62de6_0 + - svt-av1=1.4.1=hcb278e6_0 + - terminado=0.18.0=pyh0d859eb_0 + - threadpoolctl=2.2.0=pyh0d69192_0 + - tiledb=2.13.2=hd532e3d_0 + - timm=0.9.12=pyhd8ed1ab_0 + - tinycss2=1.2.1=pyhd8ed1ab_0 + - tk=8.6.13=noxft_h4845f30_101 + - toml=0.10.2=pyhd8ed1ab_0 + - tomli=2.0.1=pyhd8ed1ab_0 + - torchaudio=0.13.0=py38_cu116 + - torchvision=0.14.0=py38_cu116 + - tqdm=4.66.1=pyhd8ed1ab_0 + - traitlets=5.14.1=pyhd8ed1ab_0 + - ttach=0.0.3=pyhd8ed1ab_0 + - types-python-dateutil=2.8.19.20240106=pyhd8ed1ab_0 + - typing-extensions=4.9.0=hd8ed1ab_0 + - typing_extensions=4.9.0=pyha770c72_0 + - typing_utils=0.1.0=pyhd8ed1ab_0 + - tzcode=2023d=h3f72095_0 + - unicodedata2=15.1.0=py38h01eb140_0 + - uri-template=1.3.0=pyhd8ed1ab_0 + - uriparser=0.9.7=hcb278e6_1 + - urllib3=2.1.0=pyhd8ed1ab_0 + - wandb=0.16.2=pyhd8ed1ab_0 + - wcwidth=0.2.13=pyhd8ed1ab_0 + - webcolors=1.13=pyhd8ed1ab_0 + - webencodings=0.5.1=pyhd8ed1ab_2 + - websocket-client=1.7.0=pyhd8ed1ab_0 + - wheel=0.41.2=py38h06a4308_0 + - widgetsnbextension=4.0.9=pyhd8ed1ab_0 + - x264=1!164.3095=h166bdaf_2 + - x265=3.5=h924138e_3 + - xcb-util=0.4.0=h516909a_0 + - xcb-util-image=0.4.0=h166bdaf_0 + - xcb-util-keysyms=0.4.0=h516909a_0 + - xcb-util-renderutil=0.3.9=h166bdaf_0 + - xcb-util-wm=0.4.1=h516909a_0 + - xerces-c=3.2.4=h55805fa_1 + - xkeyboard-config=2.38=h0b41bf4_0 + - xorg-fixesproto=5.0=h7f98852_1002 + - xorg-inputproto=2.3.2=h7f98852_1002 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.1.1=hd590300_0 + - xorg-libsm=1.2.4=h7391055_0 + - xorg-libx11=1.8.4=h0b41bf4_0 + - xorg-libxau=1.0.11=hd590300_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxfixes=5.0.3=h7f98852_1004 + - xorg-libxi=1.7.10=h7f98852_0 + - xorg-libxrender=0.9.10=h7f98852_1003 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h0b41bf4_1003 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.4.5=h5eee18b_0 + - yaml=0.2.5=h7f98852_2 + - zeromq=4.3.5=h59595ed_0 + - zipp=3.17.0=pyhd8ed1ab_0 + - zlib=1.2.13=hd590300_5 + - zstd=1.5.2=h8a70e8d_1 + - pip: + - absl-py==2.0.0 + - aiohttp==3.9.1 + - aiosignal==1.3.1 + - antlr4-python3-runtime==4.9.3 + - async-timeout==4.0.3 + - backports-zoneinfo==0.2.1 + - cachetools==5.3.2 + - dtaidistance==2.3.12 + - einops==0.8.1 + - frozenlist==1.4.1 + - future==1.0.0 + - google-auth==2.26.2 + - google-auth-oauthlib==1.0.0 + - grpcio==1.60.0 + - imageio==2.9.0 + - importlib-metadata==6.11.0 + - joblib==1.3.2 + - lazy-loader==0.3 + - lightning-utilities==0.11.9 + - markdown==3.5.2 + - markdown-it-py==3.0.0 + - mdurl==0.1.2 + - mpmath==1.3.0 + - multidict==6.0.4 + - networkx==3.1 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==9.1.0.70 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.20.5 + - nvidia-nvjitlink-cu12==12.3.101 + - nvidia-nvtx-cu12==12.1.105 + - oauthlib==3.2.2 + - omegaconf==2.3.0 + - pandas==2.0.3 + - piexif==1.1.3 + - protobuf==4.25.2 + - pyarrow==14.0.2 + - pyasn1==0.5.1 + - pyasn1-modules==0.3.0 + - pydeck==0.8.1b0 + - pydeprecate==0.3.1 + - pytorch-lightning==1.4.2 + - pywavelets==1.4.1 + - regex==2023.12.25 + - requests-oauthlib==1.3.1 + - rich==13.7.0 + - rpds-py==0.16.2 + - rsa==4.9 + - sacremoses==0.1.1 + - scikit-image==0.20.0 + - scipy==1.9.1 + - smmap==5.0.1 + - sympy==1.13.3 + - taming-transformers==0.0.1 + - taming-transformers-rom1504==0.0.6 + - tenacity==8.2.3 + - tensorboard==2.14.0 + - tensorboard-data-server==0.7.2 + - tifffile==2023.7.10 + - tokenizers==0.10.3 + - toolz==0.12.0 + - torch==2.4.1 + - torchmetrics==0.7.3 + - tornado==6.4 + - triton==3.0.0 + - tzdata==2023.4 + - tzlocal==5.2 + - urwid==2.4.1 + - validators==0.22.0 + - watchdog==3.0.0 + - werkzeug==3.0.1 + - yarl==1.9.4 +prefix: /home/v-dengbowen/anaconda3/envs/disdiff_test diff --git a/TarDiff/generation.sh b/TarDiff/generation.sh new file mode 100755 index 0000000..06444e3 --- /dev/null +++ b/TarDiff/generation.sh @@ -0,0 +1 @@ +python guidance_generation.py --base configs/base/mimic_icustay_base.yaml --gpus 0, --uncond --logdir MIMIC_ICUSTAY -sl 24 diff --git a/TarDiff/guidance_generation.py b/TarDiff/guidance_generation.py new file mode 100755 index 0000000..f9ceb61 --- /dev/null +++ b/TarDiff/guidance_generation.py @@ -0,0 +1,534 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import argparse, os, sys, datetime, glob, importlib, csv +import numpy as np +import time +import torch +import torchvision +import pytorch_lightning as pl +from tqdm import tqdm +from packaging import version +from omegaconf import OmegaConf +from torch.utils.data import random_split, DataLoader, Dataset, Subset +from functools import partial +from PIL import Image +import wandb +from pytorch_lightning import seed_everything +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor +from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info +from utils.callback_utils import prepare_trainer_configs +from ldm.util import instantiate_from_config +from pathlib import Path +import matplotlib.pyplot as plt6 +import pandas as pd +from ldm.modules.guidance_scorer import GradDotCalculator +import torch.nn as nn +from torch.utils.data import TensorDataset, DataLoader +import pickle as pkl +from collections import Counter +from classifier.model import RNNClassifier +from classifier.classifier_train import TimeSeriesDataset +from ldm.modules.guidance_scorer import GradDotCalculator + + +def get_parser(**parser_kwargs): + + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument("-n", + "--name", + type=str, + const=True, + default="", + nargs="?", + help="postfix for logdir") + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-t", + "--train", + type=str2bool, + const=True, + default=True, + nargs="?", + help="train", + ) + parser.add_argument( + "-r", + "--resume", + type=str2bool, + const=True, + default=False, + nargs="?", + help="resume and test", + ) + parser.add_argument( + "--no-test", + type=str2bool, + const=True, + default=False, + nargs="?", + help="disable test", + ) + parser.add_argument( + "--normalize", + type=str, + const=True, + default=None, + nargs="?", + help="normalization method", + ) + parser.add_argument("-p", + "--project", + help="name of new or path to existing project") + parser.add_argument( + "-d", + "--debug", + type=str2bool, + nargs="?", + const=True, + default=False, + help="enable post-mortem debugging", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=23, + help="seed for seed_everything", + ) + parser.add_argument( + "-f", + "--postfix", + type=str, + default="", + help="post-postfix for default name", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + default="/mnt/storage/ts_diff_newer", + help="directory for logging dat shit", + ) + parser.add_argument( + "--scale_lr", + type=str2bool, + nargs="?", + const=True, + default=False, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + parser.add_argument( + "-dw", + "--dis_weight", + type=float, + const=True, + default=1., + nargs="?", + help="weight of disentangling loss", + ) + parser.add_argument( + "-dt", + "--dis_loss_type", + type=str, + const=True, + default=None, + nargs="?", + help="type of disentangling loss", + ) + parser.add_argument( + "-tg", + "--train_stage", + type=str, + const=True, + default='pre', + nargs="?", + help="pre / dis", + ) + parser.add_argument( + "-ds", + "--dataset_name", + type=str, + const=True, + default='elec', + nargs="?", + help="dataset name", + ) + parser.add_argument( + "-dp", + "--dataset_prefix", + type=str, + const=True, + default='/mnt/storage/tsdiff/data', + nargs="?", + help="dataset prefix", + ) + parser.add_argument( + "-cp", + "--ckpt_prefix", + type=str, + const=True, + default='/mnt/storage/tsdiff/outputs', + nargs="?", + help="ckpt prefix", + ) + parser.add_argument( + "-sp", + "--sample_path", + type=str, + const=True, + default='/mnt/storage/ts_generated/ours_amlt', + nargs="?", + help="samples prefix", + ) + + parser.add_argument( + "-pl", + "--pair_loss_type", + type=str, + const=True, + default='', + nargs="?", + help="pair loss type: cosine or l2, otherwise not used") + parser.add_argument("-sl", + "--seq_len", + type=int, + const=True, + default=24, + nargs="?", + help="sequence length") + parser.add_argument("-uc", + "--uncond", + action='store_true', + help="unconditional generation") + parser.add_argument("-si", + "--split_inv", + action='store_true', + help="split invariant encoder") + parser.add_argument("-cl", + "--ce_loss", + action='store_true', + help="cross entropy loss") + parser.add_argument("-up", + "--use_prototype", + action='store_true', + help="use prototype") + parser.add_argument("-pd", + "--part_drop", + action='store_true', + help="use partial dropout conditions") + parser.add_argument("-o", + "--orth_emb", + action='store_true', + help="use orthogonal prototype embedding") + parser.add_argument("-ma", + "--mask_assign", + action='store_true', + help="use mask assignment") + parser.add_argument("-ha", + "--hard_assign", + action='store_true', + help="use hard assignment") + parser.add_argument("-im", + "--inter_mask", + action='store_true', + help="use intermediate assignment") + parser.add_argument("-bs", + "--batch_size", + type=int, + const=True, + default=256, + nargs="?", + help="batch_size") + parser.add_argument("-ms", + "--max_step_sum", + type=int, + const=True, + default=20000, + nargs="?", + help="max training steps") + parser.add_argument("-nl", + "--num_latents", + type=int, + const=True, + default=16, + nargs="?", + help="sequence length") + parser.add_argument("-pw", + "--pair_weight", + type=float, + const=True, + default=1.0, + nargs="?", + help="pair loss weight") + parser.add_argument("-lr", + "--overwrite_learning_rate", + type=float, + const=True, + default=None, + nargs="?", + help="learning rate") + parser.add_argument("-g", + "--gen_ckpt_path", + type=str, + const=True, + default='', + nargs="?", + help="ckpt path for diffusion") + parser.add_argument("-dp", + "--downstream_pth_path", + type=str, + const=True, + default='', + nargs="?", + help="ckpt path for downstream model") + parser.add_argument("-sp", + "--save_path", + type=str, + const=True, + default='', + nargs="?", + help="path for saving generated data") + parser.add_argument("-op", + "--origin_data_path", + type=str, + const=True, + default='', + nargs="?", + help="path for origin data") + return parser + + +def nondefault_trainer_args(opt): + parser = argparse.ArgumentParser() + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args([]) + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + + +if __name__ == "__main__": + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + sys.path.append(os.getcwd()) + + parser = get_parser() + parser = Trainer.add_argparse_args(parser) + + opt, unknown = parser.parse_known_args() + + if opt.name: + name = opt.name + elif opt.base: + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = cfg_name + else: + name = "" + + seed_everything(opt.seed) + + # try: + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + # Customize config from opt: + n_data = len(config.data['params']['data_path_dict']) + config.model['params']['image_size'] = opt.seq_len + config.model['params']['unet_config']['params']['image_size'] = opt.seq_len + config.data['params']['window'] = opt.seq_len + config.data['params']['batch_size'] = opt.batch_size + bs = opt.batch_size + if opt.max_steps: + config.lightning['trainer']['max_steps'] = opt.max_steps + max_steps = opt.max_steps + else: + max_steps = config.lightning['trainer']['max_steps'] + if opt.debug: + config.lightning['trainer']['max_steps'] = 10 + config.lightning['callbacks']['image_logger']['params'][ + 'batch_frequency'] = 5 + max_steps = 10 + if opt.overwrite_learning_rate is not None: + config.model['base_learning_rate'] = opt.overwrite_learning_rate + print( + f"Setting learning rate (overwritting config file) to {opt.overwrite_learning_rate:.2e}" + ) + base_lr = opt.overwrite_learning_rate + else: + base_lr = config.model['base_learning_rate'] + + nowname = f"{name.split('-')[-1]}_{opt.seq_len}_nl_{opt.num_latents}_lr{base_lr:.1e}_bs{opt.batch_size}_ms{int(max_steps/1000)}k" + # config.data['params']['batch_size'] = opt.batch_size + + if opt.normalize is not None: + config.data['params']['normalize'] = opt.normalize + nowname += f"_{config.data['params']['normalize']}" + else: + assert 'normalize' in config.data['params'] + nowname += f"_{config.data['params']['normalize']}" + + config.model['params']['pair_loss_flag'] = False + if opt.uncond: + config.model['params']['cond_stage_config'] = "__is_unconditional__" + config.model['params']['cond_stage_trainable'] = False + nowname += f"_uncond" + else: + config.model['params']['cond_stage_config']['params'][ + 'window'] = opt.seq_len + config.model['params']['cond_stage_config']['params'][ + 'num_latents'] = opt.num_latents + + config.model['params']['pair_loss_flag'] = False + config.model['params']['pair_loss_type'] = None + config.model['params']['pair_loss_weight'] = 0 + nowname += f"_pl-None" + + config.model['params']['cond_stage_config']['params'][ + 'split_inv'] = False + config.model['params']['unet_config']['params'][ + 'latent_unit'] = opt.num_latents + + config.model['params']['cond_stage_config']['params'][ + 'use_prototype'] = False + config.model['params']['cond_stage_config']['params'][ + 'mask_assign'] = False + config.model['params']['cond_stage_config']['params'][ + 'hard_assign'] = False + config.model['params']['unet_config']['params']['inter_mask'] = False + config.model['params']['cond_stage_config']['params'][ + 'orth_proto'] = False + + nowname += f"_seed{opt.seed}" + # nowname = nowname + logdir = os.path.join(opt.logdir, cfg_name, nowname) + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") + + metrics_dir = Path(logdir) / 'metric_dict.pkl' + if metrics_dir.exists(): + print(f"Metric exists! Skipping {nowname}") + sys.exit(0) + lightning_config = config.pop("lightning", OmegaConf.create()) + # merge trainer cli with config + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + # default to ddp + # trainer_config["accelerator"] = "ddp" + trainer_config["accelerator"] = "gpu" + for k in nondefault_trainer_args(opt): + trainer_config[k] = getattr(opt, k) + if not "gpus" in trainer_config: + del trainer_config["accelerator"] + cpu = True + else: + gpuinfo = trainer_config["gpus"] + print(f"Running on GPUs {gpuinfo}") + cpu = False + trainer_opt = argparse.Namespace(**trainer_config) + lightning_config.trainer = trainer_config + + # model + if "LatentDiffusion" in config.model['target']: + if opt.dis_loss_type != None: + config.model["params"]["dis_loss_type"] = opt.dis_loss_type + config.model["params"]["dis_weight"] = opt.dis_weight + + alpha = 0.00001 + save_path = opt.save_path_origin + + train_tuple = pd.read_pickle(opt.origin_data_path) + + generation_nums_label = dict(Counter(train_tuple[1])) + + synt_nums_label = dict(Counter(train_tuple[1])) + config.model['params']['ckpt_path'] = opt.gen_ckpt_path + model = instantiate_from_config(config.model) + model = model.cuda() + model.eval() + data = instantiate_from_config(config.data) + data.prepare_data() + data.setup() + + print("#### Data Preparation Finished #####") + + downstream_model = RNNClassifier( + input_dim=7, + hidden_dim=256, + num_layers=2, + rnn_type='gru', + num_classes=1, + ) + + downstream_model.load_state_dict( + torch.load(opt.downstream_pth_path)['model_state']) + print('#### Downstream Model Loaded #####') + + print("#### Start Generating Samples #####") + for dataset in data.norm_train_dict: + normalizer = data.normalizer_dict[dataset] + train_dataset = TimeSeriesDataset(data.transform( + train_tuple[0].transpose(0, 2, 1), normalizer), + train_tuple[1], + normalize=False) + train_dataloader = DataLoader(train_dataset, + batch_size=1, + shuffle=False) + c = GradDotCalculator(downstream_model, train_dataloader, + nn.BCEWithLogitsLoss(), alpha) + generated_data_all = None + + print(f"#### Start Generating Samples with alpha {alpha} #####") + for label_sample, total_samples in tqdm(generation_nums_label.items()): + label = torch.tensor([label_sample] * total_samples, + device='cuda', + dtype=torch.long) + samples, z_denoise_row = model.sample_log(cond=None, + batch_size=total_samples, + ddim=True, + ddim_steps=20, + eta=1., + sem_guide=True, + sem_guide_type='GDC', + label=label, + GDCalculater=c) + norm_samples = model.decode_first_stage( + samples).detach().cpu().numpy() + inv_samples = data.inverse_transform(norm_samples, + data_name=dataset) + generated_data = np.array(inv_samples).transpose(0, 2, 1) + if generated_data_all is None: + generated_data_all = generated_data + else: + generated_data_all = np.concatenate( + [generated_data_all, generated_data], axis=0) + generated_data_all = generated_data_all.transpose(0, 2, 1) + label = np.concatenate([ + np.full(count, label) + for label, count in generation_nums_label.items() + ]) + tmp_name = f'synt_tardiff_noise_rnn_train_guidance_sc{alpha}' + with open(save_path + '/' + f'{tmp_name}.pkl', 'wb') as f: + #with open(save_path / f'alpha_search/{tmp_name}.pkl', 'wb') as f: + + pkl.dump((generated_data_all, label), f) + print(f"Saved {tmp_name}.pkl in {save_path}") diff --git a/TarDiff/images/overview.png b/TarDiff/images/overview.png new file mode 100644 index 0000000..a56dd0d Binary files /dev/null and b/TarDiff/images/overview.png differ diff --git a/TarDiff/ldm/lr_scheduler.py b/TarDiff/ldm/lr_scheduler.py new file mode 100755 index 0000000..4d6f365 --- /dev/null +++ b/TarDiff/ldm/lr_scheduler.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + + def __init__(self, + warm_up_steps, + lr_min, + lr_max, + lr_start, + max_decay_steps, + verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - + self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - + self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - + self.lr_min) * (1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + + def __init__(self, + warm_up_steps, + f_min, + f_max, + f_start, + cycle_lengths, + verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len( + f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle] + ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / ( + self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * ( + self.f_max[cycle] - self.f_min[cycle]) * (1 + + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle] + ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( + self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f diff --git a/TarDiff/ldm/models/autoencoder.py b/TarDiff/ldm/models/autoencoder.py new file mode 100755 index 0000000..2d588b5 --- /dev/null +++ b/TarDiff/ldm/models/autoencoder.py @@ -0,0 +1,561 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config +from copy import copy + + +class VQModel(pl.LightningModule): + + def __init__( + self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, + embed_dim, + beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", + torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print( + f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}." + ) + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + if any(["first_stage_model" in k for k in keys]): + old_sd = copy(sd) + sd = {} + for key in old_sd.keys(): + if "first_stage_model" in key: + sd[key.replace("first_stage_model.", "")] = old_sd[key] + missing, unexpected = self.load_state_dict(sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_, _, ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, + 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice( + np.arange(lower_size, upper_size + 16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train") + # predicted_indices=ind) + log_dict_ae.update({'train/epoch_num': self.current_epoch}) + + self.log_dict(log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss( + qloss, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train") + log_dict_disc.update({'train/epoch_num': self.current_epoch}) + self.log_dict(log_dict_disc, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, + batch_idx, + suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, + x, + xrec, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + suffix, + predicted_indices=ind) + + discloss, log_dict_disc = self.loss(qloss, + x, + xrec, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + suffix, + predicted_indices=ind) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", + rec_loss, + prog_bar=True, + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True) + self.log(f"val{suffix}/aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True) + # if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor * self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quantize.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr_g, + betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, + betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, + lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, + lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", + torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + # h = self.encoder(x) + # h = self.quant_conv(h) + # quant, emb_loss, info = self.quantize(h) + # return quant, emb_loss, info + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], + 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", + torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, + 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train") + + log_dict_ae["train/epoch_num"] = self.current_epoch + self.log("aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True) + self.log_dict(log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train") + + self.log("discloss", + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True) + self.log_dict(log_dict_disc, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val") + + discloss, log_dict_disc = self.loss(inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, + betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, + betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", + torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/TarDiff/ldm/models/diffusion/__init__.py b/TarDiff/ldm/models/diffusion/__init__.py new file mode 100755 index 0000000..9a04545 --- /dev/null +++ b/TarDiff/ldm/models/diffusion/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/TarDiff/ldm/models/diffusion/classifier.py b/TarDiff/ldm/models/diffusion/classifier.py new file mode 100755 index 0000000..abb6f56 --- /dev/null +++ b/TarDiff/ldm/models/diffusion/classifier.py @@ -0,0 +1,324 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted( + glob(os.path.join(diffusion_path, 'configs', + '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict( + sd, + strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy( + self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print( + '#####################################################################' + ) + print(f'load from ckpt "{ckpt_path}"') + print( + '#####################################################################' + ) + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level( + x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample( + x_start=x, + t=t, + noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, + size=(h // 2, w // 2), + mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, + None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k(logits, + targets, + k=1, + reduction="mean") + log[f"{log_prefix}/acc@5"] = self.compute_top_k(logits, + targets, + k=5, + reduction="mean") + + self.log_dict(log, + prog_bar=False, + logger=True, + on_step=self.training, + on_epoch=True) + self.log('loss', + log[f"{log_prefix}/loss"], + prog_bar=True, + logger=False) + self.log('global_step', + self.global_step, + logger=False, + on_epoch=False, + prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', + lr, + on_step=True, + logger=True, + on_epoch=False, + prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input( + batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, + self.diffusion_model.num_timesteps, + (x.shape[0], ), + device=self.device).long() + else: + t = torch.full(size=(x.shape[0], ), + fill_value=t, + device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = { + t: { + 'acc@1': [], + 'acc@5': [] + } + for t in range(0, self.diffusion_model.num_timesteps, + self.diffusion_model.log_every_t) + } + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append( + self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append( + self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), + lr=self.learning_rate, + weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [{ + 'scheduler': + LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': + 'step', + 'frequency': + 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), + num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb( + pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/TarDiff/ldm/models/diffusion/ddim.py b/TarDiff/ldm/models/diffusion/ddim.py new file mode 100755 index 0000000..988b7f6 --- /dev/null +++ b/TarDiff/ldm/models/diffusion/ddim.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.modules.diffusionmodules.util import return_wrap, extract_into_tensor + + +class DDIMSampler(object): + + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device(self.device): + attr = attr.to(torch.device(self.device)) + setattr(self, name, attr) + + def make_schedule(self, + ddim_num_steps, + ddim_discretize="uniform", + ddim_eta=0., + verbose=True): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[ + 0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model + .device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', + to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', + to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', + to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev, ddim_coef = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_coef', ddim_coef) + self.register_buffer('ddim_sqrt_one_minus_alphas', + np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * + (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', + sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + label=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, W = shape + size = (batch_size, C, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling( + cond=conditioning, + shape=size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + label=label, + **kwargs) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + label=None, + **kwargs): + device = self.model.betas.device + b = shape[0] + label = torch.ones(b, dtype=torch.long, + device=device) if label is None else torch.tensor( + label, dtype=torch.long, device=device) + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int( + min(timesteps / self.ddim_timesteps.shape[0], 1) * + self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range( + 0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[ + 0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b, ), step, device=device, dtype=torch.long) + + # if mask is not None: + # assert x0 is not None + # img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + # img = img_orig * mask + (1. - mask) * img + outs = self.p_sample_ddim( + x=img, + c=cond, + mask=mask, + label=label, + t=ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, + x, + c, + t, + index, + label=None, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + **kwargs): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c, label=label, **kwargs) + e_t = return_wrap( + e_t, torch.full((b, 1, 1), + self.ddim_coef[index], + device=device)) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, + t_in, + c_in, + label=label).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - + e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, + **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1), + sqrt_one_minus_alphas[index], + device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + # p.savez("data.npz", z=z, x = x, xrec = xrec, x_T = x_T, time = time, alphas = alphas, alphas_prev = alphas_prev, sqrt_one_minus_alphas = sqrt_one_minus_alphas, sigmas = sigmas.cpu().numpy(),e_t = e_t) + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, + repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/TarDiff/ldm/models/diffusion/ddpm.py b/TarDiff/ldm/models/diffusion/ddpm.py new file mode 100755 index 0000000..0381182 --- /dev/null +++ b/TarDiff/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1962 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.guided_ddim import GuideDDIMSampler +from ldm.modules.diffusionmodules.util import return_wrap +import copy +import os +import pandas as pd +#from IF4Med.prepare_model_data import prepare_model_dataset +from ldm.modules.guidance_scorer import GradDotCalculator + +__conditioning_keys__ = { + 'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y' +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + ): + super().__init__() + assert parameterization in [ + "eps", "x0" + ], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print( + f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" + ) + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, + ignore_keys=ignore_keys, + only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, + size=(self.num_timesteps, )) + self.ce_loss = nn.CrossEntropyLoss(reduction="none") + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + def register_schedule(self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[ + 0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', + to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', + to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', + to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', + to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + 'posterior_log_variance_clipped', + to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer( + 'posterior_mean_coef1', + to_torch(betas * np.sqrt(alphas_cumprod_prev) / + (1. - alphas_cumprod))) + self.register_buffer( + 'posterior_mean_coef2', + to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / + (1. - alphas_cumprod))) + self.register_buffer( + "shift_coef", + -to_torch(np.sqrt(alphas)) * (1. - self.alphas_cumprod_prev) / + torch.sqrt(1. - self.alphas_cumprod)) + self.register_buffer("ddim_coef", -self.sqrt_one_minus_alphas_cumprod) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / (2 * self.posterior_variance * + to_torch(alphas) * + (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / ( + 2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + self.load_epoch = sd['epoch'] + self.load_step = sd["global_step"] + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict( + sd, + strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, + x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, + t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * + x_t - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, + x_t.shape) * noise) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * + x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t) + posterior_variance = extract_into_tensor(self.posterior_variance, t, + x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + eps_pred = return_wrap(model_out, + extract_into_tensor(self.ddim_coef, t, x.shape)) + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=eps_pred) + elif self.parameterization == "x0": + x_recon = eps_pred + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape( + b, *((1, ) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * + model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), + desc='Sampling t', + total=self.num_timesteps): + img = self.p_sample(img, + torch.full((b, ), + i, + device=device, + dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, + pred, + reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + eps_pred = return_wrap( + model_out, extract_into_tensor(self.shift_coef, t, x_start.shape)) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError( + f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(eps_pred, target, mean=False).mean(dim=[1, 2]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, + self.num_timesteps, (x.shape[0], ), + device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 2: + x = x[..., None] + x = rearrange(x, 'b t c -> b c t') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True) + + self.log("global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', + lr, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + # pass + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = { + key + '_ema': loss_dict_ema[key] + for key in loss_dict_ema + } + self.log_dict(loss_dict_no_ema, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True) + self.log_dict(loss_dict_ema, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c t -> b n c t') + denoise_grid = rearrange(denoise_grid, 'b n c t -> (b n) c t') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, + batch, + N=8, + n_row=2, + sample=True, + return_keys=None, + **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, + return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + dis_loss_flag=False, + detach_flag=False, + train_enc_flag=False, + dis_weight=1.0, + dis_loss_type="IM", + cond_drop_prob=None, + cond_part_drop=False, + pair_loss_flag=False, + pair_loss_type=None, + pair_loss_weight=1.0, + grad_watch=False, + ce_loss_flag=False, + *args, + **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.dis_loss_flag = dis_loss_flag + self.pair_loss_flag = pair_loss_flag + self.pair_loss_type = pair_loss_type + self.pair_loss_weight = pair_loss_weight + self.ce_loss_flag = ce_loss_flag + self.detach_flag = detach_flag + self.train_enc_flag = train_enc_flag + self.dis_weight = dis_weight + self.dis_loss_type = dis_loss_type + self.cond_drop_prob = cond_drop_prob + self.cond_part_drop = cond_part_drop + try: + self.num_downs = len( + first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + if self.ce_loss_flag: + self.ce_loss = nn.CrossEntropyLoss(reduction="none") + # if grad_watch: + # import wandb + # self.model_watch = wandb.watch(self.model) + # if self.cond_stage_trainable: + # self.cond_watch = wandb.watch(self.cond_stage_model) + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps, ), + fill_value=self.num_timesteps - 1, + dtype=torch.long) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, + self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + if hasattr(self.model.diffusion_model, "scale_factor"): + del self.scale_factor + self.register_buffer('scale_factor', + self.model.diffusion_model.scale_factor) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING Pre-Trained STD-RESCALING ###") + else: + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, + linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print( + f"Training {self.__class__.__name__} as an unconditional model." + ) + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, + samples, + desc='', + force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append( + self.decode_first_stage( + zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c t -> b n c t') + # denoise_grid = rearrange(denoise_grid, 'b n c t -> (b n) c t') + # denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c, return_mask=False): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable( + self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + elif self.cond_stage_model is None: + c, mask = None, None + else: + c, mask = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + if return_mask: + return c, mask + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], + dim=-1), + dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip( + weighting, + self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], + ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, + Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip( + L_weighting, self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, + x, + kernel_size, + stride, + uf=1, + df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, + dilation=1, + padding=0, + stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, + Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, + w) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, + dilation=1, + padding=0, + stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, + kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, + x.shape[3] * uf), + **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, + kernel_size[1] * uf, Ly, Lx, + x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * + uf) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, + dilation=1, + padding=0, + stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, + kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, + x.shape[3] // df), + **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, + kernel_size[1] // df, Ly, Lx, + x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // + df) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + return_mask=False): + x = super().get_input(batch, k) + y = batch['label'].long() + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c, mask = self.get_learned_conditioning(xc, + return_mask=True) + else: + c, mask = self.get_learned_conditioning(xc.to(self.device), + return_mask=True) + else: + c = xc + if bs is not None: + c = c[:bs] + + else: + c = None + xc = None + mask = None + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + if return_mask: + out.append(mask) + out.append(y) + return out + + @torch.no_grad() + def decode_first_stage(self, + z, + predict_cids=False, + force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, + shape=None) + z = rearrange(z, 'b t c -> b c t').contiguous() + + z = 1. / self.scale_factor * z + + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode( + z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, + z, + predict_cids=False, + force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, + shape=None) + z = rearrange(z, 'b t c -> b c t').contiguous() + + z = 1. / self.scale_factor * z + + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode( + z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c, label = self.get_input(batch, self.first_stage_key) + kwargs['data_key'] = batch['data_key'].to(self.device) + loss = self(x, c, label, **kwargs) + # if self.cond_stage_model.grad_hook: + # if self.training: + # self.cond_stage_model.latent_handle.remove() + return loss + + def forward(self, x, c, label, *args, **kwargs): + t = torch.randint(0, + self.num_timesteps, (x.shape[0], ), + device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c, return_mask=True) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, + t=tc, + noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, label, *args, **kwargs) + + def apply_model(self, + x_noisy, + t, + cond, + mask=None, + label=None, + cfg_scale=1, + cond_drop_prob=None, + cond_part_drop=False, + return_ids=False, + sampled_concept=None, + sampled_index=None, + sub_scale=None, + **kwargs): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond, 'mask': mask} + + if cond_drop_prob is None: + x_recon = self.model.cfg_forward(x_noisy, + t, + label=label, + cfg_scale=cfg_scale, + sampled_concept=sampled_concept, + sampled_index=sampled_index, + sub_scale=sub_scale, + **cond) + else: + x_recon = self.model.forward(x_noisy, + t, + label=label, + cfg_scale=cfg_scale, + cond_drop_prob=cond_drop_prob, + cond_part_drop=cond_part_drop, + sampled_concept=sampled_concept, + sampled_index=sampled_index, + sub_scale=sub_scale, + **cond) + + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def dis_loss(self, model_forward, x_t, t, cond, sampled_concept): + if not self.train_enc_flag: + eval_encoder = copy.deepcopy(self.cond_stage_model) + eval_encoder.requires_grad_(False) + eval_encoder.eval() + else: + eval_encoder = self.cond_stage_model + + ddim_coef = extract_into_tensor(self.ddim_coef, t, x_t.shape) + with torch.no_grad(): + eps_hat = model_forward.pred + z_start = self.predict_start_from_noise(x_t, t, eps_hat) + pred_x0_t = self.differentiable_decode_first_stage( + z_start, force_not_quantize=not self.detach_flag).detach() + + eps_new_hat = model_forward.null_pred + ddim_coef * model_forward.sub_grad + z_start_new = self.predict_start_from_noise(x_t, t, eps_new_hat) + pred_x0_new_t = self.differentiable_decode_first_stage( + z_start_new, force_not_quantize=not self.detach_flag).detach() + + pred_z = eval_encoder(pred_x0_t) + z_parts = pred_z.chunk(self.model.diffusion_model.latent_unit, + dim=1) + pred_z = torch.stack(z_parts, dim=1) + + pred_z_new = eval_encoder(pred_x0_new_t) + z_parts = pred_z_new.chunk(self.model.diffusion_model.latent_unit, + dim=1) + cond = cond.chunk(self.model.diffusion_model.latent_unit, dim=1) + pred_z_new = torch.stack(z_parts, dim=1) + cond = torch.stack(cond, dim=1) + + # with torch.no_grad(): + # norm_org = torch.norm(pred_z - cond.detach(), dim=-1) + # norm_Z = torch.norm(pred_z_new - cond.detach(), dim=-1) + norm_org = torch.norm(pred_z - cond, dim=-1) + norm_Z = torch.norm(pred_z_new - cond.detach(), dim=-1) + logits_deta = torch.norm(pred_z - pred_z_new, dim=-1) + logits = norm_org.detach() - norm_Z + + dis_loss = self.ce_loss(logits, + torch.from_numpy(sampled_concept).cuda()) + dis_loss_deta = self.ce_loss(logits_deta, + torch.from_numpy(sampled_concept).cuda()) + + if self.dis_loss_type == "IM": + dis_weight = mean_flat( + (pred_x0_t.detach() - pred_x0_new_t.detach())**2) + elif self.dis_loss_type == "Z": + dis_weight = mean_flat( + (z_start.detach() - z_start_new.detach())**2) + else: + raise NotImplementedError + + return dis_weight * self.dis_weight * (dis_loss + + dis_loss_deta), dis_weight + + def sim_mask(self, arr): + n = arr.shape[0] + mask = torch.zeros((n, n), dtype=torch.int32).to(self.device) + idx1, idx2 = torch.tril_indices(n, n, -1).to(self.device) + mask[idx1, idx2] = (arr[idx1] == arr[idx2]).int() + return mask + + def pair_loss(self, cond, data_key, sim_type="cosine"): + bs = cond.shape[0] + sim_mask = self.sim_mask(data_key) + if cond.dim() == 3: + r = cond[:, -1] + # r = cond.chunk(self.model.diffusion_model.latent_unit, dim=1)[-1] + else: + assert cond.dim() == 2 + r = cond + # r = cond[:,-1] + # r_norm = torch.norm(r.detach(), dim=1) + r_norm = torch.norm(r, dim=1) + mask = torch.tril(torch.ones_like(sim_mask), + -1).to(self.device) # lower triangular matrix + if sim_type == "cosine": + cos_sim = (r @ r.T) * mask * 0.5 / ( + r_norm.view(-1, 1) @ r_norm.view(1, -1) + 1e-8) + cos_sim = cos_sim**2 + sim_loss = -( + (sim_mask * cos_sim - torch.log(1 + torch.exp(cos_sim))) / + (bs**2) * mask).sum() + elif sim_type == "l2": + l2_dist = torch.cdist( + r, r, p=2 + ) # torch.norm(r.view(bs, 1, -1) - r.view(1, bs, -1), dim=-1) ** 2 + sim_loss = -( + (sim_mask * l2_dist - torch.log(1 + torch.exp(l2_dist))) / + (bs * (bs - 1) / 2) * mask).sum() + # sim_loss = (l2_dist * mask / (bs*(bs-1)/2)).sum() + return sim_loss + + def p_losses(self, + x_start, + condmask, + t, + label=None, + noise=None, + data_key=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + if condmask is not None: + cond, mask = condmask + else: + cond = None + mask = None + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + if self.dis_loss_flag and self.global_step > 1000: + sampled_concept = np.random.randint( + self.model.diffusion_model.latent_unit, size=x_noisy.shape[0]) + model_output = self.apply_model(x_noisy, + t, + cond, + mask, + sampled_concept=sampled_concept, + cond_drop_prob=self.cond_drop_prob, + cond_part_drop=self.cond_part_drop) + dis_loss, dis_weight = self.dis_loss(model_output, x_noisy, t, + cond, sampled_concept) + elif self.pair_loss_flag: + model_output = self.apply_model(x_noisy, + t, + cond, + mask, + cond_drop_prob=self.cond_drop_prob, + cond_part_drop=self.cond_part_drop) + pair_loss = self.pair_loss(mask, data_key, self.pair_loss_type) + else: + model_output = self.apply_model(x_noisy, + t, + cond, + mask, + label=label, + cond_drop_prob=self.cond_drop_prob, + cond_part_drop=self.cond_part_drop) + if self.ce_loss_flag: + negce_loss = -self.ce_loss(cond[:, 0], data_key).mean() + eps_pred = return_wrap( + model_output, extract_into_tensor(self.shift_coef, t, + x_start.shape)) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(eps_pred, target, mean=False).mean([1, 2]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t.cpu()].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + if self.dis_loss_flag and self.global_step > 1000: + loss = self.l_simple_weight * loss.mean() + dis_loss.mean() + loss_dict.update({f'{prefix}/dis_loss': dis_loss.mean()}) + loss_dict.update({f'{prefix}/dis_weight': dis_weight.mean()}) + elif self.pair_loss_flag: + loss = self.l_simple_weight * loss.mean( + ) + self.pair_loss_weight * pair_loss + loss_dict.update({f'{prefix}/pair_loss': pair_loss}) + else: + loss = self.l_simple_weight * loss.mean() + if self.ce_loss_flag: + loss = loss + 0.1 * negce_loss + loss_dict.update({f'{prefix}/negce_loss': negce_loss}) + + loss_vlb = self.get_loss(eps_pred, target, mean=False).mean(dim=(1, 2)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + loss_dict.update({f'{prefix}/epoch_num': self.current_epoch}) + loss_dict.update({f'{prefix}/step_num': self.global_step}) + + return loss, loss_dict + + def p_mean_variance(self, + x, + c, + t, + m, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + **kwargs): + t_in = t + model_out = self.apply_model(x, + t_in, + c, + m, + return_ids=return_codebook_ids, + **kwargs) + + eps_pred = return_wrap(model_out, + extract_into_tensor(self.ddim_coef, t, x.shape)) + + if score_corrector is not None: + assert self.parameterization == "eps" + eps_pred = score_corrector.modify_score(self, eps_pred, x, t, c, + **corrector_kwargs) + + if return_codebook_ids: + eps_pred, logits = eps_pred + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=eps_pred) + elif self.parameterization == "x0": + x_recon = eps_pred + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, + indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, + x, + c, + t, + m, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + **kwargs): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, + c=c, + t=t, + m=m, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + **kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape( + b, *((1, ) * (len(x.shape) - 1))) + + # if return_codebook_ids: + # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + **kwargs): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + inter_recons = [] + inter_imgs = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: + cond[key][:batch_size] if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = [c[:batch_size] for c in cond] if isinstance( + cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), + desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b, ), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, + t=tc, + noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + **kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + inter_recons.append(x0_partial) + inter_imgs.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, inter_imgs, inter_recons + + @torch.no_grad() + def p_sample_loop(self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + **kwargs): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm( + reversed(range(0, timesteps)), desc='Sampling t', + total=timesteps) if verbose else reversed(range(0, timesteps)) + + for i in iterator: + ts = torch.full((b, ), i, device=device, dtype=torch.long) + + img = self.p_sample(img, + cond, + ts, + mask, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + **kwargs) + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: + cond[key][:batch_size] if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = [c[:batch_size] for c in cond] if isinstance( + cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + **kwargs) + + @torch.no_grad() + def sample_log(self, + cond, + batch_size, + ddim, + ddim_steps=20, + sem_guide=False, + sem_guide_type='l2', + label=None, + GDCalculater=None, + **kwargs): + + if ddim: + if sem_guide and sem_guide_type == 'GDC': + guided_sampler = GuideDDIMSampler(self, + guide_type=sem_guide_type, + GDCalculater=GDCalculater) + shape = (self.channels, self.image_size) + samples, intermediates = guided_sampler.sample( + S=ddim_steps, + batch_size=batch_size, + shape=shape, + conditioning=cond, + verbose=False, + label=label, + **kwargs) + else: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size) + samples, intermediates = ddim_sampler.sample( + S=ddim_steps, + batch_size=batch_size, + shape=shape, + conditioning=cond, + verbose=False, + label=label, + **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, + batch_size=batch_size, + return_intermediates=True, + **kwargs) + + return samples, intermediates + + @torch.no_grad() + def log_images(self, + batch, + N=8, + n_row=8, + sample=True, + plot_reconstruction=False, + ddim_steps=20, + ddim_eta=1., + return_keys=None, + quantize_denoised=False, + inpaint=False, + plot_denoise_rows=False, + plot_progressive_rows=False, + plot_diffusion_rows=False, + plot_swapped_concepts=False, + plot_decoded_xstart=False, + plot_swapped_concepts_partial=False, + fix_noise=False, + **kwargs): + + use_ddim = ddim_steps is not None + # plot_swapped_concepts = True + + log = dict() + z, c, x, xrec, xc, mask, label = self.get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, + return_mask=True) + if fix_noise: + fixed_noise = torch.randn(x.shape, device=self.device) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x # batchsize, channel, window + if plot_reconstruction: + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), + batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack( + diffusion_row) # n_log_step, batchsize, channel, window + log["diffusion_row"] = diffusion_row # n_log_step, batchsize, channel, window + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + mask=mask, + label=None) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + with self.ema_scope("Uncond Plotting"): + samples, z_denoise_row = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + cfg_scale=0, + mask=mask, + label=None) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["uncond_samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance( + self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + quantize_x0=True, + mask=mask) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_swapped_concepts: + x_samples_list = [x[None, :], + log["samples"][None, :]] # [(1, N, C, W) * 2] + nc = self.model.diffusion_model.latent_unit # number of concepts + with self.ema_scope("Plotting Swapping"): + for cdx in range(nc): + swapped_c = c.clone() + # swapped_c = torch.stack(swapped_c.chunk(nc, dim=1), dim=1) + swapped_c[:, cdx] = swapped_c[0, cdx][None, :].repeat( + c.shape[0], 1) + samples, z_denoise_row = self.sample_log( + cond=swapped_c.reshape(c.shape[0], -1), + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) + x_samples = self.decode_first_stage(samples) + x_samples_list.append(x_samples[None, :]) + log["samples_swapping"] = torch.cat(x_samples_list, + dim=0) # [nc+1, N, C, W] + with self.ema_scope("Plotting Concept Interception"): + intercept_schedule = torch.arange(0, 1.1, + 0.2).to(self.device) # (6,) + ni = len(intercept_schedule) + 1 # number of interceptions + x_samples_list = [] + input_concept_list = [] + src_c = torch.repeat_interleave( + c[None, 0].clone(), repeats=ni, + dim=0) # source concept, shaped (ni, nc, dim) + des_c = torch.repeat_interleave( + c[None, 1].clone(), repeats=ni, + dim=0) # destination concept, shaped (ni, nc, dim) + default_w = torch.zeros( + (ni, nc, 1)) # an all zero weight metrics, (ni, nc, 1) + default_w[-1, :] = 1 # set the first concept to 1.0 + for cdx in range(nc): + inter_w = default_w.clone().to(self.device) + inter_w[:-1, cdx] = intercept_schedule[:, None] + inter_c = src_c * (1 - inter_w) + des_c * inter_w + input_concept_list.append( + rearrange(inter_c, 'ni nc dim -> ni (nc dim)')) + + intered_c = torch.cat(input_concept_list, dim=0) + samples, z_denoise_row = self.sample_log( + cond=intered_c, + batch_size=intered_c.shape[0], + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) + x_samples = self.decode_first_stage(samples) + x_samples = rearrange(x_samples, + '(nc ni) c w -> nc ni c w', + ni=ni, + nc=nc) + x_samples = rearrange(x_samples, 'nc ni c w -> ni nc c w') + # x_samples_list.append(x_samples[None, :]) + log["samples_swapping_intercept"] = x_samples + if plot_swapped_concepts_partial: + x_samples_list = [] + with self.ema_scope("Plotting Swapping"): + for cdx in range(self.model.diffusion_model.latent_unit): + swapped_c = c.clone() + # swapped_c = torch.stack(swapped_c.chunk(self.model.diffusion_model.latent_unit, dim=1), dim=1) + swapped_c[:, cdx] = swapped_c[0, cdx][None, :].repeat( + c.shape[0], 1) + samples, z_denoise_row = self.sample_log( + cond=swapped_c.reshape(c.shape[0], -1), + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + sampled_index=np.array([cdx] * N)) + x_samples = self.decode_first_stage(samples) + x_samples_list.append(x_samples) + log["samples_swapping_partial"] = torch.cat(x_samples_list, + dim=0) + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + imgs, inter_imgs, inter_recons = self.progressive_denoising( + c, shape=(self.channels, self.image_size), batch_size=N) + prog_recon_row = self._get_denoise_row_from_list( + inter_recons, desc="Progressive Recon Generation") + prog_img_row = self._get_denoise_row_from_list( + inter_imgs, desc="Progressive Img Generation") + no_ddim_samples = self.decode_first_stage(imgs) + log["no_ddim_samples"] = no_ddim_samples # [N, C, W] + log["progressive_row_recon"] = rearrange( + prog_recon_row, 'b n c t -> n b c t') # [n_log_step, N, C, W] + log["progressive_row_inter"] = rearrange( + prog_img_row, 'b n c t -> n b c t') # [n_log_step, N, C, W] + + if plot_decoded_xstart: + # get diffusion row + + with self.ema_scope("Plotting PredXstart"): + z_start = z[:n_row] + diffusion_start = list() + diffusion_full = list() + for cdx in range(self.model.diffusion_model.latent_unit): + diffusion_row = list() + for t in range(self.num_timesteps): + if (t % (self.log_every_t // 2) == 0 + or t == self.num_timesteps - + 1) and t >= self.num_timesteps // 2: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, + t=t, + noise=noise) + + model_out = self.apply_model( + z_noisy, + t, + c, + return_ids=False, + sampled_concept=np.array([cdx] * n_row)) + eps_pred = model_out.pred + extract_into_tensor( + self.ddim_coef, t, + x.shape) * model_out.sub_grad + x_recon = self.predict_start_from_noise( + z_noisy, t=t, noise=eps_pred) + + diffusion_row.append( + self.decode_first_stage(x_recon)) + + if cdx == 0: + eps_pred = model_out.pred + x_recon = self.predict_start_from_noise( + z_noisy, t=t, noise=eps_pred) + diffusion_start.append( + self.decode_first_stage(x_recon)) + + eps_pred = model_out.pred + extract_into_tensor( + self.ddim_coef, t, + x.shape) * model_out.out_grad + x_recon = self.predict_start_from_noise( + z_noisy, t=t, noise=eps_pred) + diffusion_full.append( + self.decode_first_stage(x_recon)) + + diffusion_row = torch.stack( + diffusion_row) # n_log_step, n_row, C, H, W + log[f"predXstart_{cdx}"] = diffusion_row + if cdx == 0: + diffusion_start = torch.stack( + diffusion_start) # n_log_step, n_row, C, H, W + log[f"predXstart_st"] = diffusion_start + + diffusion_full = torch.stack( + diffusion_full) # n_log_step, n_row, C, H, W + log[f"predXstart_fl"] = diffusion_full + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print( + f"{self.__class__.__name__}: Also optimizing conditioner params!" + ) + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [{ + 'scheduler': + LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': + 'step', + 'frequency': + 1 + }] + return [opt], scheduler + return opt + + +class DiffusionWrapper(pl.LightningModule): + + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [ + None, 'concat', 'crossattn', 'hybrid', 'adm' + ] + + def parameters(self): + return self.diffusion_model.parameters() + + def forward(self, + x, + t, + label, + sampled_concept=None, + sampled_index=None, + c_concat: list = None, + c_crossattn: list = None, + sub_scale: list = None, + cfg_scale=1., + cond_drop_prob=0., + cond_part_drop=False, + mask=None): + + if (c_crossattn is not None) and (not None in c_crossattn): + cc = torch.cat(c_crossattn, 1) + else: + cc = None + out = self.diffusion_model(x, + t, + context=cc, + mask=mask, + y=label, + cond_drop_prob=cond_drop_prob, + cond_part_drop=cond_part_drop) + + return out + + def cfg_forward(self, + x, + t, + label=None, + sampled_concept=None, + sampled_index=None, + c_concat: list = None, + c_crossattn: list = None, + sub_scale: list = None, + cfg_scale=1., + mask=None): + + if (c_crossattn is not None) and (not None in c_crossattn): + cc = torch.cat(c_crossattn, 1) + else: + cc = None + out = self.diffusion_model(x, t, context=cc, mask=mask, y=label) + + return out diff --git a/TarDiff/ldm/models/diffusion/guided_ddim.py b/TarDiff/ldm/models/diffusion/guided_ddim.py new file mode 100755 index 0000000..784496e --- /dev/null +++ b/TarDiff/ldm/models/diffusion/guided_ddim.py @@ -0,0 +1,327 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.modules.diffusionmodules.util import return_wrap, extract_into_tensor + + +class GuideDDIMSampler(object): + + def __init__(self, + model, + schedule="linear", + guide_type='l2', + GDCalculater=None, + **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.guide_type = guide_type + self.GDC = GDCalculater + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != self.device: + attr = attr.to(self.device) + setattr(self, name, attr) + + def make_schedule(self, + ddim_num_steps, + ddim_discretize="uniform", + ddim_eta=0., + verbose=True): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[ + 0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model + .device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', + to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', + to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', + to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev, ddim_coef = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_coef', ddim_coef) + self.register_buffer('ddim_sqrt_one_minus_alphas', + np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * + (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', + sigmas_for_original_sampling_steps) + + self.posterior_variance = self.model.posterior_variance + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + label=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, W = shape + size = (batch_size, C, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling( + cond=conditioning, + shape=size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + label=label, + **kwargs) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + label=None, + **kwargs): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int( + min(timesteps / self.ddim_timesteps.shape[0], 1) * + self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range( + 0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[ + 0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, + desc=f'DDIM Guided Sampler with dynamic scale', + total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b, ), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim( + x=img, + c=cond, + t=ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + label=label, + **kwargs) + img_hat, pred_x0 = outs + + # scale = extract_into_tensor(self.model.posterior_variance, ts, img.shape) + with torch.enable_grad(): + score = self.semantic_scoring(pred_x0, cond, label, ts) + img = (1 - self.GDC.gd_scale) * img_hat + score * self.GDC.gd_scale + # img = img_hat + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + def semantic_scoring(self, pred_x0, cond, label=None, t=None): + # cond: [b, n unit, latent dim] + pred_x0.requires_grad = True + pred_cond = self.model.get_learned_conditioning(pred_x0) + # cosine similarity between cond and pred_cond + if self.guide_type == 'cosine': + if cond.dim() == 2: + sim_score = torch.nn.functional.cosine_similarity( + cond[:, -32:], pred_cond[:, -32:], dim=-1) + else: # assume the last condition is the target condition + sim_score = torch.nn.functional.cosine_similarity( + cond[:, -1], pred_cond[:, -1], dim=-1) + scale = (1 / max(sim_score.mean(), 0.05))**2 + elif self.guide_type == 'l2': + if cond.dim() == 2: + sim_score = -torch.nn.functional.pairwise_distance( + cond[:, -32:], pred_cond[:, -32:], p=2) + else: + sim_score = -torch.nn.functional.pairwise_distance( + cond[:, -1], pred_cond[:, -1], p=2) + scale = 1 / max(torch.exp(sim_score.mean().detach()), 0.01) + elif self.guide_type == 'GDC': + + sim_score = self.GDC.compute_gradient(pred_x0, label) + return sim_score + score = torch.autograd.grad(sim_score.mean(), pred_x0)[0] + return score, scale + + @torch.no_grad() + def p_sample_ddim(self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + label=None, + **kwargs): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c, label=label, **kwargs) + e_t = return_wrap( + e_t, torch.full((b, 1, 1), + self.ddim_coef[index], + device=device)) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, + t_in, + c_in, + label=label).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - + e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, + **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1), + sqrt_one_minus_alphas[index], + device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, + repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/TarDiff/ldm/models/diffusion/plms.py b/TarDiff/ldm/models/diffusion/plms.py new file mode 100755 index 0000000..62ee5a7 --- /dev/null +++ b/TarDiff/ldm/models/diffusion/plms.py @@ -0,0 +1,316 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, + ddim_num_steps, + ddim_discretize="uniform", + ddim_eta=0., + verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[ + 0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model + .device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', + to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', + to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', + to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', + np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * + (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', + sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int( + min(timesteps / self.ddim_timesteps.shape[0], 1) * + self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range( + 0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[ + 0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b, ), step, device=device, dtype=torch.long) + ts_next = torch.full((b, ), + time_range[min(i + 1, + len(time_range) - 1)], + device=device, + dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + old_eps=None, + t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, + c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - + e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, + **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), + alphas_prev[index], + device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), + sqrt_one_minus_alphas[index], + device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, + repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - + 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/TarDiff/ldm/models/diffusion/uni_csg.py b/TarDiff/ldm/models/diffusion/uni_csg.py new file mode 100755 index 0000000..74384b6 --- /dev/null +++ b/TarDiff/ldm/models/diffusion/uni_csg.py @@ -0,0 +1,2257 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim_time import DDIMSampler +from ldm.modules.diffusionmodules.util import return_wrap +import copy +import os +import pandas as pd + +__conditioning_keys__ = { + 'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y' +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + ): + super().__init__() + assert parameterization in [ + "eps", "x0" + ], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print( + f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" + ) + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, + ignore_keys=ignore_keys, + only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, + size=(self.num_timesteps, )) + self.ce_loss = nn.CrossEntropyLoss(reduction="none") + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + def register_schedule(self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[ + 0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', + to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', + to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', + to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', + to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + 'posterior_log_variance_clipped', + to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer( + 'posterior_mean_coef1', + to_torch(betas * np.sqrt(alphas_cumprod_prev) / + (1. - alphas_cumprod))) + self.register_buffer( + 'posterior_mean_coef2', + to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / + (1. - alphas_cumprod))) + self.register_buffer( + "shift_coef", + -to_torch(np.sqrt(alphas)) * (1. - self.alphas_cumprod_prev) / + torch.sqrt(1. - self.alphas_cumprod)) + self.register_buffer("ddim_coef", -self.sqrt_one_minus_alphas_cumprod) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / (2 * self.posterior_variance * + to_torch(alphas) * + (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / ( + 2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + self.load_epoch = sd['epoch'] + self.load_step = sd["global_step"] + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict( + sd, + strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, + x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, + t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * + x_t - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, + x_t.shape) * noise) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * + x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t) + posterior_variance = extract_into_tensor(self.posterior_variance, t, + x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + eps_pred = return_wrap(model_out, + extract_into_tensor(self.ddim_coef, t, x.shape)) + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=eps_pred) + elif self.parameterization == "x0": + x_recon = eps_pred + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape( + b, *((1, ) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * + model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), + desc='Sampling t', + total=self.num_timesteps): + img = self.p_sample(img, + torch.full((b, ), + i, + device=device, + dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, + pred, + reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + eps_pred = return_wrap( + model_out, extract_into_tensor(self.shift_coef, t, x_start.shape)) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError( + f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(eps_pred, target, mean=False).mean(dim=[1, 2]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, + self.num_timesteps, (x.shape[0], ), + device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 2: + x = x[..., None] + x = rearrange(x, 'b t c -> b c t') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True) + + self.log("global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', + lr, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + # pass + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = { + key + '_ema': loss_dict_ema[key] + for key in loss_dict_ema + } + self.log_dict(loss_dict_no_ema, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True) + self.log_dict(loss_dict_ema, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c t -> b n c t') + denoise_grid = rearrange(denoise_grid, 'b n c t -> (b n) c t') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, + batch, + N=8, + n_row=2, + sample=True, + return_keys=None, + **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, + return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + dis_loss_flag=False, + detach_flag=False, + train_enc_flag=False, + dis_weight=1.0, + dis_loss_type="IM", + cond_drop_prob=None, + pair_loss_flag=False, + *args, + **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.dis_loss_flag = dis_loss_flag + self.pair_loss_flag = pair_loss_flag + self.detach_flag = detach_flag + self.train_enc_flag = train_enc_flag + self.dis_weight = dis_weight + self.dis_loss_type = dis_loss_type + self.cond_drop_prob = cond_drop_prob + try: + self.num_downs = len( + first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps, ), + fill_value=self.num_timesteps - 1, + dtype=torch.long) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, + self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + if hasattr(self.model.diffusion_model, "scale_factor"): + del self.scale_factor + self.register_buffer('scale_factor', + self.model.diffusion_model.scale_factor) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING Pre-Trained STD-RESCALING ###") + else: + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, + linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print( + f"Training {self.__class__.__name__} as an unconditional model." + ) + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, + samples, + desc='', + force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append( + self.decode_first_stage( + zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c t -> b n c t') + # denoise_grid = rearrange(denoise_grid, 'b n c t -> (b n) c t') + # denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable( + self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], + dim=-1), + dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip( + weighting, + self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], + ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, + Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip( + L_weighting, self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, + x, + kernel_size, + stride, + uf=1, + df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, + dilation=1, + padding=0, + stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, + Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, + w) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, + dilation=1, + padding=0, + stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, + kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, + x.shape[3] * uf), + **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, + kernel_size[1] * uf, Ly, Lx, + x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * + uf) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, + dilation=1, + padding=0, + stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, + kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, + x.shape[3] // df), + **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, + kernel_size[1] // df, Ly, Lx, + x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // + df) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + else: + c = None + xc = None + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, + z, + predict_cids=False, + force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, + shape=None) + z = rearrange(z, 'b t c -> b c t').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold( + z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], + z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [ + self.first_stage_model.decode( + z[:, :, :, :, i], + force_not_quantize=predict_cids + or force_not_quantize) for i in range(z.shape[-1]) + ] + else: + + output_list = [ + self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] + + o = torch.stack(output_list, + axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, + o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode( + z, + force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode( + z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, + z, + predict_cids=False, + force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, + shape=None) + z = rearrange(z, 'b t c -> b c t').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold( + z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], + z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [ + self.first_stage_model.decode( + z[:, :, :, :, i], + force_not_quantize=predict_cids + or force_not_quantize) for i in range(z.shape[-1]) + ] + else: + + output_list = [ + self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] + + o = torch.stack(output_list, + axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, + o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode( + z, + force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode( + z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold( + x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], + z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [ + self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, + o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + kwargs['data_key'] = batch['data_key'].to(self.device) + loss = self(x, c, **kwargs) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, + self.num_timesteps, (x.shape[0], ), + device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, + t=tc, + noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, + crop_coordinates): # TODO: move to dataset + + def rescale_bbox(bbox): + x0 = torch.clamp( + (bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = torch.clamp( + (bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, + x_noisy, + t, + cond, + cfg_scale=1, + cond_drop_prob=None, + return_ids=False, + sampled_concept=None, + sampled_index=None, + sub_scale=None): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len( + cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold( + x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], + z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + cond_list = [cond for i in range(z.shape[-1]) + ] # Todo make this more efficient + + # apply model by loop over crops + output_list = [ + self.model(z_list[i], t, **cond_list[i]) + for i in range(z.shape[-1]) + ] + assert not isinstance( + output_list[0], tuple + ) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view( + (o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + if cond_drop_prob is None: + x_recon = self.model.cfg_forward( + x_noisy, + t, + cfg_scale=cfg_scale, + sampled_concept=sampled_concept, + sampled_index=sampled_index, + sub_scale=sub_scale, + **cond) + else: + x_recon = self.model.forward(x_noisy, + t, + cfg_scale=cfg_scale, + cond_drop_prob=cond_drop_prob, + sampled_concept=sampled_concept, + sampled_index=sampled_index, + sub_scale=sub_scale, + **cond) + + # if isinstance(x_recon, tuple) and not return_ids: + # return x_recon[0] + # else: + # return x_recon + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, + device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, + logvar1=qt_log_variance, + mean2=0.0, + logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + # @torch.no_grad() + # def test_step(self, batch, batch_idx): + # x = super().get_input(batch, self.cond_stage_key) + # cond = self.cond_stage_model(x) + # cond = torch.stack(cond.chunk(self.model.diffusion_model.latent_unit, dim = 1), dim=1) + # return {"cond":cond.detach().cpu()} + + # @torch.no_grad() + # def test_step_end(self, batch_parts): + # return batch_parts["cond"] + + # @torch.no_grad() + # def test_epoch_end(self, test_step_outputs): + # cond_cat = torch.cat(test_step_outputs, dim=0) + # cond_dir = os.path.join(self.logdir, "dis_repre","epoch={:06}.npz".format( + # self.current_epoch)) + # os.mkdir(os.path.join(self.logdir, "dis_repre")) + # np.savez(cond_dir, latents=cond_cat.numpy(), num_samples= np.array(self.global_step)) + + # def dis_loss(self, model_forward, x_t, t, cond, sampled_concept): + # if not self.train_enc_flag: + # eval_encoder = copy.deepcopy(self.cond_stage_model) + # eval_encoder.requires_grad_(False) + # eval_encoder.eval() + # else: + # eval_encoder = self.cond_stage_model + + # ddim_coef = extract_into_tensor(self.ddim_coef, t, x_t.shape) + # with torch.no_grad(): + # eps_hat = model_forward.pred + # z_start = self.predict_start_from_noise(x_t, t, eps_hat) + # pred_x0_t = self.differentiable_decode_first_stage(z_start, force_not_quantize=not self.detach_flag) + # if self.detach_flag: + # pred_x0_t = pred_x0_t.detach() + # else: + # pass + # pred_z = eval_encoder(pred_x0_t) + # z_parts = pred_z.chunk(self.model.diffusion_model.latent_unit, dim=1) + # pred_z = torch.stack(z_parts, dim=1) + + # eps_new_hat = model_forward.null_pred + ddim_coef*model_forward.sub_grad + # z_start_new = self.predict_start_from_noise(x_t, t, eps_new_hat) + # pred_x0_new_t = self.differentiable_decode_first_stage(z_start_new, force_not_quantize=not self.detach_flag) + # if self.detach_flag: + # pred_x0_new_t = pred_x0_new_t.detach() + # else: + # pass + # pred_z_new = eval_encoder(pred_x0_new_t) + # z_parts = pred_z_new.chunk(self.model.diffusion_model.latent_unit, dim=1) + # cond = cond.chunk(self.model.diffusion_model.latent_unit, dim=1) + # pred_z_new = torch.stack(z_parts, dim=1) + # cond = torch.stack(cond, dim=1) + + # with torch.no_grad(): + # norm_org = torch.norm(pred_z - cond.detach(), dim=-1) + # norm_Z = torch.norm(pred_z_new - cond.detach(), dim=-1) + # logits_deta = torch.norm(pred_z - pred_z_new, dim = -1) + # logits = norm_org - norm_Z + + # dis_loss = self.ce_loss(logits, torch.from_numpy(sampled_concept).cuda()) + # dis_loss_deta = self.ce_loss(logits_deta, torch.from_numpy(sampled_concept).cuda()) + + # if self.dis_loss_type == "IM": + # dis_weight = mean_flat((pred_x0_t - pred_x0_new_t.detach())**2) + # elif self.dis_loss_type == "Z": + # dis_weight = mean_flat((z_start - z_start_new.detach())**2) + # else: + # raise NotImplementedError + + # return dis_weight * self.dis_weight * (dis_loss + dis_loss_deta) + + def dis_loss(self, model_forward, x_t, t, cond, sampled_concept): + if not self.train_enc_flag: + eval_encoder = copy.deepcopy(self.cond_stage_model) + eval_encoder.requires_grad_(False) + eval_encoder.eval() + else: + eval_encoder = self.cond_stage_model + + ddim_coef = extract_into_tensor(self.ddim_coef, t, x_t.shape) + with torch.no_grad(): + eps_hat = model_forward.pred + z_start = self.predict_start_from_noise(x_t, t, eps_hat) + pred_x0_t = self.differentiable_decode_first_stage( + z_start, force_not_quantize=not self.detach_flag).detach() + + eps_new_hat = model_forward.null_pred + ddim_coef * model_forward.sub_grad + z_start_new = self.predict_start_from_noise(x_t, t, eps_new_hat) + pred_x0_new_t = self.differentiable_decode_first_stage( + z_start_new, force_not_quantize=not self.detach_flag).detach() + + pred_z = eval_encoder(pred_x0_t) + z_parts = pred_z.chunk(self.model.diffusion_model.latent_unit, + dim=1) + pred_z = torch.stack(z_parts, dim=1) + + pred_z_new = eval_encoder(pred_x0_new_t) + z_parts = pred_z_new.chunk(self.model.diffusion_model.latent_unit, + dim=1) + cond = cond.chunk(self.model.diffusion_model.latent_unit, dim=1) + pred_z_new = torch.stack(z_parts, dim=1) + cond = torch.stack(cond, dim=1) + + # with torch.no_grad(): + # norm_org = torch.norm(pred_z - cond.detach(), dim=-1) + # norm_Z = torch.norm(pred_z_new - cond.detach(), dim=-1) + norm_org = torch.norm(pred_z - cond, dim=-1) + norm_Z = torch.norm(pred_z_new - cond.detach(), dim=-1) + logits_deta = torch.norm(pred_z - pred_z_new, dim=-1) + logits = norm_org.detach() - norm_Z + + dis_loss = self.ce_loss(logits, + torch.from_numpy(sampled_concept).cuda()) + dis_loss_deta = self.ce_loss(logits_deta, + torch.from_numpy(sampled_concept).cuda()) + + if self.dis_loss_type == "IM": + dis_weight = mean_flat( + (pred_x0_t.detach() - pred_x0_new_t.detach())**2) + elif self.dis_loss_type == "Z": + dis_weight = mean_flat( + (z_start.detach() - z_start_new.detach())**2) + else: + raise NotImplementedError + + return dis_weight * self.dis_weight * (dis_loss + + dis_loss_deta), dis_weight + + def sim_mask(self, arr): + n = arr.shape[0] + mask = torch.zeros((n, n), dtype=torch.int32).to(self.device) + idx1, idx2 = torch.tril_indices(n, n, -1).to(self.device) + mask[idx1, idx2] = (arr[idx1] == arr[idx2]).int() + return mask + + def pair_loss(self, cond, data_key): + bs = cond.shape[0] + sim_mask = self.sim_mask(data_key) + if cond.dim() == 2: + r = cond.chunk(self.model.diffusion_model.latent_unit, dim=1)[-1] + else: + r = cond[:, -1] + # r_norm = torch.norm(r.detach(), dim=1) + r_norm = torch.norm(r, dim=1) + mask = torch.tril(torch.ones_like(sim_mask), + -1).to(self.device) # lower triangular matrix + cos_sim = (r @ r.T) * mask * 0.5 / ( + r_norm.view(-1, 1) @ r_norm.view(1, -1) + 1e-8) + cos_sim = cos_sim**2 + # cos_sim = torch.abs(cos_sim) + # cos_sim = (r@r.T.detach()) * mask * 0.5 / (r_norm.view(-1, 1) @ r_norm.view(1, -1) + 1e-8) + # sim_loss = -(((sim_mask-1) * cos_sim - torch.log(1 + torch.exp(cos_sim))) / (bs**2) * mask).sum() + sim_loss = -((sim_mask * cos_sim - torch.log(1 + torch.exp(cos_sim))) / + (bs**2) * mask).sum() + return sim_loss + + def p_losses(self, x_start, cond, t, noise=None, data_key=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + if self.dis_loss_flag and self.global_step > 1000: + sampled_concept = np.random.randint( + self.model.diffusion_model.latent_unit, size=x_noisy.shape[0]) + model_output = self.apply_model(x_noisy, + t, + cond, + sampled_concept=sampled_concept, + cond_drop_prob=self.cond_drop_prob) + dis_loss, dis_weight = self.dis_loss(model_output, x_noisy, t, + cond, sampled_concept) + elif self.pair_loss_flag: + model_output = self.apply_model(x_noisy, + t, + cond, + cond_drop_prob=self.cond_drop_prob) + pair_loss = self.pair_loss(cond, data_key) + else: + model_output = self.apply_model(x_noisy, + t, + cond, + cond_drop_prob=self.cond_drop_prob) + + eps_pred = return_wrap( + model_output, extract_into_tensor(self.shift_coef, t, + x_start.shape)) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(eps_pred, target, mean=False).mean([1, 2]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t.cpu()].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + if self.dis_loss_flag and self.global_step > 1000: + loss = self.l_simple_weight * loss.mean() + dis_loss.mean() + loss_dict.update({f'{prefix}/dis_loss': dis_loss.mean()}) + loss_dict.update({f'{prefix}/dis_weight': dis_weight.mean()}) + elif self.pair_loss_flag: + loss = self.l_simple_weight * loss.mean() + pair_loss + loss_dict.update({f'{prefix}/pair_loss': pair_loss}) + else: + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(eps_pred, target, mean=False).mean(dim=(1, 2)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + loss_dict.update({f'{prefix}/epoch_num': self.current_epoch}) + loss_dict.update({f'{prefix}/step_num': self.global_step}) + + return loss, loss_dict + + def p_mean_variance(self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + **kwargs): + t_in = t + model_out = self.apply_model(x, + t_in, + c, + return_ids=return_codebook_ids, + **kwargs) + + eps_pred = return_wrap(model_out, + extract_into_tensor(self.ddim_coef, t, x.shape)) + + if score_corrector is not None: + assert self.parameterization == "eps" + eps_pred = score_corrector.modify_score(self, eps_pred, x, t, c, + **corrector_kwargs) + + if return_codebook_ids: + eps_pred, logits = eps_pred + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=eps_pred) + elif self.parameterization == "x0": + x_recon = eps_pred + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, + indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + **kwargs): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + **kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape( + b, *((1, ) * (len(x.shape) - 1))) + + # if return_codebook_ids: + # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + **kwargs): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + inter_recons = [] + inter_imgs = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: + cond[key][:batch_size] if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = [c[:batch_size] for c in cond] if isinstance( + cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), + desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b, ), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, + t=tc, + noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + **kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + inter_recons.append(x0_partial) + inter_imgs.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, inter_imgs, inter_recons + + @torch.no_grad() + def p_sample_loop(self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + **kwargs): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm( + reversed(range(0, timesteps)), desc='Sampling t', + total=timesteps) if verbose else reversed(range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2: + 3] # spatial size has to match + + for i in iterator: + ts = torch.full((b, ), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, + t=tc, + noise=torch.randn_like(cond)) + + img = self.p_sample(img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + **kwargs) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: + cond[key][:batch_size] if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = [c[:batch_size] for c in cond] if isinstance( + cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + **kwargs) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size) + samples, intermediates = ddim_sampler.sample(S=ddim_steps, + batch_size=batch_size, + shape=shape, + conditioning=cond, + verbose=False, + **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, + batch_size=batch_size, + return_intermediates=True, + **kwargs) + + return samples, intermediates + + @torch.no_grad() + def log_images(self, + batch, + N=8, + n_row=8, + sample=True, + plot_reconstruction=False, + ddim_steps=200, + ddim_eta=1., + return_keys=None, + quantize_denoised=False, + inpaint=False, + plot_denoise_rows=False, + plot_progressive_rows=False, + plot_diffusion_rows=False, + plot_swapped_concepts=False, + plot_decoded_xstart=False, + plot_swapped_concepts_partial=False, + fix_noise=False, + **kwargs): + + use_ddim = ddim_steps is not None + # plot_swapped_concepts = True + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + if fix_noise: + fixed_noise = torch.randn(x.shape, device=self.device) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x # batchsize, channel, window + if plot_reconstruction: + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), + batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack( + diffusion_row) # n_log_step, batchsize, channel, window + # diffusion_grid = rearrange(diffusion_row, 'n b c t -> b n c t') # when drawing, mix all the channels + # diffusion_grid = rearrange(diffusion_grid, 'b n c t -> (b n) c t') + # diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + # log["diffusion_row"] = diffusion_grid # batchsize, n_log_step, channel, window + log["diffusion_row"] = diffusion_row # n_log_step, batchsize, channel, window + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + with self.ema_scope("Uncond Plotting"): + samples, z_denoise_row = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + cfg_scale=0) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["uncond_samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance( + self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + quantize_x0=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_swapped_concepts: + x_samples_list = [x[None, :], + log["samples"][None, :]] # [(1, N, C, W) * 2] + nc = self.model.diffusion_model.latent_unit # number of concepts + with self.ema_scope("Plotting Swapping"): + for cdx in range(nc): + swapped_c = c.clone() + swapped_c = torch.stack(swapped_c.chunk(nc, dim=1), dim=1) + swapped_c[:, cdx] = swapped_c[0, cdx][None, :].repeat( + c.shape[0], 1) + samples, z_denoise_row = self.sample_log( + cond=swapped_c.reshape(c.shape[0], -1), + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) + x_samples = self.decode_first_stage(samples) + x_samples_list.append(x_samples[None, :]) + log["samples_swapping"] = torch.cat(x_samples_list, + dim=0) # [nc+1, N, C, W] + with self.ema_scope("Plotting Concept Interception"): + intercept_schedule = torch.arange(0, 1.1, + 0.2).to(self.device) # (6,) + ni = len(intercept_schedule) + 1 # number of interceptions + x_samples_list = [] + input_concept_list = [] + src_c = torch.repeat_interleave( + torch.stack(c[None, 0].clone().chunk(nc, dim=1), dim=1), + repeats=ni, + dim=0) # source concept, shaped (ni, nc, dim) + des_c = torch.repeat_interleave( + torch.stack(c[None, 1].clone().chunk(nc, dim=1), dim=1), + repeats=ni, + dim=0) # destination concept, shaped (ni, nc, dim) + default_w = torch.zeros( + (ni, nc, 1)) # an all zero weight metrics, (ni, nc, 1) + default_w[-1, :] = 1 # set the first concept to 1.0 + for cdx in range(nc): + inter_w = default_w.clone().to(self.device) + inter_w[:-1, cdx] = intercept_schedule[:, None] + inter_c = src_c * (1 - inter_w) + des_c * inter_w + input_concept_list.append( + rearrange(inter_c, 'ni nc dim -> ni (nc dim)')) + + intered_c = torch.cat(input_concept_list, dim=0) + samples, z_denoise_row = self.sample_log( + cond=intered_c, + batch_size=intered_c.shape[0], + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) + x_samples = self.decode_first_stage(samples) + x_samples = rearrange(x_samples, + '(nc ni) c w -> nc ni c w', + ni=ni, + nc=nc) + x_samples = rearrange(x_samples, 'nc ni c w -> ni nc c w') + # x_samples_list.append(x_samples[None, :]) + log["samples_swapping_intercept"] = x_samples + if plot_swapped_concepts_partial: + x_samples_list = [] + with self.ema_scope("Plotting Swapping"): + for cdx in range(self.model.diffusion_model.latent_unit): + swapped_c = c.clone() + swapped_c = torch.stack(swapped_c.chunk( + self.model.diffusion_model.latent_unit, dim=1), + dim=1) + swapped_c[:, cdx] = swapped_c[0, cdx][None, :].repeat( + c.shape[0], 1) + samples, z_denoise_row = self.sample_log( + cond=swapped_c.reshape(c.shape[0], -1), + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + sampled_index=np.array([cdx] * N)) + x_samples = self.decode_first_stage(samples) + x_samples_list.append(x_samples) + log["samples_swapping_partial"] = torch.cat(x_samples_list, + dim=0) + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + imgs, inter_imgs, inter_recons = self.progressive_denoising( + c, shape=(self.channels, self.image_size), batch_size=N) + prog_recon_row = self._get_denoise_row_from_list( + inter_recons, desc="Progressive Recon Generation") + prog_img_row = self._get_denoise_row_from_list( + inter_imgs, desc="Progressive Img Generation") + no_ddim_samples = self.decode_first_stage(imgs) + log["no_ddim_samples"] = no_ddim_samples # [N, C, W] + # log["progressive_row_recon"] = prog_recon_row # [N, n_log_step, C, W] + # log["progressive_row_inter"] = prog_img_row # [N, n_log_step, C, W] + log["progressive_row_recon"] = rearrange( + prog_recon_row, 'b n c t -> n b c t') # [n_log_step, N, C, W] + log["progressive_row_inter"] = rearrange( + prog_img_row, 'b n c t -> n b c t') # [n_log_step, N, C, W] + + if plot_decoded_xstart: + # get diffusion row + + with self.ema_scope("Plotting PredXstart"): + z_start = z[:n_row] + diffusion_start = list() + diffusion_full = list() + for cdx in range(self.model.diffusion_model.latent_unit): + diffusion_row = list() + for t in range(self.num_timesteps): + if (t % (self.log_every_t // 2) == 0 + or t == self.num_timesteps - + 1) and t >= self.num_timesteps // 2: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, + t=t, + noise=noise) + + model_out = self.apply_model( + z_noisy, + t, + c, + return_ids=False, + sampled_concept=np.array([cdx] * n_row)) + eps_pred = model_out.pred + extract_into_tensor( + self.ddim_coef, t, + x.shape) * model_out.sub_grad + x_recon = self.predict_start_from_noise( + z_noisy, t=t, noise=eps_pred) + + diffusion_row.append( + self.decode_first_stage(x_recon)) + + if cdx == 0: + eps_pred = model_out.pred + x_recon = self.predict_start_from_noise( + z_noisy, t=t, noise=eps_pred) + diffusion_start.append( + self.decode_first_stage(x_recon)) + + eps_pred = model_out.pred + extract_into_tensor( + self.ddim_coef, t, + x.shape) * model_out.out_grad + x_recon = self.predict_start_from_noise( + z_noisy, t=t, noise=eps_pred) + diffusion_full.append( + self.decode_first_stage(x_recon)) + + diffusion_row = torch.stack( + diffusion_row) # n_log_step, n_row, C, H, W + # diffusion_grid = rearrange(diffusion_row, 'n b c t -> b n c t') + # diffusion_grid = rearrange(diffusion_grid, 'b n c t -> (b n) c t') + # diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + # log[f"predXstart_{cdx}"] = diffusion_grid + log[f"predXstart_{cdx}"] = diffusion_row + if cdx == 0: + diffusion_start = torch.stack( + diffusion_start) # n_log_step, n_row, C, H, W + # diffusion_grid = rearrange(diffusion_start, 'n b c t -> b n c t') + # diffusion_grid = rearrange(diffusion_grid, 'b n c t -> (b n) c t') + # diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_start.shape[0]) + # log[f"predXstart_st"] = diffusion_grid + log[f"predXstart_st"] = diffusion_start + + diffusion_full = torch.stack( + diffusion_full) # n_log_step, n_row, C, H, W + # diffusion_grid = rearrange(diffusion_full, 'n b c t -> b n c t') + # diffusion_grid = rearrange(diffusion_grid, 'b n c t -> (b n) c t') + # diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_full.shape[0]) + # log[f"predXstart_fl"] = diffusion_grid + log[f"predXstart_fl"] = diffusion_full + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print( + f"{self.__class__.__name__}: Also optimizing conditioner params!" + ) + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [{ + 'scheduler': + LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': + 'step', + 'frequency': + 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [ + None, 'concat', 'crossattn', 'hybrid', 'adm' + ] + + def parameters(self): + return self.diffusion_model.parameters() + + def forward(self, + x, + t, + sampled_concept=None, + sampled_index=None, + c_concat: list = None, + c_crossattn: list = None, + sub_scale: list = None, + cfg_scale=1., + cond_drop_prob=0.): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, + t, + context=cc, + cond_drop_prob=cond_drop_prob) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + model_kwargs = {} + if sampled_concept is not None: + model_kwargs["sampled_concept"] = sampled_concept + if sampled_index is not None: + model_kwargs["sampled_index"] = sampled_index + if sub_scale is not None: + model_kwargs["sub_scale"] = sub_scale + model_kwargs["cfg_scale"] = cfg_scale + + out = self.diffusion_model.forward(x, + t, + context=cc, + cond_drop_prob=cond_drop_prob, + **model_kwargs) + else: + raise NotImplementedError() + return out + + def cfg_forward(self, + x, + t, + sampled_concept=None, + sampled_index=None, + c_concat: list = None, + c_crossattn: list = None, + sub_scale: list = None, + cfg_scale=1.): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + model_kwargs = {} + if sampled_concept is not None: + model_kwargs["sampled_concept"] = sampled_concept + if sampled_index is not None: + model_kwargs["sampled_index"] = sampled_index + if sub_scale is not None: + model_kwargs["sub_scale"] = sub_scale + model_kwargs["cfg_scale"] = cfg_scale + + out = self.diffusion_model.forward_with_cfg(x, + t, + context=cc, + **model_kwargs) + else: + raise NotImplementedError() + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label( + dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, + (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/TarDiff/ldm/modules/attention.py b/TarDiff/ldm/modules/attention.py new file mode 100755 index 0000000..b7a6c38 --- /dev/null +++ b/TarDiff/ldm/modules/attention.py @@ -0,0 +1,386 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + + +class LinearAttention(nn.Module): + + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, + 'b (qkv heads c) h w -> qkv b heads c (h w)', + heads=self.heads, + qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, + 'b heads c (h w) -> b (heads c) h w', + heads=self.heads, + h=h, + w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0., + hard_assign=False, + inter_mask=False): + super().__init__() + self.hard_assign = hard_assign + self.inter_mask = inter_mask + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(query_dim, inner_dim, bias=False) + self.to_v = nn.Linear(query_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout)) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + if len(context.shape) == 2: + k = self.to_k(context)[:, None] + v = self.to_v(context)[:, None] + else: + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + if self.hard_assign: + sim = sim + (1 - mask) * max_neg_value + elif self.inter_mask: + mask_of_mask = torch.where(mask > 0, torch.zeros_like(mask), + torch.ones_like(mask)) + max_neg_value = -torch.finfo(mask.dtype).max + mask = mask_of_mask * max_neg_value + mask + sim = sim + mask + else: + sim = sim + mask + # sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=False, + hard_assign=False, + inter_mask=False): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + hard_assign=hard_assign, + inter_mask=inter_mask) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, mask=None): + return checkpoint(self._forward, (x, context, mask), self.parameters(), + self.checkpoint) + + def _forward(self, x, context=None, mask=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock(inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim) + for d in range(depth) + ]) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in + + +class Spatial1DTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + hard_assign=False, + inter_mask=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock(inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + hard_assign=hard_assign, + inter_mask=inter_mask) for d in range(depth) + ]) + + self.proj_out = zero_module( + nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None, mask=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c w -> b w c') + for block in self.transformer_blocks: + x = block(x, context=context, mask=mask) + x = rearrange(x, 'b w c -> b c w', w=w) + x = self.proj_out(x) + return x + x_in diff --git a/TarDiff/ldm/modules/diffusionmodules/__init__.py b/TarDiff/ldm/modules/diffusionmodules/__init__.py new file mode 100755 index 0000000..9a04545 --- /dev/null +++ b/TarDiff/ldm/modules/diffusionmodules/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/TarDiff/ldm/modules/diffusionmodules/model.py b/TarDiff/ldm/modules/diffusionmodules/model.py new file mode 100755 index 0000000..7299a5f --- /dev/null +++ b/TarDiff/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,991 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True) + + +class Upsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, + scale_factor=2.0, + mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + + def __init__(self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", + "none"], f'attn_type {attn_type} unknown' + print( + f"making attention of type '{attn_type}' with {in_channels} in_channels" + ) + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock(in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], + dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2 * + z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1, ) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True) + ]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + + def __init__(self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2**(self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + + def __init__(self, + factor, + in_channels, + mid_channels, + out_channels, + depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ + ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth) + ]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ + ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth) + ]) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=(int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + + def __init__(self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + + def __init__(self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + + def __init__(self, + in_size, + out_size, + in_channels, + out_channels, + ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1. + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler(factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate(x, + mode=self.mode, + align_corners=False, + scale_factor=scale_factor) + return x + + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d(in_channels, + n_channels, + kernel_size=3, + stride=1, + padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock(in_channels=ch_in, + out_channels=m * n_channels, + dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, 'b c h w -> b (h w) c') + return z diff --git a/TarDiff/ldm/modules/diffusionmodules/unet1d.py b/TarDiff/ldm/modules/diffusionmodules/unet1d.py new file mode 100755 index 0000000..7cd05cb --- /dev/null +++ b/TarDiff/ldm/modules/diffusionmodules/unet1d.py @@ -0,0 +1,985 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import Spatial1DTransformer +from ldm.util import default +from .util import Return, Return_grad_cfg + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return th.ones(shape, device=device, dtype=th.bool) + elif prob == 0: + return th.zeros(shape, device=device, dtype=th.bool) + else: + return th.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, mask=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, Spatial1DTransformer): + x = layer(x, context, mask=mask) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, + self.channels, + self.out_channels, + 3, + padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels, + self.out_channels, + kernel_size=ks, + stride=2) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + cond_emb_channels=None, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + if cond_emb_channels is not None: + self.cond_emb_layers = nn.Sequential( + nn.SiLU(), + linear( + cond_emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + else: + self.cond_emb_layers = None + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, + self.out_channels, + self.out_channels, + 3, + padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, + channels, + self.out_channels, + 3, + padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, + 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint(self._forward, (x, emb), self.parameters(), + self.use_checkpoint) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x, ), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, + dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, + k * scale) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, + v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + repre_emb_channels=32, + latent_unit=6, + use_cfg=True, + cond_drop_prob=0.5, + hard_assign=False, + inter_mask=False): + super().__init__() + # if use_spatial_transformer: + # assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + self.use_cfg = use_cfg + self.cond_drop_prob = cond_drop_prob + self.latent_unit = latent_unit + self.latent_dim = repre_emb_channels + self.hard_assign = hard_assign + self.inter_mask = inter_mask + assert not (self.hard_assign and self.inter_mask + ), 'hard_assign and inter_mask cannot be used together' + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + if self.use_cfg: + self.cond_emb_channels = repre_emb_channels * latent_unit if self.use_cfg else None + # self.null_classes_emb = nn.Parameter(th.randn(self.cond_emb_channels)) + # self.null_classes_emb = nn.Parameter(th.randn(latent_unit, repre_emb_channels)) + self.null_classes_emb = nn.Parameter( + th.randn(1, repre_emb_channels)) + else: + self.cond_emb_channels = None + + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1)) + ]) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else + Spatial1DTransformer(ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + hard_assign=hard_assign, + inter_mask=inter_mask)) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) if resblock_updown else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else Spatial1DTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + hard_assign=hard_assign, + inter_mask=inter_mask), + ResBlock( + ch, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else + Spatial1DTransformer(ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + hard_assign=hard_assign, + inter_mask=inter_mask)) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + cond_emb_channels=self.cond_emb_channels, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) if resblock_updown else Upsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module( + conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def _forward(self, + x, + timesteps=None, + context=None, + mask=None, + y=None, + cond_drop_prob=0, + cond_part_drop=False, + **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn/adaln + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + # context = None + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + bs, device = x.shape[0], x.device + cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) + if context is not None: + c_num = context.shape[1] + + if cond_drop_prob > 0: + if not cond_part_drop: + keep_mask = prob_mask_like((bs, 1, 1), + 1 - cond_drop_prob, + device=device) + else: + keep_mask = prob_mask_like((bs, c_num, 1), + 1 - cond_drop_prob, + device=device) + null_classes_emb = repeat(self.null_classes_emb, + '1 d -> b n d', + b=bs, + n=c_num) + # null_classes_emb = repeat(self.null_classes_emb, 'n d -> b n d', b = bs) + + # context_emb = th.where(rearrange(keep_mask, 'b -> b 1'), context, null_classes_emb) + context_emb = context * keep_mask + ( + ~keep_mask) * null_classes_emb + + elif "sampled_concept" in kwargs.keys() and cond_part_drop: + sampled_concept = th.from_numpy( + kwargs["sampled_concept"]).unsqueeze(1).to(x.device) + null_chunk = repeat(self.null_classes_emb, 'd -> b d', + b=bs).chunk(self.latent_unit, dim=1) + context_chunk = context.chunk(self.latent_unit, dim=1) + dropped_chunk = [ + th.where(sampled_concept == i, context_chunk[i], + null_chunk[i]) for i in range(self.latent_unit) + ] + context_emb = th.cat(dropped_chunk, dim=1) + else: + context_emb = context + if context_emb.dim() == 2: + context_emb = th.concat(context_emb[:, None].chunk( + self.latent_unit, dim=-1), + dim=1) + else: + context_emb = None + + hs = [] + t_emb = timestep_embedding(timesteps, + self.model_channels, + repeat_only=False) + emb = self.time_embed(t_emb) + if self.num_classes is not None: + assert y.shape == (x.shape[0], ) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + k = 0 + for module in self.input_blocks: + h = module(h, emb, context_emb, mask=mask) + hs.append(h) + if k == 5: + a = 1 + k += 1 + h = self.middle_block(h, emb, context_emb, mask=mask) + for module in self.output_blocks: + + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context_emb, mask=mask) + h = h.type(x.dtype) + pred = self.out(h) + if self.predict_codebook_ids: + return Return(pred=self.id_predictor(h)) + else: + return Return(pred=pred) + + def forward(self, + x, + timesteps=None, + context=None, + mask=None, + y=None, + cond_drop_prob=0, + cond_part_drop=False, + **kwargs): + if "sampled_concept" in kwargs.keys(): + model_out = self._forward(x=x, + timesteps=timesteps, + context=context, + mask=mask, + y=y, + cond_drop_prob=0., + cond_part_drop=False, + **kwargs) + part_context_out = self._forward(x=x, + timesteps=timesteps, + context=context, + mask=mask, + y=y, + cond_drop_prob=0., + cond_part_drop=True, + **kwargs) + null_context_out = self._forward(x=x, + timesteps=timesteps, + context=context, + mask=mask, + y=y, + cond_drop_prob=1., + **kwargs) + cfg_grad = model_out.pred - null_context_out.pred + sub_grad = part_context_out.pred - null_context_out.pred + scaled_out = null_context_out.pred + cfg_grad + out = Return_grad_cfg(pred=scaled_out, + out_grad=cfg_grad, + sub_grad=sub_grad, + null_pred=null_context_out.pred) + else: + out = self._forward(x, + timesteps, + context, + mask, + y, + cond_drop_prob, + cond_part_drop=cond_part_drop, + **kwargs) + return out + + def forward_with_cfg(self, + x, + timesteps=None, + context=None, + y=None, + cfg_scale=None, + **kwargs): + model_out = self._forward(x=x, + timesteps=timesteps, + context=context, + y=y, + cond_drop_prob=0., + **kwargs) + if cfg_scale == 1 and "sampled_concept" not in kwargs.keys(): + return model_out + + null_context_out = self._forward(x=x, + timesteps=timesteps, + context=context, + y=y, + cond_drop_prob=1., + **kwargs) + cfg_grad = model_out.pred - null_context_out.pred + scaled_out = null_context_out.pred + cfg_scale * cfg_grad + + if "sampled_concept" in kwargs.keys(): + part_context_out = self._forward(x=x, + timesteps=timesteps, + context=context, + y=y, + cond_drop_prob=0., + **kwargs) + sub_grad = part_context_out.pred - null_context_out.pred + return Return_grad_cfg(pred=scaled_out, + out_grad=cfg_grad, + sub_grad=sub_grad, + null_pred=null_context_out.pred) + else: + return Return(pred=scaled_out) diff --git a/TarDiff/ldm/modules/diffusionmodules/util.py b/TarDiff/ldm/modules/diffusionmodules/util.py new file mode 100755 index 0000000..0d5d992 --- /dev/null +++ b/TarDiff/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,328 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config +from typing import NamedTuple + + +class Return(NamedTuple): + pred: torch.Tensor + + +class Return_grad(NamedTuple): + pred: torch.Tensor + out_grad: torch.Tensor + + +class Return_grad_full(NamedTuple): + pred: torch.Tensor + out_grad: torch.Tensor + sub_grad: torch.Tensor + + +class Return_grad_cfg(NamedTuple): + pred: torch.Tensor + out_grad: torch.Tensor + sub_grad: torch.Tensor + null_pred: torch.Tensor + + +def return_wrap(inp, coef): + if isinstance(inp, Return) or isinstance(inp, Return_grad_cfg): + return inp.pred + elif isinstance(inp, Return_grad) or isinstance(inp, Return_grad_full): + # return inp.out_grad + return inp.pred + coef * inp.out_grad + + +def make_beta_schedule(schedule, + n_timestep, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + if schedule == "linear": + betas = (torch.linspace(linear_start**0.5, + linear_end**0.5, + n_timestep, + dtype=torch.float64)**2) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + + cosine_s) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, + linear_end, + n_timestep, + dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, + linear_end, + n_timestep, + dtype=torch.float64)**0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, + num_ddim_timesteps, + num_ddpm_timesteps, + verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), + num_ddim_timesteps))**2).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, + ddim_timesteps, + eta, + verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + + alphacums[ddim_timesteps[:-1]].tolist()) + ddim_coef = -np.sqrt(1. - alphas) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print( + f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}' + ) + print( + f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}' + ) + return sigmas, alphas, alphas_prev, ddim_coef + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1, ) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [ + x.detach().requires_grad_(True) for x in ctx.input_tensors + ] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * + torch.arange(start=0, end=half, dtype=torch.float32) / + half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config( + c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1, ) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/TarDiff/ldm/modules/distributions/__init__.py b/TarDiff/ldm/modules/distributions/__init__.py new file mode 100755 index 0000000..0eca642 --- /dev/null +++ b/TarDiff/ldm/modules/distributions/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. \ No newline at end of file diff --git a/TarDiff/ldm/modules/distributions/distributions.py b/TarDiff/ldm/modules/distributions/distributions.py new file mode 100755 index 0000000..ebf797b --- /dev/null +++ b/TarDiff/ldm/modules/distributions/distributions.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import numpy as np + + +class AbstractDistribution: + + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn( + self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def kl_splits(self, latent_unit=6): + mean_splits = self.mean.chunk(latent_unit, dim=-1) + var_splits = self.var.chunk(latent_unit, dim=-1) + logvar_splits = self.logvar.chunk(latent_unit, dim=-1) + kl_loss = 0 + for mean, var, logvar in zip(mean_splits, var_splits, logvar_splits): + kl_split = 0.5 * torch.sum(torch.pow(mean, 2) + var - 1.0 - logvar, + dim=-1) + kl_loss += torch.sum(kl_split) / kl_split.shape[0] + return kl_loss / latent_unit + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2)**2) * torch.exp(-logvar2)) diff --git a/TarDiff/ldm/modules/ema.py b/TarDiff/ldm/modules/ema.py new file mode 100755 index 0000000..2683487 --- /dev/null +++ b/TarDiff/ldm/modules/ema.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +from torch import nn + + +class LitEma(nn.Module): + + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + 'num_updates', + torch.tensor(0, dtype=torch.int) + if use_num_upates else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, + (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as( + m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * + (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_( + shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/TarDiff/ldm/modules/encoders/__init__.py b/TarDiff/ldm/modules/encoders/__init__.py new file mode 100755 index 0000000..0eca642 --- /dev/null +++ b/TarDiff/ldm/modules/encoders/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. \ No newline at end of file diff --git a/TarDiff/ldm/modules/encoders/modules.py b/TarDiff/ldm/modules/encoders/modules.py new file mode 100755 index 0000000..6f923ee --- /dev/null +++ b/TarDiff/ldm/modules/encoders/modules.py @@ -0,0 +1,2160 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import torch.nn as nn +from functools import partial +# import clip +from einops import rearrange, repeat +import kornia +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution +import copy +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + + +class AbstractEncoder(nn.Module): + + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class ClassEmbedder(nn.Module): + + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + + def __init__(self, + n_embed, + n_layer, + vocab_size, + max_seq_len=77, + device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, + max_seq_len=max_seq_len, + attn_layers=Encoder( + dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + + def __init__(self, + n_embed, + n_layer, + vocab_size=30522, + max_seq_len=77, + device="cuda", + use_tokenizer=True, + embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, + max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, + max_seq_len=max_seq_len, + attn_layers=Encoder( + dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text) #.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in [ + 'nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area' + ] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, + mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print( + f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.' + ) + self.channel_mapper = nn.Conv2d(in_channels, + out_channels, + 1, + bias=bias) + + def forward(self, x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +# class FrozenCLIPTextEmbedder(nn.Module): +# """ +# Uses the CLIP transformer encoder for text. +# """ +# def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): +# super().__init__() +# self.model, _ = clip.load(version, jit=False, device="cpu") +# self.device = device +# self.max_length = max_length +# self.n_repeat = n_repeat +# self.normalize = normalize + +# def freeze(self): +# self.model = self.model.eval() +# for param in self.parameters(): +# param.requires_grad = False + +# def forward(self, text): +# tokens = clip.tokenize(text).to(self.device) +# z = self.model.encode_text(tokens) +# if self.normalize: +# z = z / torch.linalg.norm(z, dim=1, keepdim=True) +# return z + +# def encode(self, text): +# z = self(text) +# if z.ndim==2: +# z = z[:, None, :] +# z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) +# return z + +# class FrozenClipImageEmbedder(nn.Module): +# """ +# Uses the CLIP image encoder. +# """ +# def __init__( +# self, +# model, +# jit=False, +# device='cuda' if torch.cuda.is_available() else 'cpu', +# antialias=False, +# ): +# super().__init__() +# self.model, _ = clip.load(name=model, device=device, jit=jit) + +# self.antialias = antialias + +# self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) +# self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + +# def preprocess(self, x): +# # normalize to [0,1] +# x = kornia.geometry.resize(x, (224, 224), +# interpolation='bicubic',align_corners=True, +# antialias=self.antialias) +# x = (x + 1.) / 2. +# # renormalize according to clip +# x = kornia.enhance.normalize(x, self.mean, self.std) +# return x + +# def forward(self, x): +# # x is assumed to be in range [-1,1] +# return self.model.encode_image(self.preprocess(x)) + + +class ResBlock(nn.Module): + + def __init__(self, in_channels, out_channels, mid_channels=None, bn=False): + super(ResBlock, self).__init__() + + if mid_channels is None: + mid_channels = out_channels + + layers = [ + nn.ReLU(), + nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1), + nn.ReLU(), + nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + ] + if bn: + layers.insert(2, nn.BatchNorm2d(out_channels)) + self.convs = nn.Sequential(*layers) + + def forward(self, x): + return x + self.convs(x) + + +class View(nn.Module): + + def __init__(self, size): + super(View, self).__init__() + self.size = size + + def forward(self, tensor): + return tensor.view(self.size) + + +class Encoder4(nn.Module): + + def __init__(self, d, bn=True, num_channels=3, latent_dim=192): + super(Encoder4, self).__init__() + self.latent_dim = latent_dim + self.encoder = nn.Sequential( + nn.Conv2d(num_channels, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + ResBlock(d, d, bn=bn), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + ResBlock(d, d, bn=bn), + View((-1, 128 * 4 * 4)), # batch_size x 2048 + nn.Linear(2048, self.latent_dim)) + + def forward(self, x): + return self.encoder(x) + + +class ResBlockTime(nn.Module): + + def __init__(self, in_channels, out_channels, mid_channels=None, bn=False): + super(ResBlockTime, self).__init__() + + if mid_channels is None: + mid_channels = out_channels + + layers = [ + nn.ReLU(), + nn.Conv1d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1), + nn.ReLU(), + nn.Conv1d(mid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + ] + if bn: + layers.insert(2, nn.BatchNorm1d(out_channels)) + self.convs = nn.Sequential(*layers) + + def forward(self, x): + return x + self.convs(x) + + +class Encoder4Time(nn.Module): + + def __init__(self, d, w, bn=True, num_channels=3, latent_dim=192): + super(Encoder4Time, self).__init__() + self.latent_dim = latent_dim + flatten_dim = int(d * w / 16) + self.encoder = nn.Sequential( + nn.Conv1d(num_channels, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(d), + nn.ReLU(inplace=True), + nn.Conv1d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(d), + nn.ReLU(inplace=True), + nn.Conv1d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(d), + nn.Conv1d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(d), + nn.ReLU(inplace=True), + ResBlockTime(d, d, bn=bn), + nn.BatchNorm1d(d), + nn.ReLU(inplace=True), + ResBlockTime(d, d, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.latent_dim)) + + def forward(self, x): + return self.encoder(x) + + +class Encoder3Time(nn.Module): + + def __init__(self, d, w, bn=True, num_channels=3, latent_dim=192): + super(Encoder3Time, self).__init__() + self.latent_dim = latent_dim + flatten_dim = int(d * w / 8) + self.encoder = nn.Sequential( + nn.Conv1d(num_channels, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(d), + nn.ReLU(inplace=True), + nn.Conv1d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(d), + nn.ReLU(inplace=True), + nn.Conv1d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(d), + nn.ReLU(inplace=True), + ResBlockTime(d, d, bn=bn), + nn.BatchNorm1d(d), + nn.ReLU(inplace=True), + ResBlockTime(d, d, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.latent_dim)) + + def forward(self, x): + return self.encoder(x) + + +class Encoder3TimeLN(nn.Module): + + def __init__(self, d, w, bn=True, num_channels=3, latent_dim=192): + super(Encoder3TimeLN, self).__init__() + self.latent_dim = latent_dim + flatten_dim = int(d * w / 8) + self.encoder = nn.Sequential( + nn.Conv1d(num_channels, d, kernel_size=4, stride=2, padding=1), + nn.LayerNorm([d, int(w / 2)]), + nn.ReLU(inplace=True), + nn.Conv1d(d, d, kernel_size=4, stride=2, padding=1), + nn.LayerNorm([d, int(w / 4)]), + nn.ReLU(inplace=True), + nn.Conv1d(d, d, kernel_size=4, stride=2, padding=1), + nn.LayerNorm([d, int(w / 8)]), + nn.ReLU(inplace=True), + ResBlockTime(d, d, bn=False), + nn.LayerNorm([d, int(w / 8)]), + nn.ReLU(inplace=True), + ResBlockTime(d, d, bn=False), + nn.LayerNorm([d, int(w / 8)]), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.latent_dim)) + + def forward(self, x): + return self.encoder(x) + + +class Encoder4_vae(nn.Module): + + def __init__(self, d, bn=True, num_channels=3, latent_dim=192): + super(Encoder4_vae, self).__init__() + self.latent_dim = latent_dim + self.encoder = nn.Sequential( + nn.Conv2d(num_channels, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + ResBlock(d, d, bn=bn), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + ResBlock(d, d, bn=bn), + View((-1, 128 * 4 * 4)), # batch_size x 2048 + nn.Linear(2048, 2 * self.latent_dim)) + + def forward(self, x): + moments = self.encoder(x) + self.posteriors = DiagonalGaussianDistribution(moments) + return self.posteriors.sample() + + def kl_loss(self, latent_unit): + kl_loss_splits = self.posteriors.kl_splits(latent_unit=latent_unit) + return kl_loss_splits + + +class Encoder256(nn.Module): + + def __init__(self, d, bn=True, num_channels=3, latent_dim=192): + super(Encoder256, self).__init__() + self.latent_dim = latent_dim + self.encoder = nn.Sequential( + nn.Conv2d(num_channels, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + ResBlock(d, d, bn=bn), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True), + ResBlock(d, d, bn=bn), + View((-1, 128 * 4 * 4)), # batch_size x 2048 + nn.Linear(2048, self.latent_dim)) + + def forward(self, x): + return self.encoder(x) + + +from math import pi, log +from functools import wraps + +import torch +from torch import nn, einsum +import torch.nn.functional as F +import numpy as np + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cache_fn(f): + cache = None + + @wraps(f) + def cached_fn(*args, _cache=True, **kwargs): + if not _cache: + return f(*args, **kwargs) + nonlocal cache + if cache is not None: + return cache + cache = f(*args, **kwargs) + return cache + + return cached_fn + + +# helper classes + + +def fourier_encode(x, max_freq, num_bands=4): + x = x.unsqueeze(-1) + device, dtype, orig_x = x.device, x.dtype, x + + scales = torch.linspace(1., + max_freq / 2, + num_bands, + device=device, + dtype=dtype) + scales = scales[(*((None, ) * (len(x.shape) - 1)), Ellipsis)] + + x = x * scales * pi + x = torch.cat([x.sin(), x.cos()], dim=-1) + x = torch.cat((x, orig_x), dim=-1) + return x + + +class PreNorm(nn.Module): + + def __init__(self, dim, fn, context_dim=None): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + self.norm_context = nn.LayerNorm(context_dim) if exists( + context_dim) else None + + def forward(self, x, **kwargs): + x = self.norm(x) + + if exists(self.norm_context): + context = kwargs['context'] + normed_context = self.norm_context(context) + kwargs.update(context=normed_context) + + return self.fn(x, **kwargs) + + +class GEGLU(nn.Module): + + def forward(self, x): + x, gates = x.chunk(2, dim=-1) + return x * F.gelu(gates) + + +class FeedForward(nn.Module): + + def __init__(self, dim, mult=4): + super().__init__() + self.net = nn.Sequential(nn.Linear(dim, dim * mult * 2), GEGLU(), + nn.Linear(dim * mult, dim)) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, query_dim) + + def forward(self, x, context=None, mask=None, hard_assign=True): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k, v = self.to_kv(context).chunk(2, dim=-1) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + if hard_assign: + sim = sim + (1 - mask) * max_neg_value + else: + sim = sim + mask + # sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +# class Attention1D(nn.Module): +# def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64): +# super().__init__() +# inner_dim = dim_head * heads +# context_dim = default(context_dim, query_dim) +# self.scale = dim_head ** -0.5 +# self.heads = heads +# self.norm = nn.BatchNorm1d(query_dim) +# self.to_q = nn.Linear(query_dim, inner_dim, bias = False) +# self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) +# self.to_out = nn.Linear(inner_dim, query_dim) + +# def forward(self, x, context = None, mask = None): +# h = self.heads + +# q = self.to_q(self.norm(x)) +# context = default(context, x) +# k, v = self.to_kv(context).chunk(2, dim = -1) + +# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) + +# sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + +# if exists(mask): +# mask = rearrange(mask, 'b ... -> b (...)') +# max_neg_value = -torch.finfo(sim.dtype).max +# mask = repeat(mask, 'b j -> (b h) () j', h = h) +# sim.masked_fill_(~mask, max_neg_value) + +# # attention, what we cannot get enough of +# attn = sim.softmax(dim = -1) + +# out = einsum('b i j, b j d -> b i d', attn, v) +# out = rearrange(out, '(b h) n d -> b n (h d)', h = h) +# return self.to_out(out) + + +class MLP_head(nn.Module): + + def __init__(self, z_dim, hidden_dim, num_cls): + super().__init__() + self.net = nn.Sequential(nn.Linear(z_dim, hidden_dim), nn.GELU(), + nn.Linear(hidden_dim, hidden_dim), nn.GELU(), + nn.Linear(hidden_dim, num_cls)) + + def forward(self, x): + return self.net(x) + + +# main class +class PerceiverEncoder(nn.Module): + + def __init__(self, + *, + index_num=32, + depth=4, + dim=32, + z_index_dim=10, + latent_dim=32, + cross_heads=1, + latent_heads=3, + cross_dim_head=32, + latent_dim_head=32, + weight_tie_layers=False, + max_freq=10, + num_freq_bands=6): + super().__init__() + self.num_latents = z_index_dim + self.components = z_index_dim + self.max_freq = max_freq + self.num_freq_bands = num_freq_bands + self.depth = depth + + self.encoder = nn.Sequential( + nn.Conv2d(3, + latent_dim // 2, + kernel_size=4, + stride=2, + padding=1, + bias=False), + nn.BatchNorm2d(latent_dim // 2), + nn.ReLU(inplace=True), + nn.Conv2d(latent_dim // 2, + latent_dim, + kernel_size=4, + stride=2, + padding=1, + bias=False), + nn.BatchNorm2d(latent_dim), + nn.ReLU(inplace=True), + ResBlock(latent_dim, latent_dim, bn=True), + nn.BatchNorm2d(latent_dim), + ResBlock(latent_dim, latent_dim, bn=True), + ) + + self.latents = nn.Parameter(torch.randn(self.num_latents, latent_dim), + True) + self.cs_layers = nn.ModuleList([]) + for i in range(depth): + self.cs_layers.append( + nn.ModuleList([ + PreNorm(latent_dim, + Attention(latent_dim, + dim + 26, + heads=cross_heads, + dim_head=cross_dim_head), + context_dim=dim + 26), + PreNorm(latent_dim, FeedForward(latent_dim)) + ])) + + get_latent_attn = lambda: PreNorm( + dim + 26, + Attention(dim + 26, heads=latent_heads, dim_head=latent_dim_head)) + get_latent_ff = lambda: PreNorm(dim + 26, FeedForward(dim + 26)) + get_latent_attn, get_latent_ff = map(cache_fn, + (get_latent_attn, get_latent_ff)) + + self.layers = nn.ModuleList([]) + cache_args = {'_cache': weight_tie_layers} + + for i in range(depth - 1): + self.layers.append( + nn.ModuleList([ + get_latent_attn(**cache_args), + get_latent_ff(**cache_args) + ])) + self.fc_layer = nn.Linear(dim, index_num) + + def forward(self, data, mask=None): + data = self.encoder(data) + data = data.reshape(*data.shape[:2], -1).permute(0, 2, 1) + b, *axis, device = *data.shape, data.device + + # calculate fourier encoded positions in the range of [-1, 1], for all axis + + axis_pos = list( + map( + lambda size: torch.linspace(-1., 1., steps=size, device=device + ), + (int(np.sqrt(axis[0])), int(np.sqrt(axis[0]))))) + pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1) + enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) + enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') + enc_pos = repeat(enc_pos, '... -> b ...', b=b) + + data = torch.cat((data, enc_pos.reshape(b, -1, enc_pos.shape[-1])), + dim=-1) + x0 = repeat(self.latents, 'n d -> b n d', b=b) + for i in range(self.depth): + cross_attn, cross_ff = self.cs_layers[i] + + # cross attention only happens once for Perceiver IO + + x = cross_attn(x0, context=data, mask=mask) + x0 + x0 = cross_ff(x) + x + + if i != self.depth - 1: + self_attn, self_ff = self.layers[i] + x_d = self_attn(data) + data + data = self_ff(x_d) + x_d + + return self.fc_layer(x0).reshape(x0.shape[0], -1) + + +def swish(x): + return x * torch.sigmoid(x) + + +class View(nn.Module): + + def __init__(self, size): + super(View, self).__init__() + self.size = size + + def forward(self, tensor): + return tensor.view(self.size) + + +class MLP_layer(nn.Module): + + def __init__(self, z_dim=512, latent_dim=256): + super(MLP_layer, self).__init__() + self.net = nn.Sequential(nn.Linear(z_dim, latent_dim), nn.GELU(), + nn.Linear(latent_dim, latent_dim)) + + def forward(self, x): + return self.net(x) + + +class MLP_layers(nn.Module): + + def __init__(self, z_dim=512, latent_dim=256, num_latents=16): + super(MLP_layers, self).__init__() + self.nets = nn.ModuleList([ + MLP_layer(z_dim=z_dim, latent_dim=latent_dim) + for i in range(num_latents) + ]) + + def forward(self, x): + out = [] + for sub_net in self.nets: + out.append(sub_net(x)[:, None, :]) + return torch.cat(out, dim=1) + + +class PerceiverDecoder(nn.Module): + + def __init__( + self, + *, + depth=6, + index_num=10, + dim=256, + z_index_dim=64, + latent_dim=256, + cross_heads=1, + cross_dim_head=128, + latent_heads=6, + latent_dim_head=128, + fourier_encode_data=False, + weight_tie_layers=False, + max_freq=10, + num_freq_bands=6, + ): + super().__init__() + num_latents = z_index_dim + self.components = z_index_dim + self.max_freq = max_freq + self.num_freq_bands = num_freq_bands + self.fourier_encode_data = fourier_encode_data + self.latents = nn.Parameter(torch.randn(num_latents, latent_dim), True) + + self.depth = depth + if depth != 0: + get_latent_attn = lambda: PreNorm( + dim, + Attention(dim, heads=latent_heads, dim_head=latent_dim_head)) + get_latent_ff = lambda: PreNorm(dim, FeedForward(dim)) + get_latent_attn, get_latent_ff = map( + cache_fn, (get_latent_attn, get_latent_ff)) + + self.slayers = nn.ModuleList([]) + cache_args = {'_cache': weight_tie_layers} + + for i in range(depth - 1): + self.slayers.append( + nn.ModuleList([ + get_latent_attn(**cache_args), + get_latent_ff(**cache_args) + ])) + + self.cs_layers = nn.ModuleList([]) + for i in range(depth): + self.cs_layers.append( + nn.ModuleList([ + PreNorm(latent_dim, + Attention(latent_dim, + dim, + heads=cross_heads, + dim_head=cross_dim_head), + context_dim=dim), + PreNorm(latent_dim, FeedForward(latent_dim)) + ])) + self.fc_layer = nn.Linear(dim, index_num) + + if depth != 0: + get_latent_attn = lambda: PreNorm( + latent_dim, + Attention( + latent_dim, heads=latent_heads, dim_head=latent_dim_head)) + get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim) + ) + get_latent_attn, get_latent_ff = map( + cache_fn, (get_latent_attn, get_latent_ff)) + + self.layers = nn.ModuleList([]) + cache_args = {'_cache': weight_tie_layers} + + for i in range(depth): + self.layers.append( + nn.ModuleList([ + get_latent_attn(**cache_args), + get_latent_ff(**cache_args) + ])) + + def forward(self, data, mask=None): + b, *axis, device = *data.shape, data.device + if self.fourier_encode_data: + # calculate fourier encoded positions in the range of [-1, 1], for all axis + + axis_pos = list( + map( + lambda size: torch.linspace( + -1., 1., steps=size, device=device), + (int(np.sqrt(axis[0])), int(np.sqrt(axis[0]))))) + pos = torch.stack(torch.meshgrid(*axis_pos, indexing='ij'), dim=-1) + enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) + enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') + enc_pos = repeat(enc_pos, '... -> b ...', b=b) + + data = torch.cat((data, enc_pos.reshape(b, -1, enc_pos.shape[-1])), + dim=-1) + + x = repeat(self.latents, 'n d -> b n d', b=b) + cp_vals = data + for i in range(self.depth): + + cross_attn, cross_ff = self.cs_layers[i] + x = cross_attn(x, context=cp_vals, mask=mask) + x + x = cross_ff(x) + x + + self_attn, self_ff = self.layers[i] + x = self_attn(x) + x + x = self_ff(x) + x + + if i != self.depth - 1: + self_attn, self_ff = self.slayers[i] + cp_vals = self_attn(cp_vals) + cp_vals + cp_vals = self_ff(cp_vals) + cp_vals + + return self.fc_layer(x) + + +class SplitTSEncoder(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + dropout=0., + emb_dropout=0., + bn=True): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + dim_out = latent_dim + flatten_dim = int(dim * window / 8) + self.latents = nn.Parameter(torch.randn(num_latents, self.latent_dim), + requires_grad=True) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + ) + + self.invariant_encoder = nn.Sequential( + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out * 2)) + + self.specific_ffn = nn.Sequential( + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.specific_encoder_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.specific_encoder_layers.append( + nn.ModuleList([ + PreNorm(dim_out, + Attention(query_dim=self.latent_dim, + context_dim=self.latent_dim, + heads=num_heads, + dim_head=dim_head), + context_dim=self.latent_dim), + PreNorm(dim_out, nn.Linear(dim_out, dim_out)) + ])) + + def forward(self, x): + b = x.shape[0] + latents = repeat(self.latents, 'n d -> b n d', b=b) + h = self.share_encoder(x) + invariant_out = self.invariant_encoder(h) + sh = self.specific_ffn(h)[:, None] # b, 1, d + for attn, ff in self.specific_encoder_layers: + sh = attn(sh, context=latents) + sh + sh = ff(sh) + sh + + sh = sh.squeeze(1) # b, 1, d --> b, d + out = torch.cat((invariant_out, sh), dim=1) + return out + + +class DomainTSEqEncoder(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + The length of the two part are equal in this implementation. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + dropout=0., + emb_dropout=0., + bn=True): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.latents = nn.Parameter(torch.randn(num_latents, self.latent_dim), + requires_grad=True) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + + self.invariant_encoder = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + + self.specific_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.specific_encoder_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.specific_encoder_layers.append( + nn.ModuleList([ + PreNorm(dim_out, + Attention(query_dim=self.latent_dim, + context_dim=self.latent_dim, + heads=num_heads, + dim_head=dim_head), + context_dim=self.latent_dim), + PreNorm(dim_out, nn.Linear(dim_out, dim_out)) + ])) + + def forward(self, x): + b = x.shape[0] + latents = repeat(self.latents, 'n d -> b n d', b=b) + h = self.share_encoder(x) + invariant_out = self.invariant_encoder(h) + invariant_out = invariant_out[:, None] + sh = self.specific_ffn(h)[:, None] # b, 1, d + for attn, ff in self.specific_encoder_layers: + sh = attn(sh, context=latents) + sh + sh = ff(sh) + sh + out = torch.cat((invariant_out, sh), dim=1) # b, 2, d + return out + + +class SplitTSEqEncoder(nn.Module): + ''' + The input are encoded into two seperated but identical functioning parts. + The length of the two part are equal in this implementation. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + dropout=0., + emb_dropout=0., + bn=True): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + + self.encoder1 = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + + self.encoder2 = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + + def forward(self, x): + h = self.share_encoder(x) + out1 = self.encoder1(h) + out2 = self.encoder2(h) + out = torch.cat((out1[:, None], out2[:, None]), dim=1) # b, 2, d + return out + + +class SingleTSEncoder(nn.Module): + ''' + The input are encoded into one embedding. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + dropout=0., + emb_dropout=0., + bn=True): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + + self.encoder1 = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + + def forward(self, x): + h = self.share_encoder(x) + out1 = self.encoder1(h) + return out1[:, None] + + +class OnlyPrototypeEncoder(nn.Module): + ''' + The input are encoded into one embedding. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + dropout=0., + emb_dropout=0., + bn=True, + orth_emb=False, + mask_assign=False, + hard_assign=False): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + self.mask_assign = mask_assign + self.hard_assign = hard_assign + self.orth_emb = orth_emb + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.latents = nn.Parameter(torch.randn(num_latents, self.latent_dim), + requires_grad=True) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + + self.specific_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.num_latents), + ) + self.sigmoid = nn.Sigmoid() + self.specific_encoder_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.specific_encoder_layers.append( + nn.ModuleList([ + PreNorm(dim_out, + Attention(query_dim=self.latent_dim, + context_dim=self.latent_dim, + heads=num_heads, + dim_head=dim_head), + context_dim=self.latent_dim), + PreNorm(dim_out, nn.Linear(dim_out, dim_out)) + ])) + + def forward(self, x): + b = x.shape[0] + if self.orth_emb: + # latents = torch_expm((self.latents - self.latents.transpose(0, 1)).unsqueeze(0)) + q, r = torch.linalg.qr(self.latents.T) + latents = repeat(q.T, 'n d -> b n d', b=b) + else: + latents = repeat(self.latents, 'n d -> b n d', b=b) + h = self.share_encoder(x) + sh = self.specific_ffn(h)[:, None] # b, 1, d + + if self.mask_assign: + mask_logit = self.mask_ffn(h) + if self.hard_assign: # hard assign + mask_prob = self.sigmoid(mask_logit) + mask = mask_prob > 0.5 # torch.bernoulli(mask_logit) + mask = (mask.float() - mask_prob).detach() + mask_prob + else: + mask = mask_logit # soft assign + else: + mask = None + + for attn, ff in self.specific_encoder_layers: + sh = attn( + sh, context=latents, mask=mask, + hard_assign=self.hard_assign) + sh + sh = ff(sh) + sh + + # out = sh # b, 1, d + + # if self.mask_assign: + return sh, mask + # else: + # return out, None + + +class ProtoAssignEncoder(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + The length of the two part are equal in this implementation. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + dropout=0., + emb_dropout=0., + bn=True, + mask_assign=False, + hard_assign=False, + orth_emb=False): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + self.mask_assign = mask_assign + self.hard_assign = hard_assign + self.orth_emb = orth_emb + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.latents = nn.Parameter(torch.randn(num_latents, self.latent_dim), + requires_grad=True) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + + self.invariant_encoder = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + + self.specific_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.num_latents), + ) + self.sigmoid = nn.Sigmoid() + self.specific_encoder_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.specific_encoder_layers.append( + nn.ModuleList([ + PreNorm(dim_out, + Attention(query_dim=self.latent_dim, + context_dim=self.latent_dim, + heads=num_heads, + dim_head=dim_head), + context_dim=self.latent_dim), + PreNorm(dim_out, nn.Linear(dim_out, dim_out)) + ])) + + def forward(self, x): + b = x.shape[0] + if self.orth_emb: + # latents = torch_expm((self.latents - self.latents.transpose(0, 1)).unsqueeze(0)) + q, r = torch.linalg.qr(self.latents.T) + latents = repeat(q.T, 'n d -> b n d', b=b) + else: + latents = repeat(self.latents, 'n d -> b n d', b=b) + h = self.share_encoder(x) + invariant_out = self.invariant_encoder(h) + invariant_out = invariant_out[:, None] + sh = self.specific_ffn(h)[:, None] # b, 1, d + + if self.mask_assign: + mask_logit = self.mask_ffn(h) + if self.hard_assign: # hard assign + mask_prob = self.sigmoid(mask_logit) + mask = mask_prob > 0.5 # torch.bernoulli(mask_logit) + mask = (mask.float() - mask_prob).detach() + mask_prob + else: + mask = mask_logit # soft assign + else: + mask = None + + for attn, ff in self.specific_encoder_layers: + sh = attn( + sh, context=latents, mask=mask, + hard_assign=self.hard_assign) + sh + sh = ff(sh) + sh + + out = torch.cat((invariant_out, sh), dim=1) # b, 2, d + return out, mask + + +class DomainUnifiedEncoder(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + The length of the two part are equal in this implementation. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + bn=True, + split_inv=True, + use_prototype=True, + mask_assign=False, + hard_assign=False, + orth_proto=False, + grad_hook=False, + **kwargs): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + self.mask_assign = mask_assign + self.hard_assign = hard_assign + self.orth_proto = orth_proto + self.split_inv = split_inv + self.use_prototype = use_prototype + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + + if self.split_inv: + self.invariant_encoder = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + + if self.use_prototype: + # self.latents = nn.Parameter(torch.randn(num_latents, self.latent_dim), requires_grad=True) + self.latents = nn.Parameter(torch.empty( + num_latents, self.latent_dim), + requires_grad=True) + nn.init.orthogonal_(self.latents) + self.init_latents = copy.deepcopy(self.latents.detach()) + self.specific_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.num_latents), + ) + self.sigmoid = nn.Sigmoid() + self.specific_encoder_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.specific_encoder_layers.append( + nn.ModuleList([ + PreNorm(dim_out, + Attention(query_dim=self.latent_dim, + context_dim=self.latent_dim, + heads=num_heads, + dim_head=dim_head), + context_dim=self.latent_dim), + PreNorm(dim_out, nn.Linear(dim_out, dim_out)) + ])) + else: + self.specific_encoder = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + else: + if self.use_prototype: + # self.latents = nn.Parameter(torch.randn(num_latents, self.latent_dim), requires_grad=True) + self.latents = nn.Parameter(torch.empty( + num_latents, self.latent_dim), + requires_grad=True) + nn.init.orthogonal_(self.latents) + self.init_latents = copy.deepcopy(self.latents.detach()) + self.specific_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.num_latents), + ) + self.sigmoid = nn.Sigmoid() + self.specific_encoder_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.specific_encoder_layers.append( + nn.ModuleList([ + PreNorm(dim_out, + Attention(query_dim=self.latent_dim, + context_dim=self.latent_dim, + heads=num_heads, + dim_head=dim_head), + context_dim=self.latent_dim), + PreNorm(dim_out, nn.Linear(dim_out, dim_out)) + ])) + + else: + self.out_encoder = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.grad_hook = grad_hook + + def forward(self, x): + # if self.grad_hook: + # handle = self.latents.register_hook(self.hook_func) + b = x.shape[0] + h = self.share_encoder(x) + mask = None + + if self.split_inv: + invariant_out = self.invariant_encoder(h) + invariant_out = invariant_out[:, None] + + if self.use_prototype: + sh = self.specific_ffn(h)[:, None] # b, 1, d + + if self.orth_proto: + q, r = torch.linalg.qr(self.latents.T) + latents = repeat(q.T, 'n d -> b n d', b=b) + else: + latents = repeat(self.latents, 'n d -> b n d', b=b) + if self.mask_assign: + mask_logit = self.mask_ffn(h) + if self.hard_assign: # hard assign + mask_prob = self.sigmoid(mask_logit) + mask = mask_prob > 0.5 # torch.bernoulli(mask_logit) + mask = (mask.float() - mask_prob).detach() + mask_prob + else: + mask = mask_logit # soft assign + else: + mask = None + + for attn, ff in self.specific_encoder_layers: + sh = attn(sh, + context=latents, + mask=mask, + hard_assign=self.hard_assign) + sh + sh = ff(sh) + sh # b, 1, d + + out = torch.cat((invariant_out, sh), dim=1) # b, 2, d + else: + spec_out = self.specific_encoder(h)[:, None] + out = torch.cat((invariant_out, spec_out), dim=1) # b, 2, d + else: + if self.use_prototype: + sh = self.specific_ffn(h)[:, None] + if self.orth_proto: + q, r = torch.linalg.qr(self.latents.T) + latents = repeat(q.T, 'n d -> b n d', b=b) + else: + latents = repeat(self.latents, 'n d -> b n d', b=b) + if self.mask_assign: + mask_logit = self.mask_ffn(h) + if self.hard_assign: # hard assign + mask_prob = self.sigmoid(mask_logit) + mask = mask_prob > 0.5 # torch.bernoulli(mask_logit) + mask = (mask.float() - mask_prob).detach() + mask_prob + else: + mask = mask_logit # soft assign + else: + mask = None + + for attn, ff in self.specific_encoder_layers: + sh = attn(sh, + context=latents, + mask=mask, + hard_assign=self.hard_assign) + sh + sh = ff(sh) + sh # b, 1, d + + out = sh # b, 1, d + else: + out = self.out_encoder(h)[:, None] # b, 1, d + # if self.grad_hook: + # handle.remove() + return out, mask + + +class DomainUnifiedEncoderHook(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + The length of the two part are equal in this implementation. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + bn=True, + split_inv=True, + use_prototype=True, + mask_assign=False, + hard_assign=False, + orth_proto=False, + grad_hook=False): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + self.mask_assign = mask_assign + self.hard_assign = hard_assign + self.orth_proto = orth_proto + self.split_inv = split_inv + self.use_prototype = use_prototype + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + + if self.split_inv: + self.invariant_encoder = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + + if self.use_prototype: + self.latents = nn.Parameter(torch.empty( + num_latents, self.latent_dim), + requires_grad=True) + nn.init.orthogonal_(self.latents) + self.specific_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.num_latents), + ) + self.sigmoid = nn.Sigmoid() + self.specific_encoder_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.specific_encoder_layers.append( + nn.ModuleList([ + PreNorm(dim_out, + Attention(query_dim=self.latent_dim, + context_dim=self.latent_dim, + heads=num_heads, + dim_head=dim_head), + context_dim=self.latent_dim), + PreNorm(dim_out, nn.Linear(dim_out, dim_out)) + ])) + else: + self.specific_encoder = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + else: + if self.use_prototype: + self.latents = nn.Parameter(torch.empty( + num_latents, self.latent_dim), + requires_grad=True) + nn.init.orthogonal_(self.latents) + self.specific_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.num_latents), + ) + self.sigmoid = nn.Sigmoid() + self.specific_encoder_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.specific_encoder_layers.append( + nn.ModuleList([ + PreNorm(dim_out, + Attention(query_dim=self.latent_dim, + context_dim=self.latent_dim, + heads=num_heads, + dim_head=dim_head), + context_dim=self.latent_dim), + PreNorm(dim_out, nn.Linear(dim_out, dim_out)) + ])) + + else: + self.out_encoder = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + nn.BatchNorm1d(dim), + nn.ReLU(inplace=True), + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.grad_hook = grad_hook + if grad_hook: + self.hook_grads = [] + self.out_grads = [] + self.latent_grads = [] + + def cap_grad(grad): + self.hook_grads.append(grad.clone()) + return grad + + self.hook_func = cap_grad + self.latents.register_hook(self.hook_func) + self.init_latents = self.latents.detach() + + def forward(self, x): + # if self.grad_hook: + # def cap_grad(grad): + # self.latent_grads.append(grad.clone()) + # return grad + # if self.grad_hook: + # handle = self.latents.register_hook(self.hook_func) + b = x.shape[0] + h = self.share_encoder(x) + mask = None + + if self.split_inv: + invariant_out = self.invariant_encoder(h) + invariant_out = invariant_out[:, None] + + if self.use_prototype: + sh = self.specific_ffn(h)[:, None] # b, 1, d + # self.hook_latents = torch.nn.functional.linear(torch.eye(self.latents.shape[0]).to(self.latents.device), self.latents.T) # self.latents + 0 # torch.mul(self.latents, 1) # self.latents * 1 + # if self.training: + # self.latent_handle = self.hook_latents.register_hook(cap_grad) + if self.orth_proto: + q, r = torch.linalg.qr(self.latents.T) + latents = repeat(q.T, 'n d -> b n d', b=b) + else: + latents = repeat(self.latents, 'n d -> b n d', b=b) + if self.mask_assign: + mask_logit = self.mask_ffn(h) + if self.hard_assign: # hard assign + mask_prob = self.sigmoid(mask_logit) + mask = mask_prob > 0.5 # torch.bernoulli(mask_logit) + mask = (mask.float() - mask_prob).detach() + mask_prob + else: + mask = mask_logit # soft assign + else: + mask = None + + for attn, ff in self.specific_encoder_layers: + sh = attn(sh, + context=latents, + mask=mask, + hard_assign=self.hard_assign) + sh + sh = ff(sh) + sh # b, 1, d + + out = torch.cat((invariant_out, sh), dim=1) # b, 2, d + else: + spec_out = self.specific_encoder(h)[:, None] + out = torch.cat((invariant_out, spec_out), dim=1) # b, 2, d + else: + if self.use_prototype: + sh = self.specific_ffn(h)[:, None] + # self.hook_latents = torch.mul(self.latents, 1) # self.latents * 1 + # if self.training: + # self.latent_handle = self.hook_latents.register_hook(cap_grad) + if self.orth_proto: + q, r = torch.linalg.qr(self.latents.T) + latents = repeat(q.T, 'n d -> b n d', b=b) + else: + latents = repeat(self.latents, 'n d -> b n d', b=b) + if self.mask_assign: + mask_logit = self.mask_ffn(h) + if self.hard_assign: # hard assign + mask_prob = self.sigmoid(mask_logit) + mask = mask_prob > 0.5 # torch.bernoulli(mask_logit) + mask = (mask.float() - mask_prob).detach() + mask_prob + else: + mask = mask_logit # soft assign + else: + mask = None + + for attn, ff in self.specific_encoder_layers: + sh = attn(sh, + context=latents, + mask=mask, + hard_assign=self.hard_assign) + sh + sh = ff(sh) + sh # b, 1, d + + out = sh # b, 1, d + else: + out = self.out_encoder(h)[:, None] # b, 1, d + + # if self.grad_hook: + # def cap_grad(grad): + # self.out_grads.append(grad.clone()) + # return grad + # if self.training: + # out.register_hook(cap_grad) + return out, mask + + +class DomainProtoMaskEncoder(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + The length of the two part are equal in this implementation. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + bn=True, + split_inv=True, + use_prototype=True, + mask_assign=False, + hard_assign=False, + orth_proto=False, + grad_hook=False): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + self.mask_assign = mask_assign + self.hard_assign = hard_assign + self.orth_proto = orth_proto + self.split_inv = split_inv + self.use_prototype = use_prototype + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + + self.latents = nn.Parameter(torch.empty(num_latents, self.latent_dim), + requires_grad=True) + nn.init.orthogonal_(self.latents) + self.specific_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, + dim_out), # nn.Linear(flatten_dim, self.num_latents) + ) + self.sigmoid = nn.Sigmoid() + self.mask_attn_layer = PreNorm(dim_out, + Attention(query_dim=self.latent_dim, + context_dim=self.latent_dim, + heads=num_heads, + dim_head=dim_head), + context_dim=self.latent_dim) + self.mask_ff_layer = PreNorm(dim_out, + nn.Linear(dim_out, self.num_latents)) + + self.specific_encoder_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.specific_encoder_layers.append( + nn.ModuleList([ + PreNorm(dim_out, + Attention(query_dim=self.latent_dim, + context_dim=self.latent_dim, + heads=num_heads, + dim_head=dim_head), + context_dim=self.latent_dim), + PreNorm(dim_out, nn.Linear(dim_out, dim_out)) + ])) + + self.init_latents = self.latents.detach() + + def forward(self, x): + b = x.shape[0] + h = self.share_encoder(x) + mask = None + + sh = self.specific_ffn(h)[:, None] + if self.orth_proto: + q, r = torch.linalg.qr(self.latents.T) + latents = repeat(q.T, 'n d -> b n d', b=b) + else: + latents = repeat(self.latents, 'n d -> b n d', b=b) + if self.mask_assign: + mask_h = self.mask_ffn(h)[:, None] + mask_sh = self.mask_attn_layer(mask_h, context=latents) + mask_h + mask_logit = self.mask_ff_layer(mask_sh).squeeze(1) + + # mask_logit = self.mask_ffn(h) + if self.hard_assign: # hard assign + mask_prob = self.sigmoid(mask_logit) + mask = mask_prob > 0.5 # torch.bernoulli(mask_logit) + mask = (mask.float() - mask_prob).detach() + mask_prob + else: + mask = mask_logit # soft assign + + for attn, ff in self.specific_encoder_layers: + sh = attn( + sh, context=latents, mask=mask, + hard_assign=self.hard_assign) + sh + sh = ff(sh) + sh # b, 1, d + + out = sh # b, 1, d + + return out, mask + + +class DomainUnifiedPrototyper(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + The length of the two part are equal in this implementation. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + bn=True, + hard_assign=False, + orth_proto=False, + grad_hook=False, + **kwargs): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + self.hard_assign = hard_assign + self.orth_proto = orth_proto + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + # self.latents = nn.Parameter(torch.randn(num_latents, self.latent_dim), requires_grad=True) + self.latents = nn.Parameter(torch.empty(num_latents, self.latent_dim), + requires_grad=False) + nn.init.orthogonal_(self.latents) + self.init_latents = copy.deepcopy(self.latents.detach()) + # self.specific_ffn = nn.Sequential( + # ResBlockTime(dim, dim, bn=bn), + # View((-1, flatten_dim)), # batch_size x 2048 + # nn.Linear(flatten_dim, dim_out) + # ) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.num_latents), + ) + self.sigmoid = nn.Sigmoid() + + self.grad_hook = grad_hook + + def forward(self, x): + # if self.grad_hook: + # handle = self.latents.register_hook(self.hook_func) + b = x.shape[0] + h = self.share_encoder(x) + mask = None + + if self.orth_proto: + q, r = torch.linalg.qr(self.latents.T) + latents = repeat(q.T, 'n d -> b n d', b=b) + else: + latents = repeat(self.latents, 'n d -> b n d', b=b) + mask_logit = self.mask_ffn(h) + if self.hard_assign: # hard assign + mask_prob = self.sigmoid(mask_logit) + mask = mask_prob > 0.5 # torch.bernoulli(mask_logit) + mask = (mask.float() - mask_prob).detach() + mask_prob + else: + mask = mask_logit # soft assign + + out = latents # mask + return out, mask + + +class DomainEmbProtoMask(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + The length of the two part are equal in this implementation. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + bn=True, + split_inv=False, + hard_assign=False, + orth_proto=False, + grad_hook=False, + **kwargs): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + self.hard_assign = hard_assign + self.orth_proto = orth_proto + self.split_inv = split_inv + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + # self.latents = nn.Parameter(torch.randn(num_latents, self.latent_dim), requires_grad=True) + self.latents = nn.Parameter(torch.empty(num_latents, self.latent_dim), + requires_grad=False) + nn.init.orthogonal_(self.latents) + self.init_latents = copy.deepcopy(self.latents.detach()) + self.specific_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.num_latents), + ) + self.sigmoid = nn.Sigmoid() + + self.grad_hook = grad_hook + + def forward(self, x): + # if self.grad_hook: + # handle = self.latents.register_hook(self.hook_func) + b = x.shape[0] + h = self.share_encoder(x) + mask = None + + if self.orth_proto: + q, r = torch.linalg.qr(self.latents.T) + latents = repeat(q.T, 'n d -> b n d', b=b) + else: + latents = repeat(self.latents, 'n d -> b n d', b=b) + mask_logit = self.mask_ffn(h) + if self.hard_assign: # hard assign + mask_prob = self.sigmoid(mask_logit) + mask = mask_prob > 0.5 # torch.bernoulli(mask_logit) + mask = (mask.float() - mask_prob).detach() + mask_prob + else: + mask = mask_logit # soft assign + + if self.split_inv: + sh = self.specific_ffn(h)[:, None] # b, 1, d + emb_mask = torch.ones(b, 1).to(x.device).float() + out = torch.cat((sh, latents), dim=1) # latents # mask + out_mask = torch.cat((emb_mask, mask), dim=1) + else: + out = latents # mask + out_mask = mask + return out, out_mask + + +class DomainEmbProtoAssignMask(nn.Module): + ''' + The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. + The length of the two part are equal in this implementation. + ''' + + def __init__(self, + dim, + window, + num_heads=1, + num_layers=1, + num_latents=16, + num_channels=3, + latent_dim=32, + dim_head=64, + bn=True, + split_inv=False, + mask_method='soft', + orth_proto=False, + grad_hook=False, + **kwargs): + super().__init__() + self.num_latents = num_latents + self.latent_dim = latent_dim + self.mask_method = mask_method + self.orth_proto = orth_proto + self.split_inv = split_inv + dim_out = latent_dim + flatten_dim = int(dim * window / 4) + self.share_encoder = nn.Sequential( + nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), + nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True)) + self.latents = nn.Parameter(torch.empty(num_latents, self.latent_dim), + requires_grad=False) + nn.init.orthogonal_(self.latents) + self.init_latents = copy.deepcopy(self.latents.detach()) + self.specific_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, dim_out)) + self.mask_ffn = nn.Sequential( + ResBlockTime(dim, dim, bn=bn), + View((-1, flatten_dim)), # batch_size x 2048 + nn.Linear(flatten_dim, self.num_latents), + ) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU() + + self.grad_hook = grad_hook + + def forward(self, x): + b = x.shape[0] + h = self.share_encoder(x) + mask = None + + if self.orth_proto: + q, r = torch.linalg.qr(self.latents.T) + latents = repeat(q.T, 'n d -> b n d', b=b) + else: + latents = repeat(self.latents, 'n d -> b n d', b=b) + mask_logit = self.mask_ffn(h) + if self.mask_method == 'hard': # hard assign + mask_prob = self.sigmoid(mask_logit) + mask = mask_prob > 0.5 # torch.bernoulli(mask_logit) + mask = (mask.float() - mask_prob).detach() + mask_prob + elif self.mask_method == 'soft': + mask = mask_logit # soft assign + elif self.mask_method == 'inter': + # mask_prob = self.sigmoid(mask_logit) + # mask = (mask_logi-0.5) * 2 + mask = self.relu(mask_logit) + # mask_of_mask = torch.where(mask > 0, torch.zeros_like(mask), torch.ones_like(mask)) + # max_neg_value = -torch.finfo(mask.dtype).max + # # mask.masked_fill_(mask_of_mask, max_neg_value) + # mask = mask_of_mask * max_neg_value + mask + + if self.split_inv: + sh = self.specific_ffn(h)[:, None] # b, 1, d + emb_mask = torch.ones(b, 1).to(x.device).float() + out = torch.cat((sh, latents), dim=1) # latents # mask + out_mask = torch.cat((emb_mask, mask), dim=1) + else: + out = latents # mask + out_mask = mask + return out, out_mask diff --git a/TarDiff/ldm/modules/guidance_scorer.py b/TarDiff/ldm/modules/guidance_scorer.py new file mode 100755 index 0000000..8330cd8 --- /dev/null +++ b/TarDiff/ldm/modules/guidance_scorer.py @@ -0,0 +1,469 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import TensorDataset, DataLoader +from tqdm import tqdm +import time +from scipy.stats import truncnorm +import math + +import torch.nn.functional as F + +import numpy as np + + +class GradDotCalculatorformer: + + def __init__(self, + model, + train_loader, + criterion, + device='cuda', + normalize='l2', + gd_scale=1, + mmd_scale=1, + pos=True, + neg=True): + """ + Initialize calculator with cached training gradients sum + + Args: + model: trained PyTorch model + train_loader: DataLoader containing training data + criterion: loss function + device: computation device + """ + self.model = model + self.criterion = criterion + self.device = device + self.model.to(device) + self.criterion.to(device) + self.normalize = normalize + # Cache the sum of training gradients + self.train_loader = train_loader + self.positive_guidance = pos + self.negative_guidance = neg + self.cached_train_grads = self._compute_train_grad_sum(train_loader) + self.gd_scale = gd_scale + self.mmd_scale = mmd_scale + + def compute_mmd_grad(self, + x: torch.Tensor, + kernel='linear') -> torch.Tensor: + + y = torch.stack([x[0] for x in self.train_loader]).to(self.device) + x_flat = x.reshape(x.size(0), -1) # (B1, D*T) + y_flat = y.reshape(y.size(0), -1).float() # (B2, D*T) + if kernel == 'linear': + K_xx = torch.mm(x_flat, x_flat.t()) # (B1, B1) + K_yy = torch.mm(y_flat, y_flat.t()) # (B2, B2) + K_xy = torch.mm(x_flat, y_flat.t()) # (B1, B2) + elif kernel == 'rbf': + gamma = 1.0 / x_flat.size(-1) # 带宽参数 + pairwise_xx = torch.cdist(x_flat, x_flat, p=2) # (B1, B1) + K_xx = torch.exp(-gamma * pairwise_xx**2) + pairwise_yy = torch.cdist(y_flat, y_flat, p=2) # (B2, B2) + K_yy = torch.exp(-gamma * pairwise_yy**2) + pairwise_xy = torch.cdist(x_flat, y_flat, p=2) # (B1, B2) + K_xy = torch.exp(-gamma * pairwise_xy**2) + else: + raise ValueError("Unsupported kernel type") + + m = x_flat.size(0) + n = y_flat.size(0) + mmd = (K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()) + mmd_grad = torch.autograd.grad(mmd, x) + + return self._normalize_gradients(mmd_grad) + + def _normalize_gradients(self, grads): + """ + Normalize gradients using specified method + """ + if self.normalize == 'l2': + # Compute total L2 norm across all parameters + total_norm = torch.sqrt(sum((g**2).sum() for g in grads)) + return [g / (total_norm + 1e-6) for g in grads] + elif self.normalize == 'l1': + # Compute total L1 norm across all parameters + total_norm = sum(g.abs().sum() for g in grads) + return [g / (total_norm + 1e-6) for g in grads] + else: + return grads + + def _compute_train_grad_sum(self, train_loader): + """Compute and cache the sum of all training samples' gradients""" + + # Initialize gradient sum + params = list(self.model.parameters()) + total_loss = torch.tensor( + 0.0, + device=self.device, + ) + # Accumulate loss for all training samples + train_loader = DataLoader(train_loader, batch_size=1, shuffle=False) + for train_inputs, train_labels in tqdm(train_loader): + if (train_labels == 0 and not self.negative_guidance) or ( + train_labels == 1 and not self.positive_guidance): + continue + + train_inputs = train_inputs.to(self.device).to(torch.float32) + train_labels = train_labels.to(self.device).to(torch.long) + + mask = torch.full((train_inputs.shape[0], train_inputs.shape[1]), + True, + dtype=bool, + device=self.device) + outputs = self.model(train_inputs, mask, None, None) + total_loss += self.criterion(outputs, train_labels) + + # Get gradient of total loss + start_time = time.time() + grad_sum = torch.autograd.grad(total_loss, params, allow_unused=True) + filtered = [(p, g) for p, g in zip(params, grad_sum) if g is not None] + if not filtered: + raise ValueError("No parameter received gradient!") + filtered_params, filtered_grads = zip(*filtered) + print(f"Gradient computation time: {time.time() - start_time:.2f}s") + torch.cuda.empty_cache() + return self._normalize_gradients(filtered_grads) + #return filtered_grads + def compute_gradient(self, test_sample, test_label): + """ + Compute gradient of grad-dot using cached training gradients + + Args: + test_sample: single test input tensor + test_label: ground truth label for test sample + """ + self.model.eval() + test_sample = test_sample.transpose(2, 1) + # Prepare test sample + test_sample = test_sample.detach().requires_grad_(True) + test_sample = test_sample.to(self.device) + #test_label = torch.tensor([test_label]).to(self.device) + test_label = test_label.to(self.device) + # Get test gradient + mask = torch.full((test_sample.shape[0], test_sample.shape[1]), + True, + dtype=bool, + device=self.device) + test_output = self.model(test_sample, mask, None, None) + test_loss = self.criterion(test_output, test_label) + test_grads = torch.autograd.grad(test_loss, + self.model.parameters(), + create_graph=True, + allow_unused=True) + filtered = [(p, g) for p, g in zip(self.model.parameters(), test_grads) + if g is not None] + if not filtered: + raise ValueError("No parameter received gradient!") + filtered_params, filtered_grads = zip(*filtered) + test_grads = self._normalize_gradients(filtered_grads) + # Compute single dot product with cached gradients sum + total_dot = sum( + (test_g * cached_g).sum() + for test_g, cached_g in zip(test_grads, self.cached_train_grads)) + # Get gradient w.r.t test_sample + grad_wrt_sample = torch.autograd.grad(total_dot, + test_sample, + create_graph=False)[0] + #mmd_grad=self.compute_mmd_grad(test_sample)[0].squeeze(0).transpose(2,1) + gd_grad = grad_wrt_sample.squeeze(0).transpose(2, 1) + + torch.cuda.empty_cache() + dynamic_scale = self.gd_scale + + return dynamic_scale * gd_grad + + def compute_influence(self, test_sample, test_label, t): + """ + Compute gradient of grad-dot using cached training gradients + + Args: + test_sample: single test input tensor + test_label: ground truth label for test sample + """ + self.model.eval() + + test_sample = test_sample.transpose(2, 1) + # Prepare test sample + test_sample = test_sample.detach().requires_grad_(True) + test_sample = test_sample.to(self.device) + #test_label = torch.tensor([test_label]).to(self.device) + test_label = test_label.to(self.device) + # Get test gradient + mask = torch.full((test_sample.shape[0], test_sample.shape[1]), + True, + dtype=bool, + device=self.device) + test_output = self.model(test_sample, mask, None, None, t_idx=t) + test_loss = self.criterion(test_output, test_label) + test_grads = torch.autograd.grad(test_loss, + self.model.parameters(), + create_graph=True, + allow_unused=True) + filtered = [(p, g) for p, g in zip(self.model.parameters(), test_grads) + if g is not None] + if not filtered: + raise ValueError("No parameter received gradient!") + filtered_params, filtered_grads = zip(*filtered) + test_grads = self._normalize_gradients(filtered_grads) + # Compute single dot product with cached gradients sum + total_dot = sum( + (test_g * cached_g).sum() + for test_g, cached_g in zip(test_grads, self.cached_train_grads)) + torch.cuda.empty_cache() + total_dot = total_dot.detach().cpu().numpy() + return total_dot + + def compute_per_class_grad_norm_stats(self): + """ + For each sample in the training set: + 1. Compute the gradient of the loss w.r.t. all model parameters; + 2. Collapse each parameter-wise gradient to its L2 norm, then obtain a + single scalar gradient norm for that sample; + 3. Group these norms by class label and report mean ± std for every class. + + Returns + ------- + stats : dict + keys → class labels + values → (mean_norm, std_norm) + """ + self.model.eval() + grad_norms = {} + + # Iterate one sample at a time so we can attribute a single norm per example. + data_loader = DataLoader(self.train_loader, + batch_size=1, + shuffle=False) + + for inputs, labels in tqdm(data_loader, + desc="Computing per-class grad norms"): + inputs = inputs.to(self.device).to(torch.float32) + labels = labels.to(self.device) + + # Create a full-True mask because the backbone expects it + mask = torch.full((inputs.shape[0], inputs.shape[1]), + True, + dtype=torch.bool, + device=self.device) + + outputs = self.model(inputs, mask, None, None) + loss = self.criterion(outputs, labels) + + # Gradients w.r.t. all parameters + grads = torch.autograd.grad(loss, + self.model.parameters(), + create_graph=False, + allow_unused=True) + + # Discard parameters that produced None (e.g. frozen layers) + filtered_grads = [g for g in grads if g is not None] + if not filtered_grads: + continue + + # L2 norm over all parameters: sqrt(Σ‖g‖₂²) + grad_norm_sq = sum(torch.sum(g**2) for g in filtered_grads) + grad_norm = torch.sqrt(grad_norm_sq) + + # Batch size is 1 here, so labels is a scalar + label_val = labels.item() if labels.numel( + ) == 1 else labels.tolist()[0] + grad_norms.setdefault(label_val, []).append(grad_norm.item()) + + stats = {} + for label_val, norm_list in grad_norms.items(): + mean, std = np.mean(norm_list), np.std(norm_list) + stats[label_val] = (mean, std) + print(f"Class {label_val}: {mean:.4f} ± {std:.4f}") + return stats + + def compute_classifier_guidance(self, test_sample, test_label): + """ + Vectorised computation of ∇ₓ log p(y|x) for a batch. + + Parameters + ---------- + test_sample : Tensor [B, D, T] + Raw batch (before transpose). D = channels, T = timesteps. + test_label : Tensor [B] or [B, 1] + Ground-truth class indices. + + Returns + ------- + guidance_grad : Tensor [B, D, T] + Input-space gradients scaled by self.gd_scale. + """ + self.model.eval() + + # Model expects (B, T, D); enable gradient on inputs + test_sample = (test_sample.transpose( + 2, 1).detach().requires_grad_(True).to(self.device)) + test_label = test_label.to(self.device) + if test_label.ndim > 1: + test_label = test_label.squeeze(-1) + + B = test_sample.shape[0] + + mask = torch.full((B, test_sample.shape[1]), + True, + dtype=torch.bool, + device=self.device) + logits = self.model(test_sample, mask, None, None) # [B, C] + log_probs = F.log_softmax(logits, dim=1) # [B, C] + + # Select log-probability of the true class for each sample + log_selected = log_probs.gather(1, + test_label.view(-1, + 1)).squeeze(1) # [B] + + # Sum over batch to obtain a scalar; gradients remain sample-wise + loss = log_selected.sum() + grad = torch.autograd.grad(loss, + test_sample, + create_graph=False, + retain_graph=False, + allow_unused=True)[0] + + return grad.transpose(2, 1) * self.gd_scale + + +class GradDotCalculator: + + def __init__(self, + model, + train_loader, + criterion, + gd_scale=1, + device='cuda', + normalize='l2'): + """ + Compute and cache the gradient sum over the *entire* training set, + then provide fast “grad-dot” calculations for test samples. + + Parameters + ---------- + model : torch.nn.Module + A pre-trained network (e.g., RNN). + train_loader : DataLoader + Loader that iterates over the full training dataset. + criterion : callable + Loss function used for both training and evaluation. + gd_scale : float, default=1 + Optional scaling factor applied to the final input-space gradient. + device : {'cuda', 'cpu'}, default='cuda' + Computation device. + normalize : {'l2', 'l1', None}, default='l2' + How to normalise parameter-space gradients before dot-product. + """ + self.model = model.to(device) + self.criterion = criterion.to(device) + self.device = device + self.normalize = normalize + self.gd_scale = gd_scale + self.cached_train_grads = self._compute_train_grad_sum(train_loader) + + def _normalize_gradients(self, grads): + """ + Apply the requested normalisation to a list of parameter gradients. + + Parameters + ---------- + grads : list[Tensor] + One tensor per parameter. + + Returns + ------- + list[Tensor] + Normalised gradients (or originals if `normalize` is None). + """ + if self.normalize == 'l2': + total_norm = torch.sqrt(sum((g**2).sum() for g in grads)) + return [g / (total_norm + 1e-6) for g in grads] + elif self.normalize == 'l1': + total_norm = sum(g.abs().sum() for g in grads) + return [g / (total_norm + 1e-6) for g in grads] + else: + return grads + + def _compute_train_grad_sum(self, train_loader): + """ + Accumulate the loss over the *whole* training set, then take a single + backward pass to obtain Σ∇θ L. This treats every sample as if it were + in one giant batch. Beware of memory usage on very large datasets. + + Returns + ------- + list[Tensor] + Normalised gradient sum, one tensor per trainable parameter. + """ + params = list(self.model.parameters()) + total_loss = torch.tensor(0.0, device=self.device) + + for inputs, labels in tqdm(train_loader, + desc="Compute Train Grad Sum"): + inputs = inputs.to(self.device) + labels = labels.to(self.device).to(torch.float32) + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + total_loss += loss + + grad_sum = torch.autograd.grad(total_loss, params, allow_unused=True) + filtered = [(p, g) for p, g in zip(params, grad_sum) if g is not None] + _, filtered_grads = zip(*filtered) + grad_sum = self._normalize_gradients(filtered_grads) + torch.cuda.empty_cache() + return grad_sum + + def compute_gradient(self, test_sample, test_label): + """ + Compute the input-space gradient induced by a “grad-dot” score: + + ⟨∇θ L_test, Σ∇θ L_train⟩ + + Steps + ----- + 1. Back-prop through the test loss to obtain ∇θ L_test. + 2. Dot-product with the cached training-set gradient sum. + 3. Back-prop that scalar w.r.t. the *input* to get ∇ₓ score. + + Parameters + ---------- + test_sample : Tensor [B, D, T] + Input time-series (channels-first, will be transposed internally). + test_label : Tensor + Corresponding ground-truth labels. + + Returns + ------- + Tensor [B, D, T] + Input-space gradient scaled by `gd_scale`. + """ + # Model expects (B, T, D) + test_sample = test_sample.transpose(2, 1).to(self.device) + test_sample.requires_grad_(True) + + test_label = test_label.to(self.device).to(torch.float32) + + with torch.backends.cudnn.flags(enabled=False): + test_output = self.model(test_sample) + test_loss = self.criterion(test_output, test_label) + + test_grads = torch.autograd.grad(test_loss, + self.model.parameters(), + create_graph=True) + + filtered = [(p, g) for p, g in zip(self.model.parameters(), test_grads) + if g is not None] + _, test_grads = zip(*filtered) + test_grads = self._normalize_gradients(test_grads) + + total_dot = sum((tg * cg).sum() + for tg, cg in zip(test_grads, self.cached_train_grads)) + + grad_wrt_sample = torch.autograd.grad(total_dot, test_sample)[0] + return grad_wrt_sample.transpose(2, 1) * self.gd_scale diff --git a/TarDiff/ldm/modules/losses/__init__.py b/TarDiff/ldm/modules/losses/__init__.py new file mode 100755 index 0000000..6744a52 --- /dev/null +++ b/TarDiff/ldm/modules/losses/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator +from ldm.modules.losses.vqperceptual import VQLPIPSWithDiscriminator \ No newline at end of file diff --git a/TarDiff/ldm/modules/losses/contperceptual.py b/TarDiff/ldm/modules/losses/contperceptual.py new file mode 100755 index 0000000..5c29bfc --- /dev/null +++ b/TarDiff/ldm/modules/losses/contperceptual.py @@ -0,0 +1,154 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + + def __init__(self, + disc_start, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, + last_layer, + retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, + last_layer, + retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, + self.last_layer[0], + retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, + self.last_layer[0], + retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, + inputs, + reconstructions, + posteriors, + optimizer_idx, + global_step, + last_layer=None, + cond=None, + split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - + reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), + reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum( + weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, + global_step, + threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator( + reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous().detach(), cond), + dim=1)) + + disc_factor = adopt_weight(self.disc_factor, + global_step, + threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/TarDiff/ldm/modules/losses/vqperceptual.py b/TarDiff/ldm/modules/losses/vqperceptual.py new file mode 100755 index 0000000..7f7b46d --- /dev/null +++ b/TarDiff/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,217 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1, 2, 3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1, 2, 3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, + n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + + +def l1(x, y): + return torch.abs(x - y) + + +def l2(x, y): + return torch.pow((x - y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + + def __init__(self, + disc_start, + codebook_weight=1.0, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_ndf=64, + disc_loss="hinge", + n_classes=None, + perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError( + f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, + last_layer, + retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, + last_layer, + retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, + self.last_layer[0], + retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, + self.last_layer[0], + retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, + codebook_loss, + inputs, + reconstructions, + optimizer_idx, + global_step, + last_layer=None, + cond=None, + split="train", + predicted_indices=None): + if codebook_loss == None: + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), + reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), + reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, + global_step, + threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean( + ) + + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity( + predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator( + reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous().detach(), cond), + dim=1)) + + disc_factor = adopt_weight(self.disc_factor, + global_step, + threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/TarDiff/ldm/util.py b/TarDiff/ldm/util.py new file mode 100755 index 0000000..c2d8caf --- /dev/null +++ b/TarDiff/ldm/util.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont +import ssl + +ssl._create_default_https_context = ssl._create_unverified_context + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] + for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print( + f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params." + ) + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch(func: callable, + data, + n_proc, + target_data_type="ndarray", + cpu_intensive=True, + use_worker_id=False): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [[func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc))] + else: + step = (int(len(data) / n_proc + 1) if len(data) % + n_proc != 0 else int(len(data) / n_proc)) + arguments = [[func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i:i + step] + for i in range(0, len(data), step)])] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/TarDiff/train.sh b/TarDiff/train.sh new file mode 100755 index 0000000..d3017e4 --- /dev/null +++ b/TarDiff/train.sh @@ -0,0 +1,2 @@ +python train.py --base configs/base/mimic_icustay_base.yaml --gpus 0, --uncond --logdir ts_diff_uncond_testing/mimic_icustay_base -sl 24 --batch_size 128 --max_steps 20000 -lr 0.0001 -s 42 + diff --git a/TarDiff/train_main.py b/TarDiff/train_main.py new file mode 100755 index 0000000..684ca6c --- /dev/null +++ b/TarDiff/train_main.py @@ -0,0 +1,458 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import argparse, os, sys, datetime +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +from pytorch_lightning.trainer import Trainer +from utils.callback_utils import prepare_trainer_configs +from ldm.util import instantiate_from_config +from pathlib import Path + + +def get_parser(**parser_kwargs): + + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument("-n", + "--name", + type=str, + const=True, + default="", + nargs="?", + help="postfix for logdir") + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-t", + "--train", + type=str2bool, + const=True, + default=True, + nargs="?", + help="train", + ) + parser.add_argument( + "-r", + "--resume", + type=str2bool, + const=True, + default=False, + nargs="?", + help="resume and test", + ) + parser.add_argument( + "--no-test", + type=str2bool, + const=True, + default=False, + nargs="?", + help="disable test", + ) + parser.add_argument( + "--normalize", + type=str, + const=True, + default=None, + nargs="?", + help="normalization method", + ) + parser.add_argument("-p", + "--project", + help="name of new or path to existing project") + parser.add_argument( + "-d", + "--debug", + type=str2bool, + nargs="?", + const=True, + default=False, + help="enable post-mortem debugging", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=23, + help="seed for seed_everything", + ) + parser.add_argument( + "-f", + "--postfix", + type=str, + default="", + help="post-postfix for default name", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + default="/mnt/storage/ts_diff_newer", + help="directory for logging dat shit", + ) + parser.add_argument( + "--scale_lr", + type=str2bool, + nargs="?", + const=True, + default=False, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + parser.add_argument( + "-dw", + "--dis_weight", + type=float, + const=True, + default=1., + nargs="?", + help="weight of disentangling loss", + ) + parser.add_argument( + "-dt", + "--dis_loss_type", + type=str, + const=True, + default=None, + nargs="?", + help="type of disentangling loss", + ) + parser.add_argument( + "-tg", + "--train_stage", + type=str, + const=True, + default='pre', + nargs="?", + help="pre / dis", + ) + parser.add_argument( + "-ds", + "--dataset_name", + type=str, + const=True, + default='elec', + nargs="?", + help="dataset name", + ) + parser.add_argument( + "-dp", + "--dataset_prefix", + type=str, + const=True, + default='/mnt/storage/tsdiff/data', + nargs="?", + help="dataset prefix", + ) + parser.add_argument( + "-cp", + "--ckpt_prefix", + type=str, + const=True, + default='/mnt/storage/tsdiff/outputs', + nargs="?", + help="ckpt prefix", + ) + parser.add_argument( + "-sp", + "--sample_path", + type=str, + const=True, + default='/mnt/storage/ts_generated/ours_amlt', + nargs="?", + help="samples prefix", + ) + + parser.add_argument( + "-pl", + "--pair_loss_type", + type=str, + const=True, + default='', + nargs="?", + help="pair loss type: cosine or l2, otherwise not used") + parser.add_argument("-sl", + "--seq_len", + type=int, + const=True, + default=24, + nargs="?", + help="sequence length") + parser.add_argument("-uc", + "--uncond", + action='store_true', + help="unconditional generation") + parser.add_argument("-si", + "--split_inv", + action='store_true', + help="split invariant encoder") + parser.add_argument("-cl", + "--ce_loss", + action='store_true', + help="cross entropy loss") + parser.add_argument("-up", + "--use_prototype", + action='store_true', + help="use prototype") + parser.add_argument("-pd", + "--part_drop", + action='store_true', + help="use partial dropout conditions") + parser.add_argument("-o", + "--orth_emb", + action='store_true', + help="use orthogonal prototype embedding") + parser.add_argument("-ma", + "--mask_assign", + action='store_true', + help="use mask assignment") + parser.add_argument("-ha", + "--hard_assign", + action='store_true', + help="use hard assignment") + parser.add_argument("-im", + "--inter_mask", + action='store_true', + help="use intermediate assignment") + parser.add_argument("-bs", + "--batch_size", + type=int, + const=True, + default=256, + nargs="?", + help="batch_size") + parser.add_argument("-ms", + "--max_step_sum", + type=int, + const=True, + default=20000, + nargs="?", + help="max training steps") + parser.add_argument("-nl", + "--num_latents", + type=int, + const=True, + default=16, + nargs="?", + help="sequence length") + parser.add_argument("-pw", + "--pair_weight", + type=float, + const=True, + default=1.0, + nargs="?", + help="pair loss weight") + parser.add_argument("-lr", + "--overwrite_learning_rate", + type=float, + const=True, + default=None, + nargs="?", + help="learning rate") + + return parser + + +def nondefault_trainer_args(opt): + parser = argparse.ArgumentParser() + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args([]) + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + + +if __name__ == "__main__": + + data_root = '/home/v-dengbowen/mount/' + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + sys.path.append(os.getcwd()) + + parser = get_parser() + parser = Trainer.add_argparse_args(parser) + + opt, unknown = parser.parse_known_args() + + if opt.name: + name = opt.name + elif opt.base: + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = cfg_name + else: + name = "" + + seed_everything(opt.seed) + + # try: + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + # Customize config from opt: + n_data = len(config.data['params']['data_path_dict']) + config.model['params']['image_size'] = opt.seq_len + config.model['params']['unet_config']['params']['image_size'] = opt.seq_len + config.data['params']['window'] = opt.seq_len + config.data['params']['batch_size'] = opt.batch_size + bs = opt.batch_size + if opt.max_steps: + config.lightning['trainer']['max_steps'] = opt.max_steps + max_steps = opt.max_steps + else: + max_steps = config.lightning['trainer']['max_steps'] + if opt.debug: + config.lightning['trainer']['max_steps'] = 10 + config.lightning['callbacks']['image_logger']['params'][ + 'batch_frequency'] = 5 + max_steps = 10 + if opt.overwrite_learning_rate is not None: + config.model['base_learning_rate'] = opt.overwrite_learning_rate + print( + f"Setting learning rate (overwritting config file) to {opt.overwrite_learning_rate:.2e}" + ) + base_lr = opt.overwrite_learning_rate + else: + base_lr = config.model['base_learning_rate'] + + nowname = f"{name.split('-')[-1]}_{opt.seq_len}_nl_{opt.num_latents}_lr{base_lr:.1e}_bs{opt.batch_size}_ms{int(max_steps/1000)}k" + + if opt.normalize is not None: + config.data['params']['normalize'] = opt.normalize + nowname += f"_{config.data['params']['normalize']}" + else: + assert 'normalize' in config.data['params'] + nowname += f"_{config.data['params']['normalize']}" + + config.model['params']['pair_loss_flag'] = False + if opt.uncond: + config.model['params']['cond_stage_config'] = "__is_unconditional__" + config.model['params']['cond_stage_trainable'] = False + nowname += f"_uncond" + + nowname += f"_seed{opt.seed}" + # nowname = nowname + logdir = os.path.join(opt.logdir, cfg_name, nowname) + if not os.path.exists(logdir): + os.makedirs(logdir) + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") + + lightning_config = config.pop("lightning", OmegaConf.create()) + # merge trainer cli with config + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + # default to ddp + trainer_config["accelerator"] = "gpu" + for k in nondefault_trainer_args(opt): + trainer_config[k] = getattr(opt, k) + if not "gpus" in trainer_config: + del trainer_config["accelerator"] + cpu = True + else: + gpuinfo = trainer_config["gpus"] + print(f"Running on GPUs {gpuinfo}") + cpu = False + trainer_opt = argparse.Namespace(**trainer_config) + lightning_config.trainer = trainer_config + + # model + if "LatentDiffusion" in config.model['target']: + if opt.dis_loss_type != None: + config.model["params"]["dis_loss_type"] = opt.dis_loss_type + config.model["params"]["dis_weight"] = opt.dis_weight + + if opt.resume: + ckpt_path = logdir + '/' + 'checkpoints' + "/" + 'last.ckpt' + config.model['params']['ckpt_path'] = ckpt_path + + print(f"Loading model........") + model = instantiate_from_config(config.model) + # trainer and callbacks + trainer_kwargs = prepare_trainer_configs(nowname, logdir, opt, + lightning_config, ckptdir, model, + now, cfgdir, config, trainer_opt) + trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) + trainer.logdir = logdir ### + + # data + for k, v in config.data.params.data_path_dict.items(): + config.data.params.data_path_dict[k] = v.replace( + '/mnt/storage/', data_root) + + print("Preparing data.......") + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + print("#### Data Preparation Finished #####") + + # print(f"Train: {data.train_shape}, Validation: {data.val_shape}, Test: {data.test_shape}") + # for k in data.datasets: + # print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + + # configure learning rate + # bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate + if not cpu: + ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) + else: + ngpu = 1 + if 'accumulate_grad_batches' in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + else: + accumulate_grad_batches = 1 + print(f"accumulate_grad_batches = {accumulate_grad_batches}") + lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches + if opt.scale_lr: + model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + print( + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)" + .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, + base_lr)) + else: + model.learning_rate = base_lr + print("++++ NOT USING LR SCALING ++++") + print(f"Setting learning rate to {model.learning_rate:.2e}") + + # allow checkpointing via USR1 + def melk(*args, **kwargs): + # run all checkpoint hooks + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb + pudb.set_trace() + + import signal + + signal.signal(signal.SIGUSR1, melk) + signal.signal(signal.SIGUSR2, divein) + # run + print(f"Starting training................") + if opt.train: + try: + trainer.logger.experiment.config.update(opt) + trainer.fit(model, data) + except Exception: + melk() + raise + print("Training finished!") diff --git a/TarDiff/utils/__init__.py b/TarDiff/utils/__init__.py new file mode 100755 index 0000000..0eca642 --- /dev/null +++ b/TarDiff/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. \ No newline at end of file diff --git a/TarDiff/utils/callback_utils.py b/TarDiff/utils/callback_utils.py new file mode 100755 index 0000000..f4c05d0 --- /dev/null +++ b/TarDiff/utils/callback_utils.py @@ -0,0 +1,479 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import argparse, os, sys, datetime, glob, importlib, csv +import numpy as np +import time +import torch +import torchvision +import pytorch_lightning as pl + +from packaging import version +from omegaconf import OmegaConf +from torch.utils.data import random_split, DataLoader, Dataset, Subset +from functools import partial +from PIL import Image +import wandb +from pytorch_lightning import seed_everything +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor +from pytorch_lightning.utilities.distributed import rank_zero_only +# from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info + +from ldm.util import instantiate_from_config +from pathlib import Path +import matplotlib.pyplot as plt +import pandas as pd + + +class SetupCallback(Callback): + + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, + lightning_config): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + self.config = config + self.lightning_config = lightning_config + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def on_fit_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + if "callbacks" in self.lightning_config: + if 'metrics_over_trainsteps_checkpoint' in self.lightning_config[ + 'callbacks']: + os.makedirs(os.path.join(self.ckptdir, + 'trainstep_checkpoints'), + exist_ok=True) + print("Project config") + print(OmegaConf.to_yaml(self.config)) + OmegaConf.save( + self.config, + os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + + print("Lightning config") + print(OmegaConf.to_yaml(self.lightning_config)) + OmegaConf.save( + OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, + "{}-lightning.yaml".format(self.now))) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass + + +def plot_naming(k, bi, ni): # the ni-th row, bi-th column + name = '' + if k == 'progressive_row_recon': + name = f"sample{bi} recon at diffstep{1000 - ni*200}" + if k == 'progressive_row_inter': + name = f"sample{bi} inter at diffstep{1000 - ni*200}" + if k == 'samples_swapping': + if ni == 0: + name = f'input {bi} (encoded source)' + elif ni == 1: + name = f'sample {bi} (reconstruction)' + else: + if bi == 0: + name = f'concept {ni-1} source' + else: + name = f'sample {bi} swapped concept {ni-1}' + # else: + # name = f"sample {ni} swapped concept {bi-1}" + if k == 'samples_swapping_intercept': + if ni == 0: + name = f'concept {bi} intercept source' + if ni == 6: + name = f'concept {bi} intercept target' + if k == 'diffusion_row': + name = f'sample{bi} diffstep {ni*200}' + # name = f'scale {bi}' + return name + + +class TSLogger(Callback): + + def __init__(self, + batch_frequency, + max_images, + clamp=False, + increase_log_steps=True, + rescale=True, + disabled=False, + log_on_batch_idx=False, + log_first_step=False, + log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = { + # pl.loggers.TestTubeLogger: self._testtube, + pl.loggers.CSVLogger: + self._testtube, + } + self.log_steps = [ + 2**n for n in range(int(np.log2(self.batch_freq)) + 1) + ] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + @rank_zero_only + def _testtube(self, pl_module, images, batch_idx, split): + for k in images: + grid = torchvision.utils.make_grid(images[k]) + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + + tag = f"{split}/{k}" + # pl_module.logger.experiment.add_image( + # tag, grid, + # global_step=pl_module.global_step) + + @rank_zero_only + def log_local(self, + save_dir, + split, + images, + global_step, + current_epoch, + batch_idx, + key_list, + dm, + logger=None): + root = os.path.join(save_dir, "images", split) + image_dict = {} + for k in images: # assume inverse normalization has been applied + # if k in ["samples_swapping", "samples_swapping_partial"]: + # grid = torchvision.utils.make_grid(images[k], nrow=8) + # else: + grid = images[k] + # grid = torchvision.utils.make_grid(images[k], nrow=8) + # if self.rescale: + # grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + # grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + + grid = grid.numpy() # shape: num_samples, channels, window + # for i in range(grid.shape[0]): + # grid[i] = dm.inverse_transform(grid[i], data_name=key_list[i]) # TODO: should apply different inverse transform for samples_swapping and swapping_intercept + # grid = (grid * 255).astype(np.uint8) + if len(grid.shape) == 3: + b, c, w = grid.shape # batchsize, channels, window + for i in range(b): + grid[i] = dm.inverse_transform(grid[i], + data_name=key_list[i]) + fig, axs = plt.subplots(c, b, + figsize=(b * 4, + c * 4)) # c rows, b columns + for bi in range(b): # transposed plotting + if c == 1: # typically 1 x 8 + axs[bi].plot(grid[bi, 0]) + else: + for ci in range(c): + axs[ci, bi].plot(grid[bi, ci]) + elif len( + grid.shape + ) == 4: # compare across rows, so batchsize as num of columns + n, b, c, w = grid.shape + if k == 'samples_swapping_intercept': + for i in range(n - 1): + grid[i] = dm.inverse_transform( + grid[i], data_name=key_list[0] + ) # always the first one except for the last row (the second one) + grid[n - 1] = dm.inverse_transform( + grid[n - 1], data_name=key_list[1] + ) # always the first one except for the last row (the second one) + else: + for i in range(b): + grid[:, + i] = dm.inverse_transform(grid[:, i], + data_name=key_list[i]) + fig, axs = plt.subplots(n, b, + figsize=(b * 4, + n * 4)) # n rows, b columns + for bi in range(b): + if n == 1: + for ci in range(c): + axs[bi].plot(grid[0, bi, ci]) + axs[bi].set_title(plot_naming(k, bi, n)) + # axs[ni].plot(grid[0, ni]) + else: + for ni in range(n): + for ci in range(c): + axs[ni, bi].plot(grid[ni, bi, ci]) + axs[ni, bi].set_title(plot_naming(k, bi, ni)) + # axs[bi, ni].plot(grid[bi, ni]) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( + k, global_step, current_epoch, batch_idx) + # path = os.path.join(root, filename) + # os.makedirs(os.path.split(path)[0], exist_ok=True) + plt.suptitle(filename) + # logger.log_graph(k,fig, step=global_step) + # image_dict[k] = wandb.Image(fig) + image_dict[k] = wandb.Image(fig) + # plt.savefig(path, transparent=False) + plt.close() + # Image.fromarray(grid).save(path) + logger.experiment.log(image_dict, step=global_step) + + def log_img(self, pl_module, batch, batch_idx, split="train", n_row=8): + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + # self.log_steps = [1000, check_idx] + if (self.check_frequency(check_idx) + and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and + callable(pl_module.log_images) and self.max_images > 0): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, + n_row=n_row, + split=split, + **self.log_images_kwargs) + key_list = pl_module.trainer.datamodule.key_list + batch_key_list = [] + for i in range(n_row): + batch_key_list.append( + key_list[batch['data_key'][i].detach().cpu().numpy()]) + + for k in images: + if k != "samples_swapping" and k != "samples_swapping_partial": # TODO: change to swapping intercept + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp( + images[k], -1., 1. + ) # should clamp to [0,1]? or modify data loader to [-1,1] + else: + images[k] = torch.clamp(images[k], -2., 2.) + + self.log_local(pl_module.logger.save_dir, + split, + images, + pl_module.global_step, + pl_module.current_epoch, + batch_idx, + batch_key_list, + pl_module.trainer.datamodule, + logger=pl_module.logger) + + logger_log_images = self.logger_log_images.get( + logger, lambda *args, **kwargs: None) + logger_log_images(pl_module, images, pl_module.global_step, split) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if ((check_idx % self.batch_freq) == 0 or + (check_idx in self.log_steps)) and (check_idx > 0 + or self.log_first_step): + try: + self.log_steps.pop(0) + except IndexError as e: + print(e) + pass + return True + return False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, + dataloader_idx): + # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and (pl_module.global_step > 0 + or self.log_first_step): + self.log_img(pl_module, batch, batch_idx, split="train") + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, + batch_idx, dataloader_idx): + # def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm + and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.root_gpu) + torch.cuda.synchronize(trainer.root_gpu) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + torch.cuda.synchronize(trainer.root_gpu) + max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.training_type_plugin.reduce(max_memory) + epoch_time = trainer.training_type_plugin.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +def prepare_trainer_configs(nowname, logdir, opt, lightning_config, ckptdir, + model, now, cfgdir, config, trainer_opt): + trainer_kwargs = dict() + + # default logger configs + default_logger_cfgs = { + "wandb": { + "target": "pytorch_lightning.loggers.WandbLogger", + "params": { + "name": f"{nowname}_{now}", + "save_dir": logdir, + "offline": opt.debug, + "id": f"{nowname}_{now}", + "project": "DisDiff-Time", + # "config": OmegaConf.to_container(config, resolve=True) + # "log_model": "all" + } + } + } + default_logger_cfg = default_logger_cfgs["wandb"] + if "logger" in lightning_config: + logger_cfg = lightning_config.logger + else: + logger_cfg = OmegaConf.create() + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + if lightning_config.trainer.grad_watch: + trainer_kwargs["logger"].watch(model) + if not opt.uncond: + trainer_kwargs["logger"].watch(model.cond_stage_model, + log='parameters', + log_freq=1000) + # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to + # specify which metric is used to determine best models + default_modelckpt_cfg = { + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": ckptdir, + "filename": "{epoch:06}-{val/loss_simple_ema:.4f}", + "verbose": True, + "save_last": True, + "auto_insert_metric_name": False + } + } + if hasattr(model, "monitor"): + print(f"Monitoring {model.monitor} as checkpoint metric.") + default_modelckpt_cfg["params"]["monitor"] = model.monitor + default_modelckpt_cfg["params"]["save_top_k"] = 3 + default_modelckpt_cfg["params"]["mode"] = "min" + if default_modelckpt_cfg["params"]["monitor"] == "train/step_num": + default_modelckpt_cfg["params"]["every_n_train_steps"] = 2000 + default_modelckpt_cfg["params"]["every_n_epochs"] = None + default_modelckpt_cfg["params"]["filename"] = "{step:09}" + default_modelckpt_cfg["params"]["mode"] = "max" + + if "modelcheckpoint" in lightning_config: + modelckpt_cfg = lightning_config.modelcheckpoint + else: + modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) + print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") + if version.parse(pl.__version__) < version.parse('1.4.0'): + trainer_kwargs["checkpoint_callback"] = instantiate_from_config( + modelckpt_cfg) + + # add callback which sets up log directory + default_callbacks_cfg = { + "setup_callback": { + "target": "utils.callback_utils.SetupCallback", + "params": { + "resume": opt.resume, + "now": now, + "logdir": logdir, + "ckptdir": ckptdir, + "cfgdir": cfgdir, + "config": config, + "lightning_config": lightning_config, + } + }, + "learning_rate_logger": { + "target": "pytorch_lightning.callbacks.LearningRateMonitor", + "params": { + "logging_interval": "step", + # "log_momentum": True + } + }, + "cuda_callback": { + "target": "utils.callback_utils.CUDACallback" + }, + } + if version.parse(pl.__version__) >= version.parse('1.4.0'): + default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) + + if "callbacks" in lightning_config: + callbacks_cfg = lightning_config.callbacks + else: + callbacks_cfg = OmegaConf.create() + + if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: + print( + 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.' + ) + default_metrics_over_trainsteps_ckpt_dict = { + 'metrics_over_trainsteps_checkpoint': { + "target": 'pytorch_lightning.callbacks.ModelCheckpoint', + 'params': { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + } + } + default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) + + callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) + if 'ignore_keys_callback' in callbacks_cfg and hasattr( + trainer_opt, 'resume_from_checkpoint'): + callbacks_cfg.ignore_keys_callback.params[ + 'ckpt_path'] = trainer_opt.resume_from_checkpoint + elif 'ignore_keys_callback' in callbacks_cfg: + del callbacks_cfg['ignore_keys_callback'] + + trainer_kwargs["callbacks"] = [ + instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg + ] + return trainer_kwargs diff --git a/diffusion/ldm/modules/guidance_scorer.py b/diffusion/ldm/modules/guidance_scorer.py index 77c97bc..8330cd8 100644 --- a/diffusion/ldm/modules/guidance_scorer.py +++ b/diffusion/ldm/modules/guidance_scorer.py @@ -1,13 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. - import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import TensorDataset, DataLoader from tqdm import tqdm import time -from scipy.stats import truncnorm # [Added code: 引入truncnorm用于采样] +from scipy.stats import truncnorm import math import torch.nn.functional as F @@ -104,9 +103,6 @@ def _compute_train_grad_sum(self, train_loader): device=self.device, ) # Accumulate loss for all training samples - # data=torch.tensor(train_loader[0]) - - # label=torch.tensor(train_loader[1]) train_loader = DataLoader(train_loader, batch_size=1, shuffle=False) for train_inputs, train_labels in tqdm(train_loader): if (train_labels == 0 and not self.negative_guidance) or ( @@ -182,58 +178,6 @@ def compute_gradient(self, test_sample, test_label): return dynamic_scale * gd_grad - def compute_noise_gd(self, test_sample, test_label, t): - """ - Compute gradient of grad-dot using cached training gradients - - Args: - test_sample: single test input tensor - test_label: ground truth label for test sample - """ - self.model.eval() - test_sample = test_sample.transpose(2, 1) - # Prepare test sample - test_sample = test_sample.detach().requires_grad_(True) - test_sample = test_sample.to(self.device) - #test_label = torch.tensor([test_label]).to(self.device) - test_label = test_label.to(self.device) - # Get test gradient - mask = torch.full((test_sample.shape[0], test_sample.shape[1]), - True, - dtype=bool, - device=self.device) - ddim_timesteps = torch.tensor([ - 0, 10, 20, 31, 43, 55, 67, 80, 93, 106, 119, 132, 145, 157, 169, - 179, 187, 193, 197, 199 - ], - device=t.device) - - t_ = t.unsqueeze(1) # [B, 1] - diff = torch.abs(ddim_timesteps - t_) # [B, 20] - t_idx = torch.argmin(diff, dim=1) - test_output = self.model(test_sample, mask, None, None, t_idx=t_idx) - test_loss = self.criterion(test_output, test_label) - test_grads = torch.autograd.grad(test_loss, - self.model.parameters(), - create_graph=True, - allow_unused=True) - filtered = [(p, g) for p, g in zip(self.model.parameters(), test_grads) - if g is not None] - if not filtered: - raise ValueError("No parameter received gradient!") - filtered_params, filtered_grads = zip(*filtered) - test_grads = self._normalize_gradients(filtered_grads) - total_dot = sum((test_g * cached_g).sum() for test_g, cached_g in zip( - test_grads[:74], self.cached_train_grads[:74])) - # Get gradient w.r.t test_sample - grad_wrt_sample = torch.autograd.grad(total_dot, - test_sample, - create_graph=False)[0] - #mmd_grad=self.compute_mmd_grad(test_sample)[0].squeeze(0).transpose(2,1) - gd_grad = grad_wrt_sample.squeeze(0).transpose(2, 1) - torch.cuda.empty_cache() - return gd_grad * 1e7 - def compute_influence(self, test_sample, test_label, t): """ Compute gradient of grad-dot using cached training gradients diff --git a/supplementary/mimiciii_prepare.md b/supplementary/mimiciii_prepare.md index e69de29..097f0c1 100644 --- a/supplementary/mimiciii_prepare.md +++ b/supplementary/mimiciii_prepare.md @@ -0,0 +1 @@ +We preprocess MIMIC-III by first querying the raw **vitals** and **admissions** tables, then isolating each ICU stay (`icustay_id`) as an independent sample. For every stay we extract seven routinely recorded signals—heart-rate, systolic/diastolic blood pressure, mean arterial pressure, respiratory rate, temperature, oxygen saturation (SpO₂), and urine output—resample them to an equal 1-hour grid, and truncate or zero-pad so every sample is a fixed **24 × 7** time-series matrix covering the first 24 hours in the unit. We attach a binary in-hospital mortality label from the admissions record, stack all samples into a single array, randomly shuffle, and split 80 % / 20 % into training and test sets while reporting the class balance. This yields a clean, length-aligned dataset ready for downstream modeling without exposing any protected health information. \ No newline at end of file