|
| 1 | +import torch |
| 2 | +from torch import nn |
| 3 | +from torch.nn import Module, ModuleList |
| 4 | + |
| 5 | +from einops import rearrange, repeat, reduce, pack, unpack |
| 6 | +from einops.layers.torch import Rearrange |
| 7 | + |
| 8 | +# helpers |
| 9 | + |
| 10 | +def pair(t): |
| 11 | + return t if isinstance(t, tuple) else (t, t) |
| 12 | + |
| 13 | +def divisible_by(num, den): |
| 14 | + return (num % den) == 0 |
| 15 | + |
| 16 | +def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): |
| 17 | + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") |
| 18 | + assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb" |
| 19 | + |
| 20 | + omega = torch.arange(dim // 4) / (dim // 4 - 1) |
| 21 | + omega = temperature ** -omega |
| 22 | + |
| 23 | + y = y.flatten()[:, None] * omega[None, :] |
| 24 | + x = x.flatten()[:, None] * omega[None, :] |
| 25 | + pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) |
| 26 | + |
| 27 | + return pos_emb.type(dtype) |
| 28 | + |
| 29 | +# classes |
| 30 | + |
| 31 | +def FeedForward(dim, mult = 4.): |
| 32 | + hidden_dim = int(dim * mult) |
| 33 | + return nn.Sequential( |
| 34 | + nn.LayerNorm(dim), |
| 35 | + nn.Linear(dim, hidden_dim), |
| 36 | + nn.GELU(), |
| 37 | + nn.Linear(hidden_dim, dim), |
| 38 | + ) |
| 39 | + |
| 40 | +class Attention(Module): |
| 41 | + def __init__(self, dim, heads = 8, dim_head = 64): |
| 42 | + super().__init__() |
| 43 | + inner_dim = dim_head * heads |
| 44 | + self.heads = heads |
| 45 | + self.scale = dim_head ** -0.5 |
| 46 | + self.norm = nn.LayerNorm(dim) |
| 47 | + |
| 48 | + self.attend = nn.Softmax(dim = -1) |
| 49 | + |
| 50 | + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
| 51 | + self.to_out = nn.Linear(inner_dim, dim, bias = False) |
| 52 | + |
| 53 | + def forward(self, x): |
| 54 | + x = self.norm(x) |
| 55 | + |
| 56 | + qkv = self.to_qkv(x).chunk(3, dim = -1) |
| 57 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) |
| 58 | + |
| 59 | + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
| 60 | + |
| 61 | + attn = self.attend(dots) |
| 62 | + |
| 63 | + out = torch.matmul(attn, v) |
| 64 | + out = rearrange(out, 'b h n d -> b n (h d)') |
| 65 | + return self.to_out(out) |
| 66 | + |
| 67 | +class JumboViT(Module): |
| 68 | + def __init__( |
| 69 | + self, |
| 70 | + *, |
| 71 | + image_size, |
| 72 | + patch_size, |
| 73 | + num_classes, |
| 74 | + dim, |
| 75 | + depth, |
| 76 | + heads, |
| 77 | + mlp_dim, |
| 78 | + num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example |
| 79 | + jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on |
| 80 | + jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward |
| 81 | + channels = 3, |
| 82 | + dim_head = 64 |
| 83 | + ): |
| 84 | + super().__init__() |
| 85 | + image_height, image_width = pair(image_size) |
| 86 | + patch_height, patch_width = pair(patch_size) |
| 87 | + |
| 88 | + assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.' |
| 89 | + |
| 90 | + patch_dim = channels * patch_height * patch_width |
| 91 | + |
| 92 | + self.to_patch_embedding = nn.Sequential( |
| 93 | + Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), |
| 94 | + nn.LayerNorm(patch_dim), |
| 95 | + nn.Linear(patch_dim, dim), |
| 96 | + nn.LayerNorm(dim), |
| 97 | + ) |
| 98 | + |
| 99 | + self.pos_embedding = posemb_sincos_2d( |
| 100 | + h = image_height // patch_height, |
| 101 | + w = image_width // patch_width, |
| 102 | + dim = dim, |
| 103 | + ) |
| 104 | + |
| 105 | + jumbo_cls_dim = dim * jumbo_cls_k |
| 106 | + |
| 107 | + self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim)) |
| 108 | + |
| 109 | + jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k) |
| 110 | + self.jumbo_cls_to_tokens = jumbo_cls_to_tokens |
| 111 | + |
| 112 | + self.norm = nn.LayerNorm(dim) |
| 113 | + self.layers = ModuleList([]) |
| 114 | + |
| 115 | + # attention and feedforwards |
| 116 | + |
| 117 | + self.jumbo_ff = nn.Sequential( |
| 118 | + Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k), |
| 119 | + FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient |
| 120 | + jumbo_cls_to_tokens |
| 121 | + ) |
| 122 | + |
| 123 | + for _ in range(depth): |
| 124 | + self.layers.append(ModuleList([ |
| 125 | + Attention(dim, heads = heads, dim_head = dim_head), |
| 126 | + FeedForward(dim, mlp_dim), |
| 127 | + ])) |
| 128 | + |
| 129 | + self.to_latent = nn.Identity() |
| 130 | + |
| 131 | + self.linear_head = nn.Linear(dim, num_classes) |
| 132 | + |
| 133 | + def forward(self, img): |
| 134 | + |
| 135 | + batch, device = img.shape[0], img.device |
| 136 | + |
| 137 | + x = self.to_patch_embedding(img) |
| 138 | + |
| 139 | + # pos embedding |
| 140 | + |
| 141 | + pos_emb = self.pos_embedding.to(device, dtype = x.dtype) |
| 142 | + |
| 143 | + x = x + pos_emb |
| 144 | + |
| 145 | + # add cls tokens |
| 146 | + |
| 147 | + cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch) |
| 148 | + |
| 149 | + jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens) |
| 150 | + |
| 151 | + x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d') |
| 152 | + |
| 153 | + # attention and feedforwards |
| 154 | + |
| 155 | + for layer, (attn, ff) in enumerate(self.layers, start = 1): |
| 156 | + is_last = layer == len(self.layers) |
| 157 | + |
| 158 | + x = attn(x) + x |
| 159 | + |
| 160 | + # jumbo feedforward |
| 161 | + |
| 162 | + jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d') |
| 163 | + |
| 164 | + x = ff(x) + x |
| 165 | + jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens |
| 166 | + |
| 167 | + if is_last: |
| 168 | + continue |
| 169 | + |
| 170 | + x, _ = pack([jumbo_cls_tokens, x], 'b * d') |
| 171 | + |
| 172 | + pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean') |
| 173 | + |
| 174 | + # normalization and project to logits |
| 175 | + |
| 176 | + embed = self.norm(pooled) |
| 177 | + |
| 178 | + embed = self.to_latent(embed) |
| 179 | + logits = self.linear_head(embed) |
| 180 | + return logits |
| 181 | + |
| 182 | +# copy pasteable file |
| 183 | + |
| 184 | +if __name__ == '__main__': |
| 185 | + |
| 186 | + v = JumboViT( |
| 187 | + num_classes = 1000, |
| 188 | + image_size = 64, |
| 189 | + patch_size = 8, |
| 190 | + dim = 16, |
| 191 | + depth = 2, |
| 192 | + heads = 2, |
| 193 | + mlp_dim = 32, |
| 194 | + jumbo_cls_k = 3, |
| 195 | + jumbo_ff_mult = 2, |
| 196 | + ) |
| 197 | + |
| 198 | + images = torch.randn(1, 3, 64, 64) |
| 199 | + |
| 200 | + logits = v(images) |
| 201 | + assert logits.shape == (1, 1000) |
0 commit comments