|
| 1 | +import torch |
| 2 | +from torch import nn |
| 3 | +from torch.nn import Module, ModuleList |
| 4 | +import torch.nn.functional as F |
| 5 | +import torch.nn.utils.parametrize as parametrize |
| 6 | + |
| 7 | +from einops import rearrange, reduce |
| 8 | +from einops.layers.torch import Rearrange |
| 9 | + |
| 10 | +# functions |
| 11 | + |
| 12 | +def exists(v): |
| 13 | + return v is not None |
| 14 | + |
| 15 | +def default(v, d): |
| 16 | + return v if exists(v) else d |
| 17 | + |
| 18 | +def pair(t): |
| 19 | + return t if isinstance(t, tuple) else (t, t) |
| 20 | + |
| 21 | +def divisible_by(numer, denom): |
| 22 | + return (numer % denom) == 0 |
| 23 | + |
| 24 | +def l2norm(t, dim = -1): |
| 25 | + return F.normalize(t, dim = dim, p = 2) |
| 26 | + |
| 27 | +# for use with parametrize |
| 28 | + |
| 29 | +class L2Norm(Module): |
| 30 | + def __init__(self, dim = -1): |
| 31 | + super().__init__() |
| 32 | + self.dim = dim |
| 33 | + |
| 34 | + def forward(self, t): |
| 35 | + return l2norm(t, dim = self.dim) |
| 36 | + |
| 37 | +class NormLinear(Module): |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + dim, |
| 41 | + dim_out, |
| 42 | + norm_dim_in = True |
| 43 | + ): |
| 44 | + super().__init__() |
| 45 | + self.linear = nn.Linear(dim, dim_out, bias = False) |
| 46 | + |
| 47 | + parametrize.register_parametrization( |
| 48 | + self.linear, |
| 49 | + 'weight', |
| 50 | + L2Norm(dim = -1 if norm_dim_in else 0) |
| 51 | + ) |
| 52 | + |
| 53 | + @property |
| 54 | + def weight(self): |
| 55 | + return self.linear.weight |
| 56 | + |
| 57 | + def forward(self, x): |
| 58 | + return self.linear(x) |
| 59 | + |
| 60 | +# attention and feedforward |
| 61 | + |
| 62 | +class Attention(Module): |
| 63 | + def __init__( |
| 64 | + self, |
| 65 | + dim, |
| 66 | + *, |
| 67 | + dim_head = 64, |
| 68 | + heads = 8, |
| 69 | + dropout = 0. |
| 70 | + ): |
| 71 | + super().__init__() |
| 72 | + dim_inner = dim_head * heads |
| 73 | + self.to_q = NormLinear(dim, dim_inner) |
| 74 | + self.to_k = NormLinear(dim, dim_inner) |
| 75 | + self.to_v = NormLinear(dim, dim_inner) |
| 76 | + |
| 77 | + self.dropout = dropout |
| 78 | + |
| 79 | + self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** 0.25)) |
| 80 | + |
| 81 | + self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) |
| 82 | + self.merge_heads = Rearrange('b h n d -> b n (h d)') |
| 83 | + |
| 84 | + self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False) |
| 85 | + |
| 86 | + def forward( |
| 87 | + self, |
| 88 | + x |
| 89 | + ): |
| 90 | + q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) |
| 91 | + |
| 92 | + q, k, v = map(self.split_heads, (q, k, v)) |
| 93 | + |
| 94 | + # query key rmsnorm |
| 95 | + |
| 96 | + q, k = map(l2norm, (q, k)) |
| 97 | + q, k = (q * self.qk_scale), (k * self.qk_scale) |
| 98 | + |
| 99 | + # scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16 |
| 100 | + |
| 101 | + out = F.scaled_dot_product_attention( |
| 102 | + q, k, v, |
| 103 | + dropout_p = self.dropout if self.training else 0., |
| 104 | + scale = 1. |
| 105 | + ) |
| 106 | + |
| 107 | + out = self.merge_heads(out) |
| 108 | + return self.to_out(out) |
| 109 | + |
| 110 | +class FeedForward(Module): |
| 111 | + def __init__( |
| 112 | + self, |
| 113 | + dim, |
| 114 | + *, |
| 115 | + dim_inner, |
| 116 | + dropout = 0. |
| 117 | + ): |
| 118 | + super().__init__() |
| 119 | + dim_inner = int(dim_inner * 2 / 3) |
| 120 | + |
| 121 | + self.dim = dim |
| 122 | + self.dropout = nn.Dropout(dropout) |
| 123 | + |
| 124 | + self.to_hidden = NormLinear(dim, dim_inner) |
| 125 | + self.to_gate = NormLinear(dim, dim_inner) |
| 126 | + |
| 127 | + self.hidden_scale = nn.Parameter(torch.ones(dim_inner)) |
| 128 | + self.gate_scale = nn.Parameter(torch.ones(dim_inner)) |
| 129 | + |
| 130 | + self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False) |
| 131 | + |
| 132 | + def forward(self, x): |
| 133 | + hidden, gate = self.to_hidden(x), self.to_gate(x) |
| 134 | + |
| 135 | + hidden = hidden * self.hidden_scale |
| 136 | + gate = gate * self.gate_scale * (self.dim ** 0.5) |
| 137 | + |
| 138 | + hidden = F.silu(gate) * hidden |
| 139 | + |
| 140 | + hidden = self.dropout(hidden) |
| 141 | + return self.to_out(hidden) |
| 142 | + |
| 143 | +# classes |
| 144 | + |
| 145 | +class nViT(Module): |
| 146 | + """ https://arxiv.org/abs/2410.01131 """ |
| 147 | + |
| 148 | + def __init__( |
| 149 | + self, |
| 150 | + *, |
| 151 | + image_size, |
| 152 | + patch_size, |
| 153 | + num_classes, |
| 154 | + dim, |
| 155 | + depth, |
| 156 | + heads, |
| 157 | + mlp_dim, |
| 158 | + dropout = 0., |
| 159 | + channels = 3, |
| 160 | + dim_head = 64, |
| 161 | + residual_lerp_scale_init = None |
| 162 | + ): |
| 163 | + super().__init__() |
| 164 | + image_height, image_width = pair(image_size) |
| 165 | + |
| 166 | + # calculate patching related stuff |
| 167 | + |
| 168 | + assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.' |
| 169 | + |
| 170 | + patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size) |
| 171 | + patch_dim = channels * (patch_size ** 2) |
| 172 | + num_patches = patch_height_dim * patch_width_dim |
| 173 | + |
| 174 | + self.channels = channels |
| 175 | + self.patch_size = patch_size |
| 176 | + |
| 177 | + self.to_patch_embedding = nn.Sequential( |
| 178 | + Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size), |
| 179 | + nn.LayerNorm(patch_dim), |
| 180 | + nn.Linear(patch_dim, dim), |
| 181 | + nn.LayerNorm(dim), |
| 182 | + ) |
| 183 | + |
| 184 | + self.abs_pos_emb = nn.Embedding(num_patches, dim) |
| 185 | + |
| 186 | + residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth) |
| 187 | + |
| 188 | + # layers |
| 189 | + |
| 190 | + self.dim = dim |
| 191 | + self.layers = ModuleList([]) |
| 192 | + self.residual_lerp_scales = nn.ParameterList([]) |
| 193 | + |
| 194 | + for _ in range(depth): |
| 195 | + self.layers.append(ModuleList([ |
| 196 | + Attention(dim, dim_head = dim_head, heads = heads, dropout = dropout), |
| 197 | + FeedForward(dim, dim_inner = mlp_dim, dropout = dropout), |
| 198 | + ])) |
| 199 | + |
| 200 | + self.residual_lerp_scales.append(nn.ParameterList([ |
| 201 | + nn.Parameter(torch.ones(dim) * residual_lerp_scale_init), |
| 202 | + nn.Parameter(torch.ones(dim) * residual_lerp_scale_init), |
| 203 | + ])) |
| 204 | + |
| 205 | + self.logit_scale = nn.Parameter(torch.ones(num_classes)) |
| 206 | + |
| 207 | + self.to_pred = NormLinear(dim, num_classes) |
| 208 | + |
| 209 | + @torch.no_grad() |
| 210 | + def norm_weights_(self): |
| 211 | + for module in self.modules(): |
| 212 | + if not isinstance(module, NormLinear): |
| 213 | + continue |
| 214 | + |
| 215 | + normed = module.weight |
| 216 | + original = module.linear.parametrizations.weight.original |
| 217 | + |
| 218 | + original.copy_(normed) |
| 219 | + |
| 220 | + def forward(self, images): |
| 221 | + device = images.device |
| 222 | + |
| 223 | + tokens = self.to_patch_embedding(images) |
| 224 | + |
| 225 | + pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device)) |
| 226 | + |
| 227 | + tokens = l2norm(tokens + pos_emb) |
| 228 | + |
| 229 | + for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales): |
| 230 | + |
| 231 | + attn_out = l2norm(attn(tokens)) |
| 232 | + tokens = l2norm(tokens.lerp(attn_out, attn_alpha)) |
| 233 | + |
| 234 | + ff_out = l2norm(ff(tokens)) |
| 235 | + tokens = l2norm(tokens.lerp(ff_out, ff_alpha)) |
| 236 | + |
| 237 | + pooled = reduce(tokens, 'b n d -> b d', 'mean') |
| 238 | + |
| 239 | + logits = self.to_pred(pooled) |
| 240 | + logits = logits * self.logit_scale * (self.dim ** 0.5) |
| 241 | + |
| 242 | + return logits |
| 243 | + |
| 244 | +# quick test |
| 245 | + |
| 246 | +if __name__ == '__main__': |
| 247 | + |
| 248 | + v = nViT( |
| 249 | + image_size = 256, |
| 250 | + patch_size = 16, |
| 251 | + num_classes = 1000, |
| 252 | + dim = 1024, |
| 253 | + depth = 6, |
| 254 | + heads = 8, |
| 255 | + mlp_dim = 2048, |
| 256 | + ) |
| 257 | + |
| 258 | + img = torch.randn(4, 3, 256, 256) |
| 259 | + logits = v(img) # (4, 1000) |
| 260 | + assert logits.shape == (4, 1000) |
0 commit comments