diff --git a/.gitignore b/.gitignore index 33839e8..0273b1b 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,5 @@ dmypy.json # Pyre type checker .pyre/ + +.idea/ diff --git a/probaforms/models/__init__.py b/probaforms/models/__init__.py index ea94bc8..642b040 100644 --- a/probaforms/models/__init__.py +++ b/probaforms/models/__init__.py @@ -1,10 +1,12 @@ from .realnvp import RealNVP from .cvae import CVAE from .wgan import ConditionalWGAN +from .residual import ResidualFlow __all__ = [ 'RealNVP', 'CVAE', - 'ConditionalWGAN' + 'ConditionalWGAN', + 'ResidualFlow', ] diff --git a/probaforms/models/residual/__init__.py b/probaforms/models/residual/__init__.py new file mode 100644 index 0000000..e873f2e --- /dev/null +++ b/probaforms/models/residual/__init__.py @@ -0,0 +1,8 @@ +''' +"Residual Flows for Invertible Generative Modeling" (arxiv.org/abs/1906.02735) +Realization of (un-)conditional Residual Flow for tabular data +Code based on github.com/tatsy/normalizing-flows-pytorch +Conditioning idea: "Learning Likelihoods with Conditional Normalizing Flows" arxiv.org/abs/1912.00042) +''' + +from .model import ResidualFlow diff --git a/probaforms/models/residual/gradients.py b/probaforms/models/residual/gradients.py new file mode 100644 index 0000000..462e561 --- /dev/null +++ b/probaforms/models/residual/gradients.py @@ -0,0 +1,274 @@ +import torch +from torch import randn as rand_normal +import numpy as np + + +# ==================== logdet estimators for residual flow ======================== + +def logdet_Jg_exact(g, x): + """ + Exact logdet determinant computation (naive forehead approach) + Args: + g: outputs g(x) + x: inputs to g function (optimized network) + Returns: log(I + Jg(x)), where Jg(x) is the Jacobian defined as dg(x) / dx + """ + + var_dim = g.shape[1] + + Jg = [ + torch.autograd.grad(g[:, i].sum(), x, create_graph=True, retain_graph=True)[0] + for i in range(x.size(1)) + ] + + Jg = torch.stack(Jg, dim=1) + ident = torch.eye(x.size(1)).type_as(x).to(x.device) + return torch.logdet(ident + Jg) + + +def logdet_Jg_cutoff(g, x, n_samples=1, n_power_series=8): + """ + Biased logdet estimator with FIXED (!) number of trace's series terms, see paper, eq. (7) + Skilling-Hutchinson trace estimator is used to estimate the trace of Jacobian matrices + + + Unfortunately, this estimator requires each term to be stored in memory because ∂/∂θ needs to be + applied to each term. The total memory cost is then O(n · m) where n is the number of computed + terms and m is the number of residual blocks in the entire network. This is extremely memory-hungry + during training, and a large random sample of n can occasionally result in running out of memory + + Args: + g: outputs g(x) + x: inputs to g function (optimized network) + n_samples: number of v samples + n_power_series: fixed number of computed terms, param n in paper + Returns: log determinant approximation using FIXED (!) length cutoff for infinite series + which can be used with residual block f(x) = x + g(x) + """ + + var_dim = g.shape[1] + + # sample v ~ N(0, 1) + v = rand_normal([g.size(0), n_samples, g.size(1)]) + v = v.type_as(x).to(x.device) + + # v^T Jg -- vector-Jacobian product + def w_t_J_fn(w): + new_w = torch.autograd.grad(g, x, grad_outputs=w, retain_graph=True, create_graph=True)[0] + new_w = new_w[:, :var_dim].reshape(new_w.shape[0], -1) # x = [y, cond], derivatives only w.r.t. y + return new_w + + sum_diag = 0.0 + w = v.clone() + for k in range(1, n_power_series + 1): + w = [w_t_J_fn(w[:, i, :]) for i in range(n_samples)] + w = torch.stack(w, dim=1) + + # v^T Jg^k v term + inner = torch.einsum('bnd,bnd->bn', w, v) + sum_diag += (-1) ** (k + 1) * (inner / k) + + # mathematical expectation + return sum_diag.sum(dim=1) / n_samples + + +def logdet_Jg_unbias(g, x, n_samples=1, p=0.5, n_exact=1, is_training=True): + """ + Unbiased logdet estimator with UNFIXED (!) number of trace's series terms, see paper, eq. (6), also see eq. (8) + Number of terms is sampled by geometric distribution + Skilling-Hutchinson trace estimator is used to estimate the trace of Jacobian matrices + + As the power series in (8) does not need to be differentiated through, using this reduces the memory + requirement by a factor of n. This is especially useful when using the unbiased estimator as the + memory will be constant regardless of the number of terms we draw from p(N) + + Args: + g: outputs g(x) + x: inputs to g function (optimized network) + n_samples: number of v samples + p: geometric distribution parameter + n_exact: number of terms to be exactly computed + is_training: True if training phase else False + Returns: log determinant approximation using unbiased series length sampling (UNFIXED LEN) + which can be used with residual block f(x) = x + g(x) + """ + + ''' + In conditional case inputs x = [y, cond] of shape (var_dim + cond_dim) + Outputs g(x) shape is always (var_dim) + ''' + + var_dim = g.shape[1] + + def geom_cdf(k): + # P[N >= k] = 1 - f_geom(k), Geom(p) probability + return (1.0 - p) ** max(0, k - n_exact) + + res = 0.0 + for j in range(n_samples): + n_power_series = n_exact + np.random.geometric(p) + v = torch.randn_like(g) # N(0, 1) by paper + w = v + + sum_vj = 0.0 + for k in range(1, n_power_series + 1): + # v^T Jg -- vector-Jacobian product + w = torch.autograd.grad(g, x, w, create_graph=is_training, retain_graph=True)[0] + w = w[:, :var_dim].reshape(w.shape[0], -1) # x = [y, cond], derivatives only w.r.t. y + P_N_ge_k = geom_cdf(k - 1) # P[N >= k] + tr = torch.sum(w * v, dim=1) # v^T Jg v + sum_vj = sum_vj + (-1) ** (k + 1) * (tr / (k * P_N_ge_k)) + res += sum_vj + return res / n_samples + + +def logdet_Jg_neumann(g, x, n_samples=1, p=0.5, n_exact=1): + """ + Unbiased Neumann logdet estimator see paper with russian roulette applied, see paper, eq. (8) and app. C + Provides Neumann gradient series with russian roulette and trace estimator applied to obtain the theorem (8) + Args: + g: outputs g(x) + x: inputs to g function (optimized network) + n_samples: number of v samples + p: geometric distribution parameter + n_exact: number of terms to be exactly computed + Returns: log determinant approximation using unbiased series length sampling + + NOTE: this method using neumann series does not return exact "log_df_dz" + but the one that can be only used in gradient wrt parameters + see: https://github.com/rtqichen/residual-flows/blob/f9dd4cd0592d1aa897f418e25cae169e77e4d692/lib/layers/iresblock.py#L249 + and: https://github.com/tatsy/normalizing-flows-pytorch/blob/f5238fa8ce62a130679a1cf4474e195926b4842f/flows/iresblock.py#L84 + """ + + ''' + In conditional case inputs x = [y, cond] of shape (var_dim + cond_dim) + Outputs g(x) shape is always (var_dim) + ''' + + var_dim = g.shape[1] + + def geom_cdf(k): + # P[N >= k] = 1 - f_geom(k), Geom(p) probability + return (1.0 - p) ** max(0, k - n_exact) + + res = 0.0 + for j in range(n_samples): + n_power_series = n_exact + np.random.geometric(p) + + v = torch.randn_like(g) + w = v + + sum_vj = v + with torch.no_grad(): + # v^T Jg sum + for k in range(1, n_power_series + 1): + # v^T Jg -- vector-Jacobian product + w = torch.autograd.grad(g, x, w, retain_graph=True)[0] + w = w[:, :var_dim].view(w.shape[0], -1) # x = [y, cond], derivatives only w.r.t. y + P_N_ge_k = geom_cdf(k - 1) # P[N >= k] + sum_vj = sum_vj + ((-1) ** k / P_N_ge_k) * w + + # Jg v + sum_vj = torch.autograd.grad(g, x, sum_vj, create_graph=True)[0] + sum_vj = sum_vj[:, :var_dim].view(sum_vj.shape[0], -1) # аналогично + res += torch.sum(sum_vj * v, dim=1) + return res / n_samples + + +class MemorySavedLogDetEstimator(torch.autograd.Function): + """ + Memory saving logdet estimator, see paper, 3.2 and app. C + Provides custom memory-saving backprop + """ + @staticmethod + def forward(ctx, logdet_fn, x, net_g_fn, training, *g_params): + """ + Args: + ctx: context object (see https://pytorch.org/docs/stable/autograd.html#function) + logdet_fn: logdet estimator function for loss calculation + x: inputs to g(x) + net_g_fn: optimized function (network) + training: True if training phase, else False + *g_params: parameters of g + + Returns: + g(x): outputs g for inputs x + logdet: estimated logdet + """ + + ctx.training = training + with torch.enable_grad(): + x = x.detach().requires_grad_(True) + g = net_g_fn(x) + ctx.x = x # shape (var_dim + cond_dim) if cond else (var_dim) + ctx.g = g # shape (var_dim) in any case + + # Backward-in-forward: early computation of gradient + # Pass params x and theta, return grads w.r.t. x and theta + # https://pytorch.org/docs/stable/generated/torch.autograd.grad.html + theta = list(g_params) + if ctx.training: + # logdet for neumann series + logdetJg = logdet_Jg_neumann(g, x).sum() + dlogdetJg_dx, *dlogdetJg_dtheta = torch.autograd.grad(logdetJg, [x] + theta, + retain_graph=True, + allow_unused=True) + ctx.save_for_backward(dlogdetJg_dx, *theta, *dlogdetJg_dtheta) + + # logdet for loss calculation + logdet = logdet_fn(g, x) + return safe_detach(g), safe_detach(logdet) + + @staticmethod + def backward(ctx, dL_dg, dL_dlogdetJg): + """ + NOTE: Be careful that chain rule for partial differentiation is as follows + df(y, z) df dy df dz + -------- = -- * -- + -- * -- + dx dy dx dz dx + """ + + training = ctx.training + if not training: + raise ValueError('Provide training=True if using backward.') + + # chain rule for partial differentiation (1st term) + with torch.enable_grad(): + g, x = ctx.g, ctx.x + dlogdetJg_dx, *saved_tensors = ctx.saved_tensors + n_params = len(saved_tensors) // 2 + theta = saved_tensors[:n_params] + dlogdetJg_dtheta = saved_tensors[n_params:] # 2nd multiplier of (9) + + dL_dx_1st, *dL_dtheta_1st = torch.autograd.grad(g, [x] + theta, + grad_outputs=dL_dg, + allow_unused=True) + + # chain rule for partial differentiation (2nd term) + # NOTE: dL_dlogdetJg consists of same values for all dimensions (see forward). + dL_dlogdetJg_scalar = dL_dlogdetJg[0].detach() # 1st multiplier of (9) + with torch.no_grad(): + dL_dx_2nd = dlogdetJg_dx * dL_dlogdetJg_scalar # see paper eq. (9) + dL_dtheta_2nd = tuple( + [g * dL_dlogdetJg_scalar if g is not None else None for g in dlogdetJg_dtheta]) + + with torch.no_grad(): + dL_dx = dL_dx_1st + dL_dx_2nd + dL_dtheta = tuple([ + g1 + g2 if g2 is not None else g1 for g1, g2 in zip(dL_dtheta_1st, dL_dtheta_2nd) + ]) + + return (None, dL_dx, None, None) + dL_dtheta + + +def memory_saved_logdet_wrapper(logdet_fn, x, net_g_fn, training): + # x = [y] or [y, cond] + g_params = list(net_g_fn.parameters()) + return MemorySavedLogDetEstimator.apply(logdet_fn, x, net_g_fn, training, *g_params) + + +def safe_detach(x): + """ + detach operation which keeps reguires_grad + """ + return x.detach().requires_grad_(x.requires_grad) diff --git a/probaforms/models/residual/model.py b/probaforms/models/residual/model.py new file mode 100644 index 0000000..cef7acd --- /dev/null +++ b/probaforms/models/residual/model.py @@ -0,0 +1,273 @@ +import os +from tqdm import tqdm +from typing import Union + +import torch +import torch.nn as nn +from torch.utils.data import TensorDataset +from torch.utils.data import DataLoader +from torch.distributions.multivariate_normal import MultivariateNormal + +from .modules import ActNorm, InvertibleResLinear + + +class ResidualFlowModel(nn.Module): + ''' + Residual Flow model class + Pass concat [X, y] if conditioning, return only y + ''' + def __init__(self, var_dim, cond_dim=None, n_layers=6, hid_dim=32, n_block_layers=2, + spnorm_coeff=0.97, logdet='unbias', n_backward_iters=100): + """ + Args: + var_dim: target data size + cond_dim: conditional data size (None if not used) + n_layers: number of residual blocks in model + hid_dim: residual block hidden size + n_block_layers: number of layers in each residual block + spnorm_coeff: spectral normalization coeff (Lipschitz), must be < 1 + logdet: logdet estimation strategy + n_backward_iters: number of iterations to sample the object + """ + super().__init__() + self.var_dim = var_dim + self.cond_dim = cond_dim + self.in_dim = var_dim + cond_dim if cond_dim is not None else var_dim + self.out_dim = self.var_dim + + self.n_layers = n_layers + self.device = 'cpu' + self.net = None + + assert spnorm_coeff < 1 + self.actnorm_in_dim = self.in_dim if self.cond_dim is None else self.var_dim + self.hid_dim = hid_dim + self.net = nn.ModuleList() + for i in range(self.n_layers): + self.net.append(ActNorm(self.actnorm_in_dim)) + self.net.append( + InvertibleResLinear(self.in_dim, self.out_dim, base_filters=self.hid_dim, + coeff=spnorm_coeff, n_layers=n_block_layers, + logdet_estimator=logdet, n_backward_iters=n_backward_iters) + ) + + def forward_process(self, z, cond=None): + log_df_dz = torch.zeros(z.size(0)).type_as(z).to(z.device) + for i, layer in enumerate(self.net): + if cond is not None and i % 2 == 1: # if layer is InvertibleResLinear + z = torch.cat([z, cond], dim=1) + z, log_df_dz = layer(z, log_df_dz) + return z, log_df_dz + + def backward_process(self, z, cond=None): + log_df_dz = torch.zeros(z.size(0)).type_as(z).to(z.device) + for i, layer in enumerate(self.net[::-1]): + if cond is not None and i % 2 == 0: # if layer is InvertibleResLinear + z = torch.cat([z, cond], dim=1) + z, log_df_dz = layer.backward(z, log_df_dz) + return z, log_df_dz + + def to(self, device): + super().to(device) + self.device = device + self.net = self.net.to(device) + return self + + +# =========================== Wrappers ========================== + + +class BaseFlowWrapper(object): + def __init__(self, var_dim, cond_dim=None, n_layers=6, hid_dim=32, n_block_layers=2, + spnorm_coeff=0.97, logdet='unbias', n_backward_iters=100, + optimizer=None, batch_size=64, n_epochs=100, checkpoint_dir=None, device='cpu', + scheduler=None, **scheduler_kwargs): + """ + Args: + var_dim: target data size + cond_dim: conditional data size (None if not used) + n_layers: number of residual blocks in model + hid_dim: residual block hidden size + n_block_layers: number of layers in each residual block + spnorm_coeff: spectral normalization coeff (Lipschitz), must be < 1 + logdet: logdet estimation strategy + n_backward_iters: number of iterations to sample the object + """ + self.flow = ResidualFlowModel(var_dim, cond_dim, n_layers, + hid_dim, n_block_layers, spnorm_coeff, logdet, n_backward_iters).to(device) + if optimizer is not None: + self.optim = optimizer + else: + self.optim = torch.optim.Adam(self.flow.parameters(), lr=1e-2, weight_decay=1e-4) + self.batch_size = batch_size + self.n_epochs = n_epochs + self.scheduler = scheduler + + self.checkpoint_dir = checkpoint_dir + if self.checkpoint_dir is not None: + try: + os.makedirs(self.checkpoint_dir) + print(f'Created directory {self.checkpoint_dir}') + except: + print(f'Directory {self.checkpoint_dir} already exists or can not be created') + pass + + self.device = self.flow.device + + self.min_epoch_loss = torch.inf + self.last_epoch_loss = None + + self.var_dim = self.flow.var_dim + self.cond_dim = self.flow.cond_dim + self.in_dim = self.flow.in_dim + self.out_dim = self.flow.out_dim + + self.mu = torch.zeros(self.out_dim, dtype=torch.float32, device=self.device) + self.var = torch.eye(self.out_dim, dtype=torch.float32, device=self.device) + self.normal = MultivariateNormal(self.mu, self.var) + + def fit(self, Y: torch.Tensor, X_cond: torch.Tensor = None): + """ + Fits flow + Args: + Y: input objects tensor of shape (B, var_dim) + X_cond: condition tensor of shape (B, cond_dim) + Returns: epochs losses list + """ + raise NotImplemented + + def sample(self, input: Union[torch.tensor, int], batch_size=None): + """ + Samples objects from condition of the X_cond's shape + Args: + input: int N in unconditional case, torch.Tensor X_cond else + batch_size: None if no batchification used, else int -- batch_size + Returns: new objects + """ + raise NotImplemented + + def loss(self, z, logdet): + """ + Computes loss (likelihood), see slide 20: + https://github.com/HSE-LAMBDA/DeepGenerativeModels/blob/spring-2021/lectures/8-NF.pdf + Args: + z: predicted data + logdet: computed logdet Jacobian + Returns: mean negative log-likehood loss log(p(z)) = log(p(g(z)) + logdet Jg(z) + """ + return -(self.normal.log_prob(z) + logdet).mean() + + def checkpoint(self): + """Save model at the best epochs (with minimal loss)""" + if self.checkpoint_dir is None: + return + + if self.last_epoch_loss > self.min_epoch_loss: + return + + self.min_epoch_loss = self.last_epoch_loss + torch.save(self.flow.state_dict(), os.path.join(self.checkpoint_dir, f'flow.pt')) + torch.save(self.optim.state_dict(), os.path.join(self.checkpoint_dir, f'optim.pt')) + if self.scheduler is not None: + torch.save(self.scheduler.state_dict(), os.path.join(self.checkpoint_dir, f'sched.pt')) + + def load_from_checkpoint(self, strict=False): + """Load model from checkpoint""" + if self.checkpoint_dir is None: + return + + self.flow.load_state_dict(torch.load(os.path.join(self.checkpoint_dir, f'flow.pt')), strict=strict) + self.optim.load_state_dict(torch.load(os.path.join(self.checkpoint_dir, f'optim.pt'))) + if self.scheduler is not None: + self.scheduler.load_state_dict(torch.load(os.path.join(self.checkpoint_dir, f'sched.pt'))) + + +class ResidualFlow(BaseFlowWrapper): + def __init__(self, var_dim, cond_dim=None, n_layers=6, hid_dim=32, n_block_layers=2, + spnorm_coeff=0.97, logdet='unbias', n_backward_iters=100, + optimizer=None, batch_size=64, n_epochs=10, checkpoint_dir=None, device='cpu', + scheduler=None, **scheduler_kwargs): + super().__init__(var_dim, cond_dim, n_layers, hid_dim, n_block_layers, spnorm_coeff, logdet, n_backward_iters, + optimizer, batch_size, n_epochs, checkpoint_dir, device, scheduler, **scheduler_kwargs) + + def fit(self, Y: torch.Tensor, X_cond: torch.Tensor = None): + if X_cond is not None: + td = TensorDataset(Y, X_cond) + else: + td = TensorDataset(Y) + + batches = DataLoader(td, batch_size=self.batch_size, shuffle=True) + + losses = [] + self.flow.train() + for _ in tqdm(range(self.n_epochs), desc=f"Epoch"): + epoch_loss = 0.0 + for data in batches: + self.optim.zero_grad() + + y = data[0].to(self.device) + if X_cond is not None: + x_cond = data[1].to(self.device) + else: + x_cond = None + + z, logdet = self.flow.forward_process(y, x_cond) + + loss = self.loss(z, logdet) + loss.backward() + self.optim.step() + + epoch_loss += loss.item() * y.shape[0] / Y.shape[0] + + if self.scheduler is not None: + self.scheduler.step() + + losses.append(epoch_loss) + self.last_epoch_loss = epoch_loss + self.checkpoint() + + self.load_from_checkpoint() + return losses + + def sample(self, input: Union[torch.tensor, int], batch_size=None): + N = None; X_cond = None + if isinstance(input, int): + N = input + assert self.flow.cond_dim is None + elif isinstance(input, torch.Tensor): + X_cond = input + assert self.flow.cond_dim == X_cond.shape[1] + else: + raise ValueError('Undefined input type') + + if X_cond is not None: + td = TensorDataset(X_cond) + bs = X_cond.shape[1] if batch_size is None else batch_size + batches = DataLoader(td, batch_size=bs, shuffle=False) + else: + if batch_size is not None: + bs = batch_size + n_batches = N // bs + remains = N - bs * n_batches + batches = [bs for _ in range(n_batches)] + if remains > 0: + batches += [remains] + else: + batches = [N] + + Ys = [] + self.flow.eval() + with torch.no_grad(): + for data in tqdm(batches, 'batch'): + if isinstance(data, int): + x_cond = None + size = data + else: + x_cond = data[0].to(self.device) + size = x_cond.size(0) + + z = self.normal.sample((size,)).to(self.device) + y, _ = self.flow.backward_process(z, x_cond) + Ys.append(y.cpu()) + Y = torch.cat(Ys, dim=0) + return Y diff --git a/probaforms/models/residual/modules.py b/probaforms/models/residual/modules.py new file mode 100644 index 0000000..d520003 --- /dev/null +++ b/probaforms/models/residual/modules.py @@ -0,0 +1,230 @@ +import torch +from torch import nn +import numpy as np + +from .gradients import memory_saved_logdet_wrapper, safe_detach +from .gradients import logdet_Jg_exact, logdet_Jg_cutoff, logdet_Jg_unbias + + +def l2normalize(v, eps=1e-12): + return v / (v.norm() + eps) + + +class SpectralNorm(nn.Module): + """ + Modified spectral normalization [Miyato et al. 2018] for invertible residual networks + Most of this implementation is borrowed from the following link: + https://github.com/christiancosgrove/pytorch-spectral-normalization-gan + See paper, app. D, eq. (16), and slide 25 of: + https://github.com/HSE-LAMBDA/DeepGenerativeModels/blob/spring-2021/lectures/9-NF2.pdf + """ + + def __init__(self, module, coeff=0.97, eps=1.0e-5, name='weight', power_iterations=1): + super(SpectralNorm, self).__init__() + self.module = module + self.coeff = coeff + self.eps = eps + self.name = name + self.power_iterations = power_iterations + + if not self._made_params(): + self._make_params() + + def _update_u_v(self): + u = getattr(self.module, self.name + '_u') + v = getattr(self.module, self.name + '_v') + w = getattr(self.module, self.name + '_bar') + + height = w.data.shape[0] + for _ in range(self.power_iterations): + v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) + u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) + + sigma = u.dot(w.view(height, -1).mv(v)) + scale = self.coeff / (sigma + self.eps) + + delattr(self.module, self.name) + if scale < 1.0: + setattr(self.module, self.name, w * scale.expand_as(w)) + else: + setattr(self.module, self.name, w) + + def _made_params(self): + try: + _ = getattr(self.module, self.name + '_u') + _ = getattr(self.module, self.name + '_v') + _ = getattr(self.module, self.name + '_bar') + return True + except AttributeError: + return False + + def _make_params(self): + w = getattr(self.module, self.name) + + height = w.data.shape[0] + width = w.view(height, -1).data.shape[1] + + u = w.data.new(height).normal_(0, 1) + v = w.data.new(width).normal_(0, 1) + u.data = l2normalize(u.data) + v.data = l2normalize(v.data) + w_bar = nn.Parameter(w.data) + + self.module.register_buffer(self.name + '_u', u) + self.module.register_buffer(self.name + '_v', v) + self.module.register_parameter(self.name + '_bar', w_bar) + + def forward(self, *args): + self._update_u_v() + return self.module.forward(*args) + + +class ActNorm(nn.Module): + def __init__(self, var_dim, eps=1e-5): + super(ActNorm, self).__init__() + self.var_dim = var_dim + self.eps = eps + + self.register_parameter('log_scale', nn.Parameter(torch.zeros(self.var_dim))) + self.register_parameter('bias', nn.Parameter(torch.zeros(self.var_dim))) + self.initialized = False + + def forward(self, z, log_df_dz): + if not self.initialized: + z_reshape = z.view(z.size(0), self.var_dim, -1) + log_std = torch.log(torch.std(z_reshape, dim=[0, 2]) + self.eps) + mean = torch.mean(z_reshape, dim=[0, 2]) + self.log_scale.data.copy_(log_std.view(self.var_dim)) + self.bias.data.copy_(mean.view(self.var_dim)) + self.initialized = True + + z = (z - self.bias) / torch.exp(self.log_scale) + + num_pixels = np.prod(z.size()) // (z.size(0) * z.size(1)) + log_df_dz -= torch.sum(self.log_scale) * num_pixels + return z, log_df_dz + + def backward(self, y, log_df_dz): + y = y * torch.exp(self.log_scale) + self.bias + num_pixels = np.prod(y.size()) // (y.size(0) * y.size(1)) + log_df_dz += torch.sum(self.log_scale) * num_pixels + return y, log_df_dz + + +class LipSwish(nn.Module): + def __init__(self): + super(LipSwish, self).__init__() + beta = nn.Parameter(torch.ones([1], dtype=torch.float32)) + self.register_parameter('beta', beta) + + def forward(self, x, cond=None): + return x * torch.sigmoid(self.beta * x) / 1.1 + + +class InvertibleResBlockBase(nn.Module): + """ invertible residual block""" + def __init__(self, coeff=0.97, ftol=1e-4, logdet_estimator='unbias', n_backward_iters=100): + super(InvertibleResBlockBase, self).__init__() + + self.coeff = coeff + self.ftol = ftol + self.estimator = logdet_estimator + self.proc_g_fn = memory_saved_logdet_wrapper + self.logdet_fn = self._get_logdet_estimator() + self.n_iters = n_backward_iters + + self.g_fn = ... + self.var_dim = ... + + def _get_logdet_estimator(self): + if self.training: + # force use unbiased log-det estimator + logdet_fn = lambda g, z: logdet_Jg_unbias(g, z, 1, is_training=self.training) + else: + if self.estimator == 'exact': + logdet_fn = logdet_Jg_exact + elif self.estimator == 'fixed': + logdet_fn = lambda g, z: logdet_Jg_cutoff(g, z, n_samples=5, n_power_series=10) + elif self.estimator == 'unbias': + logdet_fn = lambda g, z: logdet_Jg_unbias( + g, z, n_samples=5, n_exact=10, is_training=self.training) + else: + raise Exception('Unknown logdet estimator: %s' % self.estimator) + + return logdet_fn + + def forward(self, x, log_df_dz): + # x = [y] or [y, cond] + g, logdet = self.proc_g_fn(self.logdet_fn, x, self.g_fn, self.training) + # residual z = F(x) = y + g(y, cond), g is a network + z = x[:, :self.var_dim] + g + log_df_dz += logdet + return z, log_df_dz + + def backward(self, z, log_df_dz): + x = safe_detach(z.clone()) + cond = x[:, self.var_dim:].clone() + + with torch.enable_grad(): + x.requires_grad_(True) + cond.requires_grad_(True) + for k in range(self.n_iters): + x = safe_detach(x) + prev_x_var = safe_detach(x[:, :self.var_dim]) + # fixed point iteration x = y - g(x) + new_x_var = z[:, :self.var_dim] - self.g_fn(x) + if torch.all(torch.abs(new_x_var - prev_x_var) < self.ftol): + break + + x = safe_detach(torch.cat([new_x_var, cond], dim=1).requires_grad_(True)) + + del prev_x_var + logdet = self.logdet_fn(self.g_fn(x), x) + return new_x_var, log_df_dz - logdet + + def simple_backward(self, z, log_df_dz=0.0): + new_x_var = z[:, :self.var_dim] - self.g_fn(z) + return new_x_var, log_df_dz + + +class ResBackbone(nn.Module): + def __init__(self, in_features, + out_features, + base_filters=32, + n_layers=2, + coeff=0.97): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + hidden_dims = [in_features] + [base_filters] * n_layers + [out_features] + self.layers = nn.ModuleList() + for i, (in_dims, out_dims) in enumerate(zip(hidden_dims[:-1], hidden_dims[1:])): + module = nn.Linear(in_dims, out_dims) + self.layers.append(SpectralNorm(module, coeff=coeff)) + if i != len(hidden_dims) - 2: + self.layers.append(LipSwish()) + + def forward(self, x): + for i in range(len(self.layers)): + x = self.layers[i](x) + return x + + +class InvertibleResLinear(InvertibleResBlockBase): + def __init__(self, + in_features, + out_features, + base_filters=32, + n_layers=2, + coeff=0.97, + ftol=1.0e-4, + logdet_estimator='unbias', + n_backward_iters=100): + ''' + Pass concat [X, y] if conditioning, return only y + See class BaseFlow in model.py + ''' + super(InvertibleResLinear, self).__init__(coeff, ftol, logdet_estimator, n_backward_iters) + self.g_fn = ResBackbone(in_features, out_features, base_filters, n_layers, coeff) + self.var_dim = self.g_fn.out_features diff --git a/tests/test_models.py b/tests/test_models.py index 3e8a316..a6b0316 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,12 +1,26 @@ import numpy as np import pytest +import torch +import itertools from probaforms.models.interfaces import GenModel +from probaforms.models import ResidualFlow + def subclasses(cls): return set(cls.__subclasses__()).union(s for c in cls.__subclasses__() for s in subclasses(c)) +def gen_data(bias=0, N=500): + X = np.linspace(bias, bias + 5, N).reshape(-1, 1) + mu = np.exp(-X + bias) + eps = np.random.normal(0, 1, X.shape) + sigma = 0.05 * (X - bias + 0.5) + X = torch.from_numpy(X).to(torch.float32) + y = torch.from_numpy(mu + eps * sigma).to(torch.float32) + return X, y + + @pytest.mark.parametrize("model", subclasses(GenModel)) def test_with_conditions(model): n = 100 @@ -26,3 +40,36 @@ def test_without_conditions(model): gen.fit(X, C=None) X_gen = gen.sample(C=n) assert X_gen.shape == X.shape + + +logdets = ['exact', 'fixed', 'unbias'] +is_conds = [True, False] +devices = ['cpu', 'cuda'] +@pytest.mark.parametrize("logdet,device,is_cond", list(itertools.product(logdets, devices, is_conds))) +def test_resflow(logdet, device, is_cond): + n = 100 + X = torch.from_numpy(np.random.normal(size=(n, 5))).to(torch.float32) + y = torch.from_numpy(np.random.normal(size=(n, 3))).to(torch.float32) + len_y = y.shape[1]; len_X = X.shape[1] + + flow_args_dict = { + 'var_dim': len_y if is_cond else len_y + len_X, + 'cond_dim': len_X if is_cond else None, + 'hid_dim': 16, + 'n_block_layers': 3, + 'n_layers': 3, + 'spnorm_coeff': 0.95, + 'n_backward_iters': 100, + 'logdet': logdet, + 'device': device, + } + + wrapper = ResidualFlow(**flow_args_dict, n_epochs=50, batch_size=100) + + if is_cond: + _ = wrapper.fit(y, X) + _ = wrapper.sample(X).cpu() + else: + data = torch.cat([X, y], dim=1) + _ = wrapper.fit(data) + _ = wrapper.sample(500, batch_size=100).cpu()