diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..1cf776a4e --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +**/__pycache__/** +**/.ipynb_checkpoints/** +.DS_Store + +logs/ +wandb/ +taming/ +src/ +=0.7.5 +*.ckpt +*.log +*.png diff --git a/analysis.py b/analysis.py new file mode 100644 index 000000000..7d4fd7e4a --- /dev/null +++ b/analysis.py @@ -0,0 +1,221 @@ +import argparse +import os +import torch +import torch.nn as nn +import torch.distributed as dist +import torchvision +import torchvision.transforms.functional as F +import numpy as np +import pandas as pd +import lpips +import clip + +from torch.utils.data import DataLoader +from torchvision.datasets import ImageFolder +from torchvision.transforms import transforms +from PIL import Image +from pytorch_wavelets import DWTForward +from segment_anything import SamPredictor, sam_model_registry +from tqdm import tqdm +from omegaconf import OmegaConf + +from ldm.models.autoencoder import AutoencoderKL + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ['SLURM_NTASKS']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + return + + torch.cuda.set_device(args.gpu) + print(f'| distributed init (rank {args.rank}): {args.dist_url}, gpu {args.gpu}') + dist.init_process_group( + backend='nccl', + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + + +def is_main_process(): + return dist.get_rank() == 0 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('VAE Analysis', add_help=False) + parser.add_argument('--batch_size', default=10, type=int) + parser.add_argument('--data_path', default='/BS/var/nobackup/imagenet-1k/', type=str) + parser.add_argument('--resos', default=256, type=int) + + parser.add_argument('--device', default='cuda', help='device to use for training / testing') + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + args = parser.parse_args() + + # multi-node and multi-GPU evaluation + init_distributed_mode(args) + + # data loading + transform = transforms.Compose([ + transforms.Resize(args.resos), + transforms.CenterCrop((args.resos, args.resos)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + dataset = ImageFolder(root=os.path.join(args.data_path, 'val'), transform=transform) + + val_sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) + image_val_loader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=val_sampler, + num_workers=4, + pin_memory=True, + ) + + # build the model + config = OmegaConf.load('configs/autoencoder/autoencoder_kl_16x16x16.yaml') + vae = AutoencoderKL.load_from_checkpoint( + 'logs/ch1-1-2-2-4_baseline/last.ckpt', + ddconfig=config.model.params.ddconfig, + lossconfig=config.model.params.lossconfig, + embed_dim=config.model.params.embed_dim, + ) + vae = vae.to(args.device) + vae.eval() + for p in vae.parameters(): p.requires_grad_(False) + print(f'prepare finished.') + + # evaluate the reconstruction loss [-1, 1] range + total_loss = {key: 0.0 for key in ['reconstruction', 'low_frequency', 'high_frequency', 'perceptual', 'clip_semantic', 'sam_semantic']} + total_images = 0 + batch_imgs_to_save = 10 + visualize_imgs = [] if is_main_process() else None + + # for evaluation metrics + dwt = DWTForward(J=1, wave='haar', mode='zero').to(args.device) + l2_loss = nn.MSELoss(reduction='mean') + + lpips_loss = lpips.LPIPS(net='vgg').to(args.device).eval() + + clip, preprocess_clip = clip.load('ViT-B/32', device=args.device) + clip.eval() + + sam = sam_model_registry['vit_b'](checkpoint='/BS/var/work/segment-anything/sam_vit_b_01ec64.pth') + sam = sam.to(args.device).eval() + sam_predictor = SamPredictor(sam) + + with torch.no_grad(): + for imgs, labels in tqdm(image_val_loader, disable=not is_main_process(), desc='Processing images', leave=True): + imgs = imgs.to(args.device) + rec_imgs, _, = vae(imgs) + + # first level DWT + ll1, hs = dwt(imgs) + lh1, hl1, hh1 = hs[0][:, 0], hs[0][:, 1], hs[0][:, 2] + rec_ll1, rec_hs = dwt(rec_imgs) + rec_lh1, rec_hl1, rec_hh1 = rec_hs[0][:, 0], rec_hs[0][:, 1], rec_hs[0][:, 2] + + # preprocess for CLIP, which expects input of size (224, 224) + # more efficient than applying preprocess_clip() sample-by-sample, but introduce slight discrepancies + # due to differences between PIL and tensor-based Resize implementations in torchvision + imgs_clip = torch.clamp((imgs + 1) / 2, min=0, max=1) + imgs_clip = F.resize(imgs_clip, 224, F.InterpolationMode.BICUBIC) + imgs_clip = F.normalize(imgs_clip, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + rec_imgs_clip = torch.clamp((rec_imgs + 1) / 2, min=0, max=1) + rec_imgs_clip = F.resize(rec_imgs_clip, 224, F.InterpolationMode.BICUBIC) + rec_imgs_clip = F.normalize(rec_imgs_clip, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + + # imgs_clip = torch.clamp((imgs + 1) / 2, min=0, max=1).cpu() + # imgs_clip = torch.stack([preprocess_clip(torchvision.transforms.ToPILImage()(img)) for img in imgs_clip]).to(args.device) + # rec_imgs_clip = torch.clamp((rec_imgs + 1) / 2, min=0, max=1).cpu() + # rec_imgs_clip = torch.stack([preprocess_clip(torchvision.transforms.ToPILImage()(img)) for img in rec_imgs_clip]).to(args.device) + + features_clip = clip.encode_image(imgs_clip) + rec_features_clip = clip.encode_image(rec_imgs_clip) + + # preprocess for SAM, which expects input of size (1024, 1024) + # slightly differs from apply_image(), which uses uint8 NumPy arrays + imgs_sam = torch.clamp((imgs + 1) / 2, min=0, max=1).mul_(255) + imgs_sam = sam_predictor.transform.apply_image_torch(imgs_sam) + rec_imgs_sam = torch.clamp((rec_imgs + 1) / 2, min=0, max=1).mul_(255) + rec_imgs_sam = sam_predictor.transform.apply_image_torch(rec_imgs_sam) + + sam_predictor.set_torch_image(imgs_sam, (256, 256)) + features_sam = sam_predictor.features + features_sam = features_sam.reshape(imgs.shape[0], -1) + sam_predictor.set_torch_image(rec_imgs_sam, (256, 256)) + rec_features_sam = sam_predictor.features + rec_features_sam = rec_features_sam.reshape(imgs.shape[0], -1) + + batch_losses = { + 'reconstruction': l2_loss(rec_imgs, imgs).item(), + 'low_frequency': l2_loss(rec_ll1, ll1).item(), + 'high_frequency': (l2_loss(rec_lh1, lh1).item() + l2_loss(rec_hl1, hl1).item() + l2_loss(rec_hh1, hh1).item()) / 3, + 'perceptual': lpips_loss(rec_imgs, imgs).mean().item(), + 'clip_semantic': 1 - nn.functional.cosine_similarity(features_clip, rec_features_clip).mean().item(), + 'sam_semantic': 1 - nn.functional.cosine_similarity(features_sam, rec_features_sam).mean().item(), + } + + for key in total_loss: + total_loss[key] += batch_losses[key] * imgs.shape[0] + + total_images += imgs.shape[0] + + if is_main_process() and len(visualize_imgs) < batch_imgs_to_save: + visualize_imgs.append(rec_imgs[:4].cpu()) + + # aggregate losses across all distributed processes + for key in total_loss: + t = torch.tensor(total_loss[key], device=args.device) + torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.SUM) + total_loss[key] = t.item() + + t = torch.tensor(total_images, device=args.device) + torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.SUM) + total_images = t.item() + + if is_main_process(): + save_dir = '/BS/var/work/analysis_figures' + + # visualize some reconstructed images + visualize_imgs = torch.cat(visualize_imgs, dim=0) + visualize_imgs = torch.clamp((visualize_imgs + 1) / 2, min=0, max=1) + visualize_imgs = torchvision.utils.make_grid(visualize_imgs, nrow=8, padding=0) + visualize_imgs = visualize_imgs.permute(1, 2, 0).mul_(255).numpy() + visualize_imgs = Image.fromarray(visualize_imgs.astype(np.uint8)) + visualize_imgs.save(f'{save_dir}/recon_kl-vae-f16c16-ldm-from-scratch.png') + + # compute average loss per component + avg_loss = {key: total_loss[key] / total_images for key in total_loss} + + # save results + csv_path = f'{save_dir}/loss_metrics.csv' + + new_row = { + 'Model': 'KL-VAE-f16c16-LDM-From-Scratch', + 'Dataset': 'ImageNet', + 'Reconstruction Loss': avg_loss['reconstruction'], + 'Low Frequency Loss': avg_loss['low_frequency'], + 'High Frequency Loss': avg_loss['high_frequency'], + 'Perceptual Loss': avg_loss['perceptual'], + 'CLIP Semantic Loss': avg_loss['clip_semantic'], + 'SAM Semantic Loss': avg_loss['sam_semantic'], + } + + if os.path.exists(csv_path): + df = pd.read_csv(csv_path) + df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True) + else: + df = pd.DataFrame([new_row]) + + df.to_csv(csv_path, index=False) diff --git a/analysis_wav.py b/analysis_wav.py new file mode 100644 index 000000000..6353ade7c --- /dev/null +++ b/analysis_wav.py @@ -0,0 +1,221 @@ +import argparse +import os +import torch +import torch.nn as nn +import torch.distributed as dist +import torchvision +import torchvision.transforms.functional as F +import numpy as np +import pandas as pd +import lpips +import clip + +from torch.utils.data import DataLoader +from torchvision.datasets import ImageFolder +from torchvision.transforms import transforms +from PIL import Image +from pytorch_wavelets import DWTForward +from segment_anything import SamPredictor, sam_model_registry +from tqdm import tqdm +from omegaconf import OmegaConf + +from ldm.models.autoencoder_wav import AutoencoderKL + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ['SLURM_NTASKS']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + return + + torch.cuda.set_device(args.gpu) + print(f'| distributed init (rank {args.rank}): {args.dist_url}, gpu {args.gpu}') + dist.init_process_group( + backend='nccl', + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + + +def is_main_process(): + return dist.get_rank() == 0 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('VAE Analysis', add_help=False) + parser.add_argument('--batch_size', default=10, type=int) + parser.add_argument('--data_path', default='/BS/var/nobackup/imagenet-1k/', type=str) + parser.add_argument('--resos', default=256, type=int) + + parser.add_argument('--device', default='cuda', help='device to use for training / testing') + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + args = parser.parse_args() + + # multi-node and multi-GPU evaluation + init_distributed_mode(args) + + # data loading + transform = transforms.Compose([ + transforms.Resize(args.resos), + transforms.CenterCrop((args.resos, args.resos)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + dataset = ImageFolder(root=os.path.join(args.data_path, 'val'), transform=transform) + + val_sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) + image_val_loader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=val_sampler, + num_workers=4, + pin_memory=True, + ) + + # build the model + config = OmegaConf.load('configs/autoencoder/autoencoder_kl-wav_16x16x16.yaml') + vae = AutoencoderKL.load_from_checkpoint( + 'logs/ch1-2-4_baseline/last.ckpt', + ddconfig=config.model.params.ddconfig, + lossconfig=config.model.params.lossconfig, + embed_dim=config.model.params.embed_dim, + ) + vae = vae.to(args.device) + vae.eval() + for p in vae.parameters(): p.requires_grad_(False) + print(f'prepare finished.') + + # evaluate the reconstruction loss [-1, 1] range + total_loss = {key: 0.0 for key in ['reconstruction', 'low_frequency', 'high_frequency', 'perceptual', 'clip_semantic', 'sam_semantic']} + total_images = 0 + batch_imgs_to_save = 10 + visualize_imgs = [] if is_main_process() else None + + # for evaluation metrics + dwt = DWTForward(J=1, wave='haar', mode='zero').to(args.device) + l2_loss = nn.MSELoss(reduction='mean') + + lpips_loss = lpips.LPIPS(net='vgg').to(args.device).eval() + + clip, preprocess_clip = clip.load('ViT-B/32', device=args.device) + clip.eval() + + sam = sam_model_registry['vit_b'](checkpoint='/BS/var/work/segment-anything/sam_vit_b_01ec64.pth') + sam = sam.to(args.device).eval() + sam_predictor = SamPredictor(sam) + + with torch.no_grad(): + for imgs, labels in tqdm(image_val_loader, disable=not is_main_process(), desc='Processing images', leave=True): + imgs = imgs.to(args.device) + rec_imgs, _, = vae(imgs) + + # first level DWT + ll1, hs = dwt(imgs) + lh1, hl1, hh1 = hs[0][:, 0], hs[0][:, 1], hs[0][:, 2] + rec_ll1, rec_hs = dwt(rec_imgs) + rec_lh1, rec_hl1, rec_hh1 = rec_hs[0][:, 0], rec_hs[0][:, 1], rec_hs[0][:, 2] + + # preprocess for CLIP, which expects input of size (224, 224) + # more efficient than applying preprocess_clip() sample-by-sample, but introduce slight discrepancies + # due to differences between PIL and tensor-based Resize implementations in torchvision + imgs_clip = torch.clamp((imgs + 1) / 2, min=0, max=1) + imgs_clip = F.resize(imgs_clip, 224, F.InterpolationMode.BICUBIC) + imgs_clip = F.normalize(imgs_clip, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + rec_imgs_clip = torch.clamp((rec_imgs + 1) / 2, min=0, max=1) + rec_imgs_clip = F.resize(rec_imgs_clip, 224, F.InterpolationMode.BICUBIC) + rec_imgs_clip = F.normalize(rec_imgs_clip, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + + # imgs_clip = torch.clamp((imgs + 1) / 2, min=0, max=1).cpu() + # imgs_clip = torch.stack([preprocess_clip(torchvision.transforms.ToPILImage()(img)) for img in imgs_clip]).to(args.device) + # rec_imgs_clip = torch.clamp((rec_imgs + 1) / 2, min=0, max=1).cpu() + # rec_imgs_clip = torch.stack([preprocess_clip(torchvision.transforms.ToPILImage()(img)) for img in rec_imgs_clip]).to(args.device) + + features_clip = clip.encode_image(imgs_clip) + rec_features_clip = clip.encode_image(rec_imgs_clip) + + # preprocess for SAM, which expects input of size (1024, 1024) + # slightly differs from apply_image(), which uses uint8 NumPy arrays + imgs_sam = torch.clamp((imgs + 1) / 2, min=0, max=1).mul_(255) + imgs_sam = sam_predictor.transform.apply_image_torch(imgs_sam) + rec_imgs_sam = torch.clamp((rec_imgs + 1) / 2, min=0, max=1).mul_(255) + rec_imgs_sam = sam_predictor.transform.apply_image_torch(rec_imgs_sam) + + sam_predictor.set_torch_image(imgs_sam, (256, 256)) + features_sam = sam_predictor.features + features_sam = features_sam.reshape(imgs.shape[0], -1) + sam_predictor.set_torch_image(rec_imgs_sam, (256, 256)) + rec_features_sam = sam_predictor.features + rec_features_sam = rec_features_sam.reshape(imgs.shape[0], -1) + + batch_losses = { + 'reconstruction': l2_loss(rec_imgs, imgs).item(), + 'low_frequency': l2_loss(rec_ll1, ll1).item(), + 'high_frequency': (l2_loss(rec_lh1, lh1).item() + l2_loss(rec_hl1, hl1).item() + l2_loss(rec_hh1, hh1).item()) / 3, + 'perceptual': lpips_loss(rec_imgs, imgs).mean().item(), + 'clip_semantic': 1 - nn.functional.cosine_similarity(features_clip, rec_features_clip).mean().item(), + 'sam_semantic': 1 - nn.functional.cosine_similarity(features_sam, rec_features_sam).mean().item(), + } + + for key in total_loss: + total_loss[key] += batch_losses[key] * imgs.shape[0] + + total_images += imgs.shape[0] + + if is_main_process() and len(visualize_imgs) < batch_imgs_to_save: + visualize_imgs.append(rec_imgs[:4].cpu()) + + # aggregate losses across all distributed processes + for key in total_loss: + t = torch.tensor(total_loss[key], device=args.device) + torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.SUM) + total_loss[key] = t.item() + + t = torch.tensor(total_images, device=args.device) + torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.SUM) + total_images = t.item() + + if is_main_process(): + save_dir = '/BS/var/work/analysis_figures' + + # visualize some reconstructed images + visualize_imgs = torch.cat(visualize_imgs, dim=0) + visualize_imgs = torch.clamp((visualize_imgs + 1) / 2, min=0, max=1) + visualize_imgs = torchvision.utils.make_grid(visualize_imgs, nrow=8, padding=0) + visualize_imgs = visualize_imgs.permute(1, 2, 0).mul_(255).numpy() + visualize_imgs = Image.fromarray(visualize_imgs.astype(np.uint8)) + visualize_imgs.save(f'{save_dir}/recon_kl-vae-wav-f16c16.png') + + # compute average loss per component + avg_loss = {key: total_loss[key] / total_images for key in total_loss} + + # save results + csv_path = f'{save_dir}/loss_metrics.csv' + + new_row = { + 'Model': 'KL-VAE-WAV-f16c16', + 'Dataset': 'ImageNet', + 'Reconstruction Loss': avg_loss['reconstruction'], + 'Low Frequency Loss': avg_loss['low_frequency'], + 'High Frequency Loss': avg_loss['high_frequency'], + 'Perceptual Loss': avg_loss['perceptual'], + 'CLIP Semantic Loss': avg_loss['clip_semantic'], + 'SAM Semantic Loss': avg_loss['sam_semantic'], + } + + if os.path.exists(csv_path): + df = pd.read_csv(csv_path) + df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True) + else: + df = pd.DataFrame([new_row]) + + df.to_csv(csv_path, index=False) diff --git a/configs/autoencoder/autoencoder_kl-wav_16x16x16.yaml b/configs/autoencoder/autoencoder_kl-wav_16x16x16.yaml new file mode 100644 index 000000000..b4d68c7ad --- /dev/null +++ b/configs/autoencoder/autoencoder_kl-wav_16x16x16.yaml @@ -0,0 +1,54 @@ +model: + base_learning_rate: 4.5e-6 + target: ldm.models.autoencoder_wav.AutoencoderKL + params: + monitor: "val/rec_loss" + embed_dim: 16 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 0.000001 + disc_weight: 0.5 + + ddconfig: + double_z: True + z_channels: 16 + resolution: 64 + in_channels: 48 + out_ch: 48 + ch: 128 + ch_mult: [ 1,2,4] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [16] + dropout: 0.0 + + +data: + target: main.DataModuleFromConfig + params: + batch_size: 32 + wrap: True + train: + target: ldm.data.imagenet.ImageNetSRTrain + params: + size: 256 + degradation: pil_nearest + validation: + target: ldm.data.imagenet.ImageNetSRValidation + params: + size: 256 + degradation: pil_nearest + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 1 diff --git a/configs/autoencoder/autoencoder_kl_16x16x16.yaml b/configs/autoencoder/autoencoder_kl_16x16x16.yaml index 5f1d10ec7..b9e875df2 100644 --- a/configs/autoencoder/autoencoder_kl_16x16x16.yaml +++ b/configs/autoencoder/autoencoder_kl_16x16x16.yaml @@ -27,7 +27,7 @@ model: data: target: main.DataModuleFromConfig params: - batch_size: 12 + batch_size: 16 wrap: True train: target: ldm.data.imagenet.ImageNetSRTrain @@ -51,4 +51,4 @@ lightning: trainer: benchmark: True - accumulate_grad_batches: 2 + accumulate_grad_batches: 1 diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py index 1c473f9c6..14156c55f 100644 --- a/ldm/data/imagenet.py +++ b/ldm/data/imagenet.py @@ -148,21 +148,25 @@ def __init__(self, process_images=True, data_root=None, **kwargs): super().__init__(**kwargs) def _prepare(self): - if self.data_root: - self.root = os.path.join(self.data_root, self.NAME) - else: - cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) - - self.datadir = os.path.join(self.root, "data") - self.txt_filelist = os.path.join(self.root, "filelist.txt") + # if self.data_root: + # self.root = os.path.join(self.data_root, self.NAME) + # else: + # cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + # self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + # self.datadir = os.path.join(self.root, "data") + # self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.root = '/BS/var/nobackup/imagenet-1k/' + self.datadir = os.path.join(self.root, 'train') + self.txt_filelist = os.path.join(self.root, 'train-txt', 'filelist.txt') self.expected_length = 1281167 self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True) - if not tdu.is_prepared(self.root): + + # if not tdu.is_prepared(self.root): + if not tdu.is_prepared(os.path.join(self.root, 'train-txt')): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) - + datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) @@ -183,7 +187,7 @@ def _prepare(self): os.makedirs(subdir, exist_ok=True) with tarfile.open(subpath, "r:") as tar: tar.extractall(path=subdir) - + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) @@ -191,7 +195,8 @@ def _prepare(self): with open(self.txt_filelist, "w") as f: f.write(filelist) - tdu.mark_prepared(self.root) + # tdu.mark_prepared(self.root) + tdu.mark_prepared(os.path.join(self.root, 'train-txt')) class ImageNetValidation(ImageNetBase): @@ -214,17 +219,22 @@ def __init__(self, process_images=True, data_root=None, **kwargs): super().__init__(**kwargs) def _prepare(self): - if self.data_root: - self.root = os.path.join(self.data_root, self.NAME) - else: - cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) - self.datadir = os.path.join(self.root, "data") - self.txt_filelist = os.path.join(self.root, "filelist.txt") + # if self.data_root: + # self.root = os.path.join(self.data_root, self.NAME) + # else: + # cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + # self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + # self.datadir = os.path.join(self.root, "data") + # self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.root = '/BS/var/nobackup/imagenet-1k/' + self.datadir = os.path.join(self.root, 'val') + self.txt_filelist = os.path.join(self.root, 'val-txt', 'filelist.txt') self.expected_length = 50000 self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False) - if not tdu.is_prepared(self.root): + + # if not tdu.is_prepared(self.root): + if not tdu.is_prepared(os.path.join(self.root, 'val-txt')): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -265,8 +275,8 @@ def _prepare(self): with open(self.txt_filelist, "w") as f: f.write(filelist) - tdu.mark_prepared(self.root) - + # tdu.mark_prepared(self.root) + tdu.mark_prepared(os.path.join(self.root, 'val-txt')) class ImageNetSR(Dataset): diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py index 6a9c4f454..e85d77303 100644 --- a/ldm/models/autoencoder.py +++ b/ldm/models/autoencoder.py @@ -1,7 +1,12 @@ import torch +import torchvision import pytorch_lightning as pl import torch.nn.functional as F +import numpy as np +import wandb + from contextlib import contextmanager +from PIL import Image from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer @@ -305,10 +310,13 @@ def __init__(self, if colorize_nlabels is not None: assert type(colorize_nlabels)==int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + # if monitor is not None: + # self.monitor = monitor + # if ckpt_path is not None: + # self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + # activates manual optimization for multiple optimizers + self.automatic_optimization = False def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] @@ -348,40 +356,80 @@ def get_input(self, batch, k): x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() return x - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) - if optimizer_idx == 0: - # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return aeloss - - if optimizer_idx == 1: - # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - - self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return discloss + opt_ae, opt_disc = self.optimizers() + + # train encoder+decoder+logvar + # adjust global step to match LR scheduler step: two optimizers -> two steps per iteration + ae_loss, log_dict_ae = self.loss( + inputs, reconstructions, posterior, 0, self.global_step // 2, + last_layer=self.get_last_layer(), split="train" + ) + + opt_ae.zero_grad() + self.manual_backward(ae_loss) + opt_ae.step() + + # train the discriminator + # adjust global step to match LR scheduler step: two optimizers -> two steps per iteration + disc_loss, log_dict_disc = self.loss( + inputs, reconstructions, posterior, 1, self.global_step // 2, + last_layer=self.get_last_layer(), split="train" + ) + + opt_disc.zero_grad() + self.manual_backward(disc_loss) + opt_disc.step() + + # log + for key, value in log_dict_ae.items(): + self.log(key, value, sync_dist=True) + for key, value in log_dict_disc.items(): + self.log(key, value, sync_dist=True) def validation_step(self, batch, batch_idx): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val") - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val") - - self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict + + # VAE + aeloss, log_dict_ae = self.loss( + inputs, reconstructions, posterior, 0, self.global_step // 2, + last_layer=self.get_last_layer(), split="val", + ) + + # Discriminator + discloss, log_dict_disc = self.loss( + inputs, reconstructions, posterior, 1, self.global_step // 2, + last_layer=self.get_last_layer(), split="val", + ) + + # log + for key, value in log_dict_ae.items(): + self.log(key, value, sync_dist=True) + for key, value in log_dict_disc.items(): + self.log(key, value, sync_dist=True) + + # only plot one batch + if batch_idx == 0: + self.imgs = inputs + self.rec_imgs = reconstructions + self.posterior = posterior + + def on_validation_epoch_end(self): + # reference images + if self.current_epoch == 0: + imgs = self.get_image_grid(self.imgs) + self.logger.experiment.log({'val_image/orig_imgs': wandb.Image(imgs)}) + # reconstructed images + if self.current_epoch % 5 == 0: + rec_imgs = self.get_image_grid(self.rec_imgs) + sample_imgs = self.decode(torch.randn_like(self.posterior.sample())) + sample_imgs = self.get_image_grid(sample_imgs) + self.logger.experiment.log({'val_image/recon_imgs': wandb.Image(rec_imgs)}) + self.logger.experiment.log({'val_image/sample_imgs': wandb.Image(sample_imgs)}) def configure_optimizers(self): lr = self.learning_rate @@ -397,6 +445,15 @@ def configure_optimizers(self): def get_last_layer(self): return self.decoder.conv_out.weight + def get_image_grid(self, imgs: torch.Tensor, nrow: int = 8, padding: int = 0) -> Image.Image: + """ + Get a grid of images suitable for saving or visualization. + """ + imgs = torch.clamp((imgs.detach() + 1) / 2, min=0, max=1) # normalize to [0, 1] + imgs = torchvision.utils.make_grid(imgs, nrow=nrow, padding=padding) + imgs = imgs.permute(1, 2, 0).mul_(255).cpu().numpy() + return Image.fromarray(imgs.astype(np.uint8)) + @torch.no_grad() def log_images(self, batch, only_inputs=False, **kwargs): log = dict() diff --git a/ldm/models/autoencoder_wav.py b/ldm/models/autoencoder_wav.py new file mode 100644 index 000000000..5f6bdc28e --- /dev/null +++ b/ldm/models/autoencoder_wav.py @@ -0,0 +1,600 @@ +import torch +import torch.nn as nn +import torchvision +import pytorch_lightning as pl +import torch.nn.functional as F +import numpy as np +import wandb + +from contextlib import contextmanager +from PIL import Image +from pytorch_wavelets import DWTForward, DWTInverse + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + # if monitor is not None: + # self.monitor = monitor + # if ckpt_path is not None: + # self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + # wavelet decomposition + self.dwt = DWTForward(J=2, wave='haar', mode='zero') + self.idwt = DWTInverse(wave='haar', mode='zero') + + self.downsample_wav = DownsampleWav(in_channels=9, out_channels=36) # (9, 128, 128) -> (36, 64, 64) + self.upsample_wav = UpsampleWav(in_channels=36, out_channels=9) # (36, 64, 64) -> (9, 128, 128) + + # activates manual optimization for multiple optimizers + self.automatic_optimization = False + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode_wav(self, x): + ll2, (h1, h2) = self.dwt(x) + h1 = h1.view(h1.shape[0], -1, h1.shape[3], h1.shape[4]) # -> (N, C * 3, H, W) + h2 = h2.view(h2.shape[0], -1, h2.shape[3], h2.shape[4]) # -> (N, C * 3, H, W) + h1, h2, ll2 = h1 / 2, h2 / 4, ll2 / 4 # normalization + h1 = self.downsample_wav(h1) + coeff = torch.cat([h1, h2, ll2], dim=1) + return coeff + + def decode_wav(self, x): + rec_h1, rec_h2, rec_ll2 = x[:, 0:36], x[:, 36:45], x[:, 45:48] + rec_h1 = self.upsample_wav(rec_h1) + rec_h1, rec_h2, rec_ll2 = rec_h1 * 2, rec_h2 * 4, rec_ll2 * 4 # denormalization + rec_h1 = rec_h1.view(rec_h1.shape[0], 3, 3, rec_h1.shape[2], rec_h1.shape[3]) # -> (N, C, 3, H, W) + rec_h2 = rec_h2.view(rec_h2.shape[0], 3, 3, rec_h2.shape[2], rec_h2.shape[3]) # -> (N, C, 3, H, W) + rec_img = self.idwt((rec_ll2, [rec_h1, rec_h2])) + return rec_img + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + coeff = self.encode_wav(input) + posterior = self.encode(coeff) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + rec_img = self.decode_wav(dec) + return rec_img, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + opt_ae, opt_disc = self.optimizers() + + # train encoder+decoder+logvar + # adjust global step to match LR scheduler step: two optimizers -> two steps per iteration + ae_loss, log_dict_ae = self.loss( + inputs, reconstructions, posterior, 0, self.global_step // 2, + last_layer=self.get_last_layer(), split="train" + ) + + opt_ae.zero_grad() + self.manual_backward(ae_loss) + opt_ae.step() + + # train the discriminator + # adjust global step to match LR scheduler step: two optimizers -> two steps per iteration + disc_loss, log_dict_disc = self.loss( + inputs, reconstructions, posterior, 1, self.global_step // 2, + last_layer=self.get_last_layer(), split="train" + ) + + opt_disc.zero_grad() + self.manual_backward(disc_loss) + opt_disc.step() + + # log + for key, value in log_dict_ae.items(): + self.log(key, value, sync_dist=True) + for key, value in log_dict_disc.items(): + self.log(key, value, sync_dist=True) + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + # VAE + aeloss, log_dict_ae = self.loss( + inputs, reconstructions, posterior, 0, self.global_step // 2, + last_layer=self.get_last_layer(), split="val", + ) + + # Discriminator + discloss, log_dict_disc = self.loss( + inputs, reconstructions, posterior, 1, self.global_step // 2, + last_layer=self.get_last_layer(), split="val", + ) + + # log + for key, value in log_dict_ae.items(): + self.log(key, value, sync_dist=True) + for key, value in log_dict_disc.items(): + self.log(key, value, sync_dist=True) + + # only plot one batch + if batch_idx == 0: + self.imgs = inputs + ll2, (h1, h2) = self.dwt(inputs) + h1, h2, ll2 = h1 / 2, h2 / 4, ll2 / 4 # normalization + self.lh1 = h1[:, 0] + self.hl1 = h1[:, 1] + self.hh1 = h1[:, 2] + self.ll2 = ll2 + self.lh2 = h2[:, 0] + self.hl2 = h2[:, 1] + self.hh2 = h2[:, 2] + self.rec_imgs = reconstructions + rec_ll2, (rec_h1, rec_h2) = self.dwt(reconstructions) + rec_h1, rec_h2, rec_ll2 = rec_h1 / 2, rec_h2 / 4, rec_ll2 / 4 # normalization + self.rec_lh1 = rec_h1[:, 0] + self.rec_hl1 = rec_h1[:, 1] + self.rec_hh1 = rec_h1[:, 2] + self.rec_ll2 = rec_ll2 + self.rec_lh2 = rec_h2[:, 0] + self.rec_hl2 = rec_h2[:, 1] + self.rec_hh2 = rec_h2[:, 2] + self.posterior = posterior + + def on_validation_epoch_end(self): + # reference images + if self.current_epoch == 0: + imgs = self.get_image_grid(self.imgs) + lh1 = self.get_image_grid(self.lh1) + hl1 = self.get_image_grid(self.hl1) + hh1 = self.get_image_grid(self.hh1) + ll2 = self.get_image_grid(self.ll2) + lh2 = self.get_image_grid(self.lh2) + hl2 = self.get_image_grid(self.hl2) + hh2 = self.get_image_grid(self.hh2) + self.logger.experiment.log({'val_image/orig_imgs': wandb.Image(imgs)}) + self.logger.experiment.log({'val_image/orig_lh1': wandb.Image(lh1)}) + self.logger.experiment.log({'val_image/orig_hl1': wandb.Image(hl1)}) + self.logger.experiment.log({'val_image/orig_hh1': wandb.Image(hh1)}) + self.logger.experiment.log({'val_image/orig_ll2': wandb.Image(ll2)}) + self.logger.experiment.log({'val_image/orig_lh2': wandb.Image(lh2)}) + self.logger.experiment.log({'val_image/orig_hl2': wandb.Image(hl2)}) + self.logger.experiment.log({'val_image/orig_hh2': wandb.Image(hh2)}) + # reconstructed images + if self.current_epoch % 5 == 0: + rec_imgs = self.get_image_grid(self.rec_imgs) + rec_lh1 = self.get_image_grid(self.rec_lh1) + rec_hl1 = self.get_image_grid(self.rec_hl1) + rec_hh1 = self.get_image_grid(self.rec_hh1) + rec_ll2 = self.get_image_grid(self.rec_ll2) + rec_lh2 = self.get_image_grid(self.rec_lh2) + rec_hl2 = self.get_image_grid(self.rec_hl2) + rec_hh2 = self.get_image_grid(self.rec_hh2) + dec = self.decode(torch.randn_like(self.posterior.sample())) + sample_imgs = self.decode_wav(dec) + sample_imgs = self.get_image_grid(sample_imgs) + self.logger.experiment.log({'val_image/recon_imgs': wandb.Image(rec_imgs)}) + self.logger.experiment.log({'val_image/recon_lh1': wandb.Image(rec_lh1)}) + self.logger.experiment.log({'val_image/recon_hl1': wandb.Image(rec_hl1)}) + self.logger.experiment.log({'val_image/recon_hh1': wandb.Image(rec_hh1)}) + self.logger.experiment.log({'val_image/recon_ll2': wandb.Image(rec_ll2)}) + self.logger.experiment.log({'val_image/recon_lh2': wandb.Image(rec_lh2)}) + self.logger.experiment.log({'val_image/recon_hl2': wandb.Image(rec_hl2)}) + self.logger.experiment.log({'val_image/recon_hh2': wandb.Image(rec_hh2)}) + self.logger.experiment.log({'val_image/sample_imgs': wandb.Image(sample_imgs)}) + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def get_image_grid(self, imgs: torch.Tensor, nrow: int = 8, padding: int = 0) -> Image.Image: + """ + Get a grid of images suitable for saving or visualization. + """ + imgs = torch.clamp((imgs.detach() + 1) / 2, min=0, max=1) # normalize to [0, 1] + imgs = torchvision.utils.make_grid(imgs, nrow=nrow, padding=padding) + imgs = imgs.permute(1, 2, 0).mul_(255).cpu().numpy() + return Image.fromarray(imgs.astype(np.uint8)) + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class UpsampleWav(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.upconv = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1) + self.channel_reduce = nn.Conv2d(in_channels, out_channels, kernel_size=1) # Reduce channels 36 → 9 + + def forward(self, x): + x = self.upconv(x) # Upsample spatially (64x64 → 128x128) + x = self.channel_reduce(x) # Reduce channels (36 → 9) + return x + + +class DownsampleWav(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1) # Proper downsampling + self.channel_expand = nn.Conv2d(in_channels, out_channels, kernel_size=1) # Adjust channels + + def forward(self, x): + x = self.conv(x) # Downsample spatially (128x128 → 64x64) + x = self.channel_expand(x) # Adjust channels (9 → 36) + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py index 672c1e32a..70b07980a 100644 --- a/ldm/modules/losses/contperceptual.py +++ b/ldm/modules/losses/contperceptual.py @@ -53,7 +53,7 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx, nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: - weighted_nll_loss = weights*nll_loss + weighted_nll_loss = weights * nll_loss weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] kl_loss = posteriors.kl() @@ -75,20 +75,23 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx, d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) except RuntimeError: assert not self.training - d_weight = torch.tensor(0.0) + d_weight = 0.0 # torch.tensor(0.0) else: - d_weight = torch.tensor(0.0) + d_weight = 0.0 # torch.tensor(0.0) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight, + "{}/disc_factor".format(split): disc_factor, + "{}/g_loss".format(split): g_loss.detach().mean(), + } return loss, log if optimizer_idx == 1: @@ -103,9 +106,9 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } - return d_loss, log - + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } + return d_loss, log \ No newline at end of file diff --git a/main.py b/main.py index e8e18c18f..3bd13c846 100644 --- a/main.py +++ b/main.py @@ -11,14 +11,16 @@ from functools import partial from PIL import Image -from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.loggers.wandb import WandbLogger +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.utilities.rank_zero import rank_zero_only from pytorch_lightning.utilities import rank_zero_info from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.util import instantiate_from_config +from ldm.models.autoencoder import AutoencoderKL def get_parser(**parser_kwargs): @@ -38,16 +40,16 @@ def str2bool(v): "--name", type=str, const=True, - default="", + default="stage-1", nargs="?", - help="postfix for logdir", + help="name of the run", ) parser.add_argument( "-r", "--resume", type=str, const=True, - default="", + default=None, nargs="?", help="resume from logdir or checkpoint in logdir", ) @@ -80,7 +82,9 @@ def str2bool(v): parser.add_argument( "-p", "--project", - help="name of new or path to existing project" + type=str, + default="vae", + help="name of the project to which this run will belong" ) parser.add_argument( "-d", @@ -120,6 +124,24 @@ def str2bool(v): default=True, help="scale base-lr by ngpu * batch_size * n_accumulate", ) + parser.add_argument( + "--num_nodes", + type=int, + default=1, + help="number of nodes for distributed training", + ) + parser.add_argument( + "--wandb_id", + type=str, + default=None, + help="run ID resuming logging of a model", + ) + parser.add_argument( + "--max_ep", + type=int, + default=200, + help="number of training epochs", + ) return parser @@ -166,7 +188,7 @@ def __init__(self, batch_size, train=None, validation=None, test=None, predict=N super().__init__() self.batch_size = batch_size self.dataset_configs = dict() - self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.num_workers = 4 # num_workers if num_workers is not None else batch_size * 2 self.use_worker_init_fn = use_worker_init_fn if train is not None: self.dataset_configs["train"] = train @@ -415,48 +437,7 @@ def on_train_epoch_end(self, trainer, pl_module, outputs): pass -if __name__ == "__main__": - # custom parser to specify config files, train, test and debug mode, - # postfix, resume. - # `--key value` arguments are interpreted as arguments to the trainer. - # `nested.key=value` arguments are interpreted as config parameters. - # configs are merged from left-to-right followed by command line parameters. - - # model: - # base_learning_rate: float - # target: path to lightning module - # params: - # key: value - # data: - # target: main.DataModuleFromConfig - # params: - # batch_size: int - # wrap: bool - # train: - # target: path to train dataset - # params: - # key: value - # validation: - # target: path to validation dataset - # params: - # key: value - # test: - # target: path to test dataset - # params: - # key: value - # lightning: (optional, has sane defaults and can be specified on cmdline) - # trainer: - # additional arguments to trainer - # logger: - # logger to instantiate - # modelcheckpoint: - # modelcheckpoint to instantiate - # callbacks: - # callback1: - # target: importpath - # params: - # key: value - +if __name__ == '__main__': now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # add cwd for convenience and to make classes in this file available when @@ -465,277 +446,107 @@ def on_train_epoch_end(self, trainer, pl_module, outputs): sys.path.append(os.getcwd()) parser = get_parser() - parser = Trainer.add_argparse_args(parser) - opt, unknown = parser.parse_known_args() - if opt.name and opt.resume: - raise ValueError( - "-n/--name and -r/--resume cannot be specified both." - "If you want to resume training in a new log folder, " - "use -n/--name in combination with --resume_from_checkpoint" + + # distributed training parameters + gpus = torch.cuda.device_count() + num_nodes = opt.num_nodes + rank = int(os.getenv('NODE_RANK')) if os.getenv('NODE_RANK') is not None else 0 + print(f'[INFO] Number of GPUs available: {gpus}') + print(f'[INFO] Number of nodes: {num_nodes}') + + pl.seed_everything(opt.seed, workers=True) + + # logging, checkpointing and resuming + save_ckpt_dir = f'{opt.logdir}/{opt.name}' + resume = opt.resume is not None + + if rank == 0: # prevents from logging multiple times + logger = WandbLogger( + project=opt.project, name=opt.name, id=opt.wandb_id, + offline=False, resume='must' if resume else None, ) - if opt.resume: - if not os.path.exists(opt.resume): - raise ValueError("Cannot find {}".format(opt.resume)) - if os.path.isfile(opt.resume): - paths = opt.resume.split("/") - # idx = len(paths)-paths[::-1].index("logs")+1 - # logdir = "/".join(paths[:idx]) - logdir = "/".join(paths[:-2]) - ckpt = opt.resume - else: - assert os.path.isdir(opt.resume), opt.resume - logdir = opt.resume.rstrip("/") - ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") - - opt.resume_from_checkpoint = ckpt - base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) - opt.base = base_configs + opt.base - _tmp = logdir.split("/") - nowname = _tmp[-1] else: - if opt.name: - name = "_" + opt.name - elif opt.base: - cfg_fname = os.path.split(opt.base[0])[-1] - cfg_name = os.path.splitext(cfg_fname)[0] - name = "_" + cfg_name - else: - name = "" - nowname = now + name + opt.postfix - logdir = os.path.join(opt.logdir, nowname) - - ckptdir = os.path.join(logdir, "checkpoints") - cfgdir = os.path.join(logdir, "configs") - seed_everything(opt.seed) - - try: - # init and save configs - configs = [OmegaConf.load(cfg) for cfg in opt.base] - cli = OmegaConf.from_dotlist(unknown) - config = OmegaConf.merge(*configs, cli) - lightning_config = config.pop("lightning", OmegaConf.create()) - # merge trainer cli with config - trainer_config = lightning_config.get("trainer", OmegaConf.create()) - # default to ddp - trainer_config["accelerator"] = "ddp" - for k in nondefault_trainer_args(opt): - trainer_config[k] = getattr(opt, k) - if not "gpus" in trainer_config: - del trainer_config["accelerator"] - cpu = True - else: - gpuinfo = trainer_config["gpus"] - print(f"Running on GPUs {gpuinfo}") - cpu = False - trainer_opt = argparse.Namespace(**trainer_config) - lightning_config.trainer = trainer_config - - # model - model = instantiate_from_config(config.model) - - # trainer and callbacks - trainer_kwargs = dict() - - # default logger configs - default_logger_cfgs = { - "wandb": { - "target": "pytorch_lightning.loggers.WandbLogger", - "params": { - "name": nowname, - "save_dir": logdir, - "offline": opt.debug, - "id": nowname, - } - }, - "testtube": { - "target": "pytorch_lightning.loggers.TestTubeLogger", - "params": { - "name": "testtube", - "save_dir": logdir, - } - }, - } - default_logger_cfg = default_logger_cfgs["testtube"] - if "logger" in lightning_config: - logger_cfg = lightning_config.logger - else: - logger_cfg = OmegaConf.create() - logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) - trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) - - # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to - # specify which metric is used to determine best models - default_modelckpt_cfg = { - "target": "pytorch_lightning.callbacks.ModelCheckpoint", - "params": { - "dirpath": ckptdir, - "filename": "{epoch:06}", - "verbose": True, - "save_last": True, - } - } - if hasattr(model, "monitor"): - print(f"Monitoring {model.monitor} as checkpoint metric.") - default_modelckpt_cfg["params"]["monitor"] = model.monitor - default_modelckpt_cfg["params"]["save_top_k"] = 3 - - if "modelcheckpoint" in lightning_config: - modelckpt_cfg = lightning_config.modelcheckpoint - else: - modelckpt_cfg = OmegaConf.create() - modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) - print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") - if version.parse(pl.__version__) < version.parse('1.4.0'): - trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) - - # add callback which sets up log directory - default_callbacks_cfg = { - "setup_callback": { - "target": "main.SetupCallback", - "params": { - "resume": opt.resume, - "now": now, - "logdir": logdir, - "ckptdir": ckptdir, - "cfgdir": cfgdir, - "config": config, - "lightning_config": lightning_config, - } - }, - "image_logger": { - "target": "main.ImageLogger", - "params": { - "batch_frequency": 750, - "max_images": 4, - "clamp": True - } - }, - "learning_rate_logger": { - "target": "main.LearningRateMonitor", - "params": { - "logging_interval": "step", - # "log_momentum": True - } - }, - "cuda_callback": { - "target": "main.CUDACallback" - }, - } - if version.parse(pl.__version__) >= version.parse('1.4.0'): - default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) - - if "callbacks" in lightning_config: - callbacks_cfg = lightning_config.callbacks - else: - callbacks_cfg = OmegaConf.create() - - if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: - print( - 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') - default_metrics_over_trainsteps_ckpt_dict = { - 'metrics_over_trainsteps_checkpoint': - {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', - 'params': { - "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), - "filename": "{epoch:06}-{step:09}", - "verbose": True, - 'save_top_k': -1, - 'every_n_train_steps': 10000, - 'save_weights_only': True - } - } - } - default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) - - callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) - if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): - callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint - elif 'ignore_keys_callback' in callbacks_cfg: - del callbacks_cfg['ignore_keys_callback'] - - trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] - - trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) - trainer.logdir = logdir ### - - # data - data = instantiate_from_config(config.data) - # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html - # calling these ourselves should not be necessary but it is. - # lightning still takes care of proper multiprocessing though - data.prepare_data() - data.setup() - print("#### Data #####") - for k in data.datasets: - print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") - - # configure learning rate - bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate - if not cpu: - ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) - else: - ngpu = 1 - if 'accumulate_grad_batches' in lightning_config.trainer: - accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches - else: - accumulate_grad_batches = 1 - print(f"accumulate_grad_batches = {accumulate_grad_batches}") - lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches - if opt.scale_lr: - model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr - print( - "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( - model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) - else: - model.learning_rate = base_lr - print("++++ NOT USING LR SCALING ++++") - print(f"Setting learning rate to {model.learning_rate:.2e}") - - - # allow checkpointing via USR1 - def melk(*args, **kwargs): - # run all checkpoint hooks - if trainer.global_rank == 0: - print("Summoning checkpoint.") - ckpt_path = os.path.join(ckptdir, "last.ckpt") - trainer.save_checkpoint(ckpt_path) - - - def divein(*args, **kwargs): - if trainer.global_rank == 0: - import pudb; - pudb.set_trace() - - - import signal + logger = WandbLogger( + project=opt.project, name=opt.name, + offline=True, + ) + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + # model + if resume: + model = AutoencoderKL.load_from_checkpoint( + opt.resume, + ddconfig=config.model.params.ddconfig, + lossconfig=config.model.params.lossconfig, + embed_dim=config.model.params.embed_dim, + ) + print(f'[INFO] Loaded from {opt.resume}') + else: + model = AutoencoderKL( + ddconfig=config.model.params.ddconfig, + lossconfig=config.model.params.lossconfig, + embed_dim=config.model.params.embed_dim, + ) - signal.signal(signal.SIGUSR1, melk) - signal.signal(signal.SIGUSR2, divein) + # data + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + print("#### Data #####") + for k in data.datasets: + print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + + # inspect the first batch + train_loader = data.train_dataloader() + batch = next(iter(train_loader)) + x = batch["image"] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + print(f'[INFO] Batch images shape: {x.shape}, max: {x.max()}, min: {x.min()}') + x = torchvision.utils.make_grid((x + 1) / 2, nrow=8, padding=0) + x = x.permute(1, 2, 0).mul_(255).cpu().numpy() + x = Image.fromarray(x.astype(np.uint8)) + x.save('ref.png') + + # configure learning rate + bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate + accumulate_bs = bs * num_nodes * gpus + accumulate_grad_batches = config.lightning.trainer.accumulate_grad_batches + lr = accumulate_grad_batches * accumulate_bs * base_lr + model.learning_rate = lr + + print(f'[INFO] Accumulate_grad_batches = {accumulate_grad_batches}') + print(f'[INFO] Cumulative batch size: {accumulate_bs}') + print(f'[INFO] Final learning rate: {lr:.2e}') + + # callbacks + checkpoint_callback = ModelCheckpoint( + dirpath=save_ckpt_dir, + filename='{epoch:03d}', + save_last=True, + save_top_k=1, + monitor=config.model.params.monitor, + mode='min', + every_n_epochs=1, + ) + callbacks = [LearningRateMonitor(logging_interval='step'), checkpoint_callback] + + # PyTorch Lightning Trainer + trainer = pl.Trainer( + accelerator='gpu', num_nodes=num_nodes, devices=gpus, + precision='32', deterministic=True, + callbacks=callbacks, logger=logger, + strategy=DDPStrategy(find_unused_parameters=True), + accumulate_grad_batches=accumulate_grad_batches, + max_epochs=opt.max_ep, + ) - # run - if opt.train: - try: - trainer.fit(model, data) - except Exception: - melk() - raise - if not opt.no_test and not trainer.interrupted: - trainer.test(model, data) - except Exception: - if opt.debug and trainer.global_rank == 0: - try: - import pudb as debugger - except ImportError: - import pdb as debugger - debugger.post_mortem() - raise - finally: - # move newly created debug project to debug_runs - if opt.debug and not opt.resume and trainer.global_rank == 0: - dst, name = os.path.split(logdir) - dst = os.path.join(dst, "debug_runs", name) - os.makedirs(os.path.split(dst)[0], exist_ok=True) - os.rename(logdir, dst) - if trainer.global_rank == 0: - print(trainer.profiler.summary()) + # train the model + trainer.fit(model, data, ckpt_path=opt.resume) + \ No newline at end of file diff --git a/main_wav.py b/main_wav.py new file mode 100644 index 000000000..5f589eca1 --- /dev/null +++ b/main_wav.py @@ -0,0 +1,552 @@ +import argparse, os, sys, datetime, glob, importlib, csv +import numpy as np +import time +import torch +import torchvision +import pytorch_lightning as pl + +from packaging import version +from omegaconf import OmegaConf +from torch.utils.data import random_split, DataLoader, Dataset, Subset +from functools import partial +from PIL import Image + +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor +from pytorch_lightning.loggers.wandb import WandbLogger +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info + +from ldm.data.base import Txt2ImgIterableBaseDataset +from ldm.util import instantiate_from_config +from ldm.models.autoencoder_wav import AutoencoderKL + + +def get_parser(**parser_kwargs): + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument( + "-n", + "--name", + type=str, + const=True, + default="stage-1", + nargs="?", + help="name of the run", + ) + parser.add_argument( + "-r", + "--resume", + type=str, + const=True, + default=None, + nargs="?", + help="resume from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-t", + "--train", + type=str2bool, + const=True, + default=False, + nargs="?", + help="train", + ) + parser.add_argument( + "--no-test", + type=str2bool, + const=True, + default=False, + nargs="?", + help="disable test", + ) + parser.add_argument( + "-p", + "--project", + type=str, + default="vae", + help="name of the project to which this run will belong" + ) + parser.add_argument( + "-d", + "--debug", + type=str2bool, + nargs="?", + const=True, + default=False, + help="enable post-mortem debugging", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=23, + help="seed for seed_everything", + ) + parser.add_argument( + "-f", + "--postfix", + type=str, + default="", + help="post-postfix for default name", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + default="logs", + help="directory for logging dat shit", + ) + parser.add_argument( + "--scale_lr", + type=str2bool, + nargs="?", + const=True, + default=True, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + parser.add_argument( + "--num_nodes", + type=int, + default=1, + help="number of nodes for distributed training", + ) + parser.add_argument( + "--wandb_id", + type=str, + default=None, + help="run ID resuming logging of a model", + ) + parser.add_argument( + "--max_ep", + type=int, + default=200, + help="number of training epochs", + ) + return parser + + +def nondefault_trainer_args(opt): + parser = argparse.ArgumentParser() + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args([]) + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + + +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, Txt2ImgIterableBaseDataset): + split_size = dataset.num_records // worker_info.num_workers + # reset num_records to the true number to retain reliable length information + dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + else: + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +class DataModuleFromConfig(pl.LightningDataModule): + def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, + wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, + shuffle_val_dataloader=False): + super().__init__() + self.batch_size = batch_size + self.dataset_configs = dict() + self.num_workers = 4 # num_workers if num_workers is not None else batch_size * 2 + self.use_worker_init_fn = use_worker_init_fn + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + self.wrap = wrap + + def prepare_data(self): + for data_cfg in self.dataset_configs.values(): + instantiate_from_config(data_cfg) + + def setup(self, stage=None): + self.datasets = dict( + (k, instantiate_from_config(self.dataset_configs[k])) + for k in self.dataset_configs) + if self.wrap: + for k in self.datasets: + self.datasets[k] = WrappedDataset(self.datasets[k]) + + def _train_dataloader(self): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader(self.datasets["train"], batch_size=self.batch_size, + num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn) + + def _val_dataloader(self, shuffle=False): + if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader(self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _test_dataloader(self, shuffle=False): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + + # do not shuffle dataloader for iterable dataset + shuffle = shuffle and (not is_iterable_dataset) + + return DataLoader(self.datasets["test"], batch_size=self.batch_size, + num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) + + def _predict_dataloader(self, shuffle=False): + if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader(self.datasets["predict"], batch_size=self.batch_size, + num_workers=self.num_workers, worker_init_fn=init_fn) + + +class SetupCallback(Callback): + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + self.config = config + self.lightning_config = lightning_config + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def on_pretrain_routine_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + if "callbacks" in self.lightning_config: + if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: + os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + print("Project config") + print(OmegaConf.to_yaml(self.config)) + OmegaConf.save(self.config, + os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + + print("Lightning config") + print(OmegaConf.to_yaml(self.lightning_config)) + OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass + + +class ImageLogger(Callback): + def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True, + rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, + log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = { + pl.loggers.TestTubeLogger: self._testtube, + } + self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + @rank_zero_only + def _testtube(self, pl_module, images, batch_idx, split): + for k in images: + grid = torchvision.utils.make_grid(images[k]) + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + + tag = f"{split}/{k}" + pl_module.logger.experiment.add_image( + tag, grid, + global_step=pl_module.global_step) + + @rank_zero_only + def log_local(self, save_dir, split, images, + global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "images", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( + k, + global_step, + current_epoch, + batch_idx) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and + callable(pl_module.log_images) and + self.max_images > 0): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1., 1.) + + self.log_local(pl_module.logger.save_dir, split, images, + pl_module.global_step, pl_module.current_epoch, batch_idx) + + logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) + logger_log_images(pl_module, images, pl_module.global_step, split) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( + check_idx > 0 or self.log_first_step): + try: + self.log_steps.pop(0) + except IndexError as e: + print(e) + pass + return True + return False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + self.log_img(pl_module, batch, batch_idx, split="train") + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.root_gpu) + torch.cuda.synchronize(trainer.root_gpu) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module, outputs): + torch.cuda.synchronize(trainer.root_gpu) + max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.training_type_plugin.reduce(max_memory) + epoch_time = trainer.training_type_plugin.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +if __name__ == '__main__': + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + # add cwd for convenience and to make classes in this file available when + # running as `python main.py` + # (in particular `main.DataModuleFromConfig`) + sys.path.append(os.getcwd()) + + parser = get_parser() + opt, unknown = parser.parse_known_args() + + # distributed training parameters + gpus = torch.cuda.device_count() + num_nodes = opt.num_nodes + rank = int(os.getenv('NODE_RANK')) if os.getenv('NODE_RANK') is not None else 0 + print(f'[INFO] Number of GPUs available: {gpus}') + print(f'[INFO] Number of nodes: {num_nodes}') + + pl.seed_everything(opt.seed, workers=True) + + # logging, checkpointing and resuming + save_ckpt_dir = f'{opt.logdir}/{opt.name}' + resume = opt.resume is not None + + if rank == 0: # prevents from logging multiple times + logger = WandbLogger( + project=opt.project, name=opt.name, id=opt.wandb_id, + offline=False, resume='must' if resume else None, + ) + else: + logger = WandbLogger( + project=opt.project, name=opt.name, + offline=True, + ) + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + # model + if resume: + model = AutoencoderKL.load_from_checkpoint( + opt.resume, + ddconfig=config.model.params.ddconfig, + lossconfig=config.model.params.lossconfig, + embed_dim=config.model.params.embed_dim, + ) + print(f'[INFO] Loaded from {opt.resume}') + else: + model = AutoencoderKL( + ddconfig=config.model.params.ddconfig, + lossconfig=config.model.params.lossconfig, + embed_dim=config.model.params.embed_dim, + ) + + # data + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + print("#### Data #####") + for k in data.datasets: + print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + + # inspect the first batch + train_loader = data.train_dataloader() + batch = next(iter(train_loader)) + x = batch["image"] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + print(f'[INFO] Batch images shape: {x.shape}, max: {x.max()}, min: {x.min()}') + x = torchvision.utils.make_grid((x + 1) / 2, nrow=8, padding=0) + x = x.permute(1, 2, 0).mul_(255).cpu().numpy() + x = Image.fromarray(x.astype(np.uint8)) + x.save('ref.png') + + # configure learning rate + bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate + accumulate_bs = bs * num_nodes * gpus + accumulate_grad_batches = config.lightning.trainer.accumulate_grad_batches + lr = accumulate_grad_batches * accumulate_bs * base_lr + model.learning_rate = lr + + print(f'[INFO] Accumulate_grad_batches = {accumulate_grad_batches}') + print(f'[INFO] Cumulative batch size: {accumulate_bs}') + print(f'[INFO] Final learning rate: {lr:.2e}') + + # callbacks + checkpoint_callback = ModelCheckpoint( + dirpath=save_ckpt_dir, + filename='{epoch:03d}', + save_last=True, + save_top_k=1, + monitor=config.model.params.monitor, + mode='min', + every_n_epochs=1, + ) + callbacks = [LearningRateMonitor(logging_interval='step'), checkpoint_callback] + + # PyTorch Lightning Trainer + trainer = pl.Trainer( + accelerator='gpu', num_nodes=num_nodes, devices=gpus, + precision='32', deterministic=True, + callbacks=callbacks, logger=logger, + strategy=DDPStrategy(find_unused_parameters=True), + accumulate_grad_batches=accumulate_grad_batches, + max_epochs=opt.max_ep, + ) + + # train the model + trainer.fit(model, data, ckpt_path=opt.resume) + \ No newline at end of file diff --git a/run_vae.text b/run_vae.text new file mode 100644 index 000000000..8a144eb7d --- /dev/null +++ b/run_vae.text @@ -0,0 +1 @@ +python3 main.py --base configs/autoencoder/autoencoder_kl_16x16x16.yaml diff --git a/slurm_scripts/run_analysis.sh b/slurm_scripts/run_analysis.sh new file mode 100644 index 000000000..3943d855f --- /dev/null +++ b/slurm_scripts/run_analysis.sh @@ -0,0 +1,25 @@ +#!/bin/bash +#SBATCH -p gpu17 +#SBATCH --gres gpu:2 +#SBATCH --nodes 2 +#SBATCH --ntasks-per-node 2 +#SBATCH -t 0-02:00:0 + +# activate Conda environment +source ~/.bashrc +conda activate ldm-new + +cd .. + +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +export MASTER_PORT=29500 + +# print start time +echo $(date) +echo $SLURM_NTASKS +echo $MASTER_ADDR +echo $MASTER_PORT + +# start command +srun python3 analysis.py +srun python3 analysis_wav.py diff --git a/slurm_scripts/run_train_vae.sh b/slurm_scripts/run_train_vae.sh new file mode 100644 index 000000000..34e1211ce --- /dev/null +++ b/slurm_scripts/run_train_vae.sh @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH -p gpu17 +#SBATCH --gres gpu:2 +#SBATCH --nodes 8 +#SBATCH --ntasks-per-node 2 +#SBATCH -t 0-08:00:0 +#SBATCH -o train_vae_output_%j.log +#SBATCH -e train_vae_error_%j.log + +# activate Conda environment +source ~/.bashrc +conda activate ldm-new + +cd .. + +cmd="python3 main.py --base configs/autoencoder/autoencoder_kl_16x16x16.yaml \ + --num_nodes $SLURM_JOB_NUM_NODES --max_ep 120 \ + --name ch1-1-2-2-4_baseline \ + --resume logs/ch1-1-2-2-4_baseline/last.ckpt --wandb_id zrnpbzxc" + +# print start time and command to log +echo $(date) +echo $cmd + +# start command +srun $cmd diff --git a/slurm_scripts/run_train_vae_wav.sh b/slurm_scripts/run_train_vae_wav.sh new file mode 100644 index 000000000..fd744ef4c --- /dev/null +++ b/slurm_scripts/run_train_vae_wav.sh @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH -p gpu17 +#SBATCH --gres gpu:2 +#SBATCH --nodes 8 +#SBATCH --ntasks-per-node 2 +#SBATCH -t 0-08:00:0 +#SBATCH -o train_vae_wav_output_%j.log +#SBATCH -e train_vae_wav_error_%j.log + +# activate Conda environment +source ~/.bashrc +conda activate ldm-new + +cd .. + +cmd="python3 main_wav.py --base configs/autoencoder/autoencoder_kl-wav_16x16x16.yaml \ + --num_nodes $SLURM_JOB_NUM_NODES --max_ep 150 \ + --name ch1-2-4_baseline \ + --resume logs/ch1-2-4_baseline/last.ckpt --wandb_id zkrbwrm6" + +# print start time and command to log +echo $(date) +echo $cmd + +# start command +srun $cmd