Skip to content

Commit 1de866d

Browse files
committed
add the proposed jumbo vit from Fuller et al. of Carleton University
1 parent 9f49a31 commit 1de866d

File tree

3 files changed

+211
-1
lines changed

3 files changed

+211
-1
lines changed

README.md

+9
Original file line numberDiff line numberDiff line change
@@ -2172,4 +2172,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
21722172
}
21732173
```
21742174

2175+
```bibtex
2176+
@inproceedings{Fuller2025SimplerFV,
2177+
title = {Simpler Fast Vision Transformers with a Jumbo CLS Token},
2178+
author = {Anthony Fuller and Yousef Yassin and Daniel G. Kyrollos and Evan Shelhamer and James R. Green},
2179+
year = {2025},
2180+
url = {https://api.semanticscholar.org/CorpusID:276557720}
2181+
}
2182+
```
2183+
21752184
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.9.2',
9+
version = '1.10.1',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

vit_pytorch/jumbo_vit.py

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import Module, ModuleList
4+
5+
from einops import rearrange, repeat, reduce, pack, unpack
6+
from einops.layers.torch import Rearrange
7+
8+
# helpers
9+
10+
def pair(t):
11+
return t if isinstance(t, tuple) else (t, t)
12+
13+
def divisible_by(num, den):
14+
return (num % den) == 0
15+
16+
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
17+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
18+
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
19+
20+
omega = torch.arange(dim // 4) / (dim // 4 - 1)
21+
omega = temperature ** -omega
22+
23+
y = y.flatten()[:, None] * omega[None, :]
24+
x = x.flatten()[:, None] * omega[None, :]
25+
pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
26+
27+
return pos_emb.type(dtype)
28+
29+
# classes
30+
31+
def FeedForward(dim, mult = 4.):
32+
hidden_dim = int(dim * mult)
33+
return nn.Sequential(
34+
nn.LayerNorm(dim),
35+
nn.Linear(dim, hidden_dim),
36+
nn.GELU(),
37+
nn.Linear(hidden_dim, dim),
38+
)
39+
40+
class Attention(Module):
41+
def __init__(self, dim, heads = 8, dim_head = 64):
42+
super().__init__()
43+
inner_dim = dim_head * heads
44+
self.heads = heads
45+
self.scale = dim_head ** -0.5
46+
self.norm = nn.LayerNorm(dim)
47+
48+
self.attend = nn.Softmax(dim = -1)
49+
50+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
51+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
52+
53+
def forward(self, x):
54+
x = self.norm(x)
55+
56+
qkv = self.to_qkv(x).chunk(3, dim = -1)
57+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
58+
59+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
60+
61+
attn = self.attend(dots)
62+
63+
out = torch.matmul(attn, v)
64+
out = rearrange(out, 'b h n d -> b n (h d)')
65+
return self.to_out(out)
66+
67+
class JumboViT(Module):
68+
def __init__(
69+
self,
70+
*,
71+
image_size,
72+
patch_size,
73+
num_classes,
74+
dim,
75+
depth,
76+
heads,
77+
mlp_dim,
78+
num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example
79+
jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on
80+
jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward
81+
channels = 3,
82+
dim_head = 64
83+
):
84+
super().__init__()
85+
image_height, image_width = pair(image_size)
86+
patch_height, patch_width = pair(patch_size)
87+
88+
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
89+
90+
patch_dim = channels * patch_height * patch_width
91+
92+
self.to_patch_embedding = nn.Sequential(
93+
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
94+
nn.LayerNorm(patch_dim),
95+
nn.Linear(patch_dim, dim),
96+
nn.LayerNorm(dim),
97+
)
98+
99+
self.pos_embedding = posemb_sincos_2d(
100+
h = image_height // patch_height,
101+
w = image_width // patch_width,
102+
dim = dim,
103+
)
104+
105+
jumbo_cls_dim = dim * jumbo_cls_k
106+
107+
self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim))
108+
109+
jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k)
110+
self.jumbo_cls_to_tokens = jumbo_cls_to_tokens
111+
112+
self.norm = nn.LayerNorm(dim)
113+
self.layers = ModuleList([])
114+
115+
# attention and feedforwards
116+
117+
self.jumbo_ff = nn.Sequential(
118+
Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k),
119+
FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient
120+
jumbo_cls_to_tokens
121+
)
122+
123+
for _ in range(depth):
124+
self.layers.append(ModuleList([
125+
Attention(dim, heads = heads, dim_head = dim_head),
126+
FeedForward(dim, mlp_dim),
127+
]))
128+
129+
self.to_latent = nn.Identity()
130+
131+
self.linear_head = nn.Linear(dim, num_classes)
132+
133+
def forward(self, img):
134+
135+
batch, device = img.shape[0], img.device
136+
137+
x = self.to_patch_embedding(img)
138+
139+
# pos embedding
140+
141+
pos_emb = self.pos_embedding.to(device, dtype = x.dtype)
142+
143+
x = x + pos_emb
144+
145+
# add cls tokens
146+
147+
cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch)
148+
149+
jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens)
150+
151+
x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d')
152+
153+
# attention and feedforwards
154+
155+
for layer, (attn, ff) in enumerate(self.layers, start = 1):
156+
is_last = layer == len(self.layers)
157+
158+
x = attn(x) + x
159+
160+
# jumbo feedforward
161+
162+
jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d')
163+
164+
x = ff(x) + x
165+
jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens
166+
167+
if is_last:
168+
continue
169+
170+
x, _ = pack([jumbo_cls_tokens, x], 'b * d')
171+
172+
pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean')
173+
174+
# normalization and project to logits
175+
176+
embed = self.norm(pooled)
177+
178+
embed = self.to_latent(embed)
179+
logits = self.linear_head(embed)
180+
return logits
181+
182+
# copy pasteable file
183+
184+
if __name__ == '__main__':
185+
186+
v = JumboViT(
187+
num_classes = 1000,
188+
image_size = 64,
189+
patch_size = 8,
190+
dim = 16,
191+
depth = 2,
192+
heads = 2,
193+
mlp_dim = 32,
194+
jumbo_cls_k = 3,
195+
jumbo_ff_mult = 2,
196+
)
197+
198+
images = torch.randn(1, 3, 64, 64)
199+
200+
logits = v(images)
201+
assert logits.shape == (1, 1000)

0 commit comments

Comments
 (0)