Skip to content

Commit db05a14

Browse files
committed
add the proposed jumbo vit from Fuller et al. of Carleton University
1 parent 9f49a31 commit db05a14

File tree

3 files changed

+214
-1
lines changed

3 files changed

+214
-1
lines changed

README.md

+9
Original file line numberDiff line numberDiff line change
@@ -2172,4 +2172,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
21722172
}
21732173
```
21742174

2175+
```bibtex
2176+
@inproceedings{Fuller2025SimplerFV,
2177+
title = {Simpler Fast Vision Transformers with a Jumbo CLS Token},
2178+
author = {Anthony Fuller and Yousef Yassin and Daniel G. Kyrollos and Evan Shelhamer and James R. Green},
2179+
year = {2025},
2180+
url = {https://api.semanticscholar.org/CorpusID:276557720}
2181+
}
2182+
```
2183+
21752184
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.9.2',
9+
version = '1.10.1',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

vit_pytorch/jumbo_vit.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Simpler Fast Vision Transformers with a Jumbo CLS Token
2+
# https://arxiv.org/abs/2502.15021
3+
4+
import torch
5+
from torch import nn
6+
from torch.nn import Module, ModuleList
7+
8+
from einops import rearrange, repeat, reduce, pack, unpack
9+
from einops.layers.torch import Rearrange
10+
11+
# helpers
12+
13+
def pair(t):
14+
return t if isinstance(t, tuple) else (t, t)
15+
16+
def divisible_by(num, den):
17+
return (num % den) == 0
18+
19+
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
20+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
21+
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
22+
23+
omega = torch.arange(dim // 4) / (dim // 4 - 1)
24+
omega = temperature ** -omega
25+
26+
y = y.flatten()[:, None] * omega[None, :]
27+
x = x.flatten()[:, None] * omega[None, :]
28+
pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
29+
30+
return pos_emb.type(dtype)
31+
32+
# classes
33+
34+
def FeedForward(dim, mult = 4.):
35+
hidden_dim = int(dim * mult)
36+
return nn.Sequential(
37+
nn.LayerNorm(dim),
38+
nn.Linear(dim, hidden_dim),
39+
nn.GELU(),
40+
nn.Linear(hidden_dim, dim),
41+
)
42+
43+
class Attention(Module):
44+
def __init__(self, dim, heads = 8, dim_head = 64):
45+
super().__init__()
46+
inner_dim = dim_head * heads
47+
self.heads = heads
48+
self.scale = dim_head ** -0.5
49+
self.norm = nn.LayerNorm(dim)
50+
51+
self.attend = nn.Softmax(dim = -1)
52+
53+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
54+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
55+
56+
def forward(self, x):
57+
x = self.norm(x)
58+
59+
qkv = self.to_qkv(x).chunk(3, dim = -1)
60+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
61+
62+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
63+
64+
attn = self.attend(dots)
65+
66+
out = torch.matmul(attn, v)
67+
out = rearrange(out, 'b h n d -> b n (h d)')
68+
return self.to_out(out)
69+
70+
class JumboViT(Module):
71+
def __init__(
72+
self,
73+
*,
74+
image_size,
75+
patch_size,
76+
num_classes,
77+
dim,
78+
depth,
79+
heads,
80+
mlp_dim,
81+
num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example
82+
jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on
83+
jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward
84+
channels = 3,
85+
dim_head = 64
86+
):
87+
super().__init__()
88+
image_height, image_width = pair(image_size)
89+
patch_height, patch_width = pair(patch_size)
90+
91+
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
92+
93+
patch_dim = channels * patch_height * patch_width
94+
95+
self.to_patch_embedding = nn.Sequential(
96+
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
97+
nn.LayerNorm(patch_dim),
98+
nn.Linear(patch_dim, dim),
99+
nn.LayerNorm(dim),
100+
)
101+
102+
self.pos_embedding = posemb_sincos_2d(
103+
h = image_height // patch_height,
104+
w = image_width // patch_width,
105+
dim = dim,
106+
)
107+
108+
jumbo_cls_dim = dim * jumbo_cls_k
109+
110+
self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim))
111+
112+
jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k)
113+
self.jumbo_cls_to_tokens = jumbo_cls_to_tokens
114+
115+
self.norm = nn.LayerNorm(dim)
116+
self.layers = ModuleList([])
117+
118+
# attention and feedforwards
119+
120+
self.jumbo_ff = nn.Sequential(
121+
Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k),
122+
FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient
123+
jumbo_cls_to_tokens
124+
)
125+
126+
for _ in range(depth):
127+
self.layers.append(ModuleList([
128+
Attention(dim, heads = heads, dim_head = dim_head),
129+
FeedForward(dim, mlp_dim),
130+
]))
131+
132+
self.to_latent = nn.Identity()
133+
134+
self.linear_head = nn.Linear(dim, num_classes)
135+
136+
def forward(self, img):
137+
138+
batch, device = img.shape[0], img.device
139+
140+
x = self.to_patch_embedding(img)
141+
142+
# pos embedding
143+
144+
pos_emb = self.pos_embedding.to(device, dtype = x.dtype)
145+
146+
x = x + pos_emb
147+
148+
# add cls tokens
149+
150+
cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch)
151+
152+
jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens)
153+
154+
x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d')
155+
156+
# attention and feedforwards
157+
158+
for layer, (attn, ff) in enumerate(self.layers, start = 1):
159+
is_last = layer == len(self.layers)
160+
161+
x = attn(x) + x
162+
163+
# jumbo feedforward
164+
165+
jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d')
166+
167+
x = ff(x) + x
168+
jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens
169+
170+
if is_last:
171+
continue
172+
173+
x, _ = pack([jumbo_cls_tokens, x], 'b * d')
174+
175+
pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean')
176+
177+
# normalization and project to logits
178+
179+
embed = self.norm(pooled)
180+
181+
embed = self.to_latent(embed)
182+
logits = self.linear_head(embed)
183+
return logits
184+
185+
# copy pasteable file
186+
187+
if __name__ == '__main__':
188+
189+
v = JumboViT(
190+
num_classes = 1000,
191+
image_size = 64,
192+
patch_size = 8,
193+
dim = 16,
194+
depth = 2,
195+
heads = 2,
196+
mlp_dim = 32,
197+
jumbo_cls_k = 3,
198+
jumbo_ff_mult = 2,
199+
)
200+
201+
images = torch.randn(1, 3, 64, 64)
202+
203+
logits = v(images)
204+
assert logits.shape == (1, 1000)

0 commit comments

Comments
 (0)