Skip to content

Commit 014df1e

Browse files
committed
improvise a max vit with register tokens
1 parent 680d446 commit 014df1e

File tree

3 files changed

+342
-3
lines changed

3 files changed

+342
-3
lines changed

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.5.0',
6+
version = '1.5.2',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',
@@ -16,7 +16,7 @@
1616
'image recognition'
1717
],
1818
install_requires=[
19-
'einops>=0.6.1',
19+
'einops>=0.7.0',
2020
'torch>=1.10',
2121
'torchvision'
2222
],

vit_pytorch/max_vit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def forward(self, x):
173173

174174
# split heads
175175

176-
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
176+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
177177

178178
# scale
179179

vit_pytorch/max_vit_with_registers.py

+339
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
from functools import partial
2+
3+
import torch
4+
from torch import nn, einsum
5+
import torch.nn.functional as F
6+
from torch.nn import Module, ModuleList, Sequential
7+
8+
from einops import rearrange, repeat, reduce, pack, unpack
9+
from einops.layers.torch import Rearrange, Reduce
10+
11+
# helpers
12+
13+
def exists(val):
14+
return val is not None
15+
16+
def default(val, d):
17+
return val if exists(val) else d
18+
19+
def pack_one(x, pattern):
20+
return pack([x], pattern)
21+
22+
def unpack_one(x, ps, pattern):
23+
return unpack(x, ps, pattern)[0]
24+
25+
def cast_tuple(val, length = 1):
26+
return val if isinstance(val, tuple) else ((val,) * length)
27+
28+
# helper classes
29+
30+
def FeedForward(dim, mult = 4, dropout = 0.):
31+
inner_dim = int(dim * mult)
32+
return Sequential(
33+
nn.LayerNorm(dim),
34+
nn.Linear(dim, inner_dim),
35+
nn.GELU(),
36+
nn.Dropout(dropout),
37+
nn.Linear(inner_dim, dim),
38+
nn.Dropout(dropout)
39+
)
40+
41+
# MBConv
42+
43+
class SqueezeExcitation(Module):
44+
def __init__(self, dim, shrinkage_rate = 0.25):
45+
super().__init__()
46+
hidden_dim = int(dim * shrinkage_rate)
47+
48+
self.gate = Sequential(
49+
Reduce('b c h w -> b c', 'mean'),
50+
nn.Linear(dim, hidden_dim, bias = False),
51+
nn.SiLU(),
52+
nn.Linear(hidden_dim, dim, bias = False),
53+
nn.Sigmoid(),
54+
Rearrange('b c -> b c 1 1')
55+
)
56+
57+
def forward(self, x):
58+
return x * self.gate(x)
59+
60+
class MBConvResidual(Module):
61+
def __init__(self, fn, dropout = 0.):
62+
super().__init__()
63+
self.fn = fn
64+
self.dropsample = Dropsample(dropout)
65+
66+
def forward(self, x):
67+
out = self.fn(x)
68+
out = self.dropsample(out)
69+
return out + x
70+
71+
class Dropsample(Module):
72+
def __init__(self, prob = 0):
73+
super().__init__()
74+
self.prob = prob
75+
76+
def forward(self, x):
77+
device = x.device
78+
79+
if self.prob == 0. or (not self.training):
80+
return x
81+
82+
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
83+
return x * keep_mask / (1 - self.prob)
84+
85+
def MBConv(
86+
dim_in,
87+
dim_out,
88+
*,
89+
downsample,
90+
expansion_rate = 4,
91+
shrinkage_rate = 0.25,
92+
dropout = 0.
93+
):
94+
hidden_dim = int(expansion_rate * dim_out)
95+
stride = 2 if downsample else 1
96+
97+
net = Sequential(
98+
nn.Conv2d(dim_in, hidden_dim, 1),
99+
nn.BatchNorm2d(hidden_dim),
100+
nn.GELU(),
101+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
102+
nn.BatchNorm2d(hidden_dim),
103+
nn.GELU(),
104+
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
105+
nn.Conv2d(hidden_dim, dim_out, 1),
106+
nn.BatchNorm2d(dim_out)
107+
)
108+
109+
if dim_in == dim_out and not downsample:
110+
net = MBConvResidual(net, dropout = dropout)
111+
112+
return net
113+
114+
# attention related classes
115+
116+
class Attention(Module):
117+
def __init__(
118+
self,
119+
dim,
120+
dim_head = 32,
121+
dropout = 0.,
122+
window_size = 7
123+
):
124+
super().__init__()
125+
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
126+
127+
self.heads = dim // dim_head
128+
self.scale = dim_head ** -0.5
129+
130+
self.norm = nn.LayerNorm(dim)
131+
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
132+
133+
self.attend = nn.Sequential(
134+
nn.Softmax(dim = -1),
135+
nn.Dropout(dropout)
136+
)
137+
138+
self.to_out = nn.Sequential(
139+
nn.Linear(dim, dim, bias = False),
140+
nn.Dropout(dropout)
141+
)
142+
143+
# relative positional bias
144+
145+
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
146+
147+
pos = torch.arange(window_size)
148+
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
149+
grid = rearrange(grid, 'c i j -> (i j) c')
150+
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
151+
rel_pos += window_size - 1
152+
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
153+
154+
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
155+
156+
def forward(self, x):
157+
device, h = x.device, self.heads
158+
159+
x = self.norm(x)
160+
161+
# project for queries, keys, values
162+
163+
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
164+
165+
# split heads
166+
167+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
168+
169+
# scale
170+
171+
q = q * self.scale
172+
173+
# sim
174+
175+
sim = einsum('b h i d, b h j d -> b h i j', q, k)
176+
177+
# add positional bias
178+
179+
bias = self.rel_pos_bias(self.rel_pos_indices)
180+
bias = rearrange(bias, 'i j h -> h i j')
181+
182+
num_registers = sim.shape[-1] - bias.shape[-1]
183+
bias = F.pad(bias, (num_registers, 0, num_registers, 0), value = 0.)
184+
185+
sim = sim + bias
186+
187+
# attention
188+
189+
attn = self.attend(sim)
190+
191+
# aggregate
192+
193+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
194+
195+
# combine heads out
196+
197+
out = rearrange(out, 'b h n d -> b n (h d)')
198+
return self.to_out(out)
199+
200+
class MaxViT(Module):
201+
def __init__(
202+
self,
203+
*,
204+
num_classes,
205+
dim,
206+
depth,
207+
dim_head = 32,
208+
dim_conv_stem = None,
209+
window_size = 7,
210+
mbconv_expansion_rate = 4,
211+
mbconv_shrinkage_rate = 0.25,
212+
dropout = 0.1,
213+
channels = 3,
214+
num_register_tokens = 4
215+
):
216+
super().__init__()
217+
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
218+
219+
# convolutional stem
220+
221+
dim_conv_stem = default(dim_conv_stem, dim)
222+
223+
self.conv_stem = Sequential(
224+
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
225+
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
226+
)
227+
228+
# variables
229+
230+
num_stages = len(depth)
231+
232+
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
233+
dims = (dim_conv_stem, *dims)
234+
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
235+
236+
self.layers = nn.ModuleList([])
237+
238+
# window size
239+
240+
self.window_size = window_size
241+
242+
self.register_tokens = nn.ParameterList([])
243+
244+
# iterate through stages
245+
246+
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
247+
for stage_ind in range(layer_depth):
248+
is_first = stage_ind == 0
249+
stage_dim_in = layer_dim_in if is_first else layer_dim
250+
251+
conv = MBConv(
252+
stage_dim_in,
253+
layer_dim,
254+
downsample = is_first,
255+
expansion_rate = mbconv_expansion_rate,
256+
shrinkage_rate = mbconv_shrinkage_rate
257+
)
258+
259+
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size)
260+
block_ff = FeedForward(dim = layer_dim, dropout = dropout)
261+
262+
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size)
263+
grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
264+
265+
register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))
266+
267+
self.layers.append(ModuleList([
268+
conv,
269+
ModuleList([block_attn, block_ff]),
270+
ModuleList([grid_attn, grid_ff])
271+
]))
272+
273+
self.register_tokens.append(register_tokens)
274+
275+
# mlp head out
276+
277+
self.mlp_head = nn.Sequential(
278+
Reduce('b d h w -> b d', 'mean'),
279+
nn.LayerNorm(dims[-1]),
280+
nn.Linear(dims[-1], num_classes)
281+
)
282+
283+
def forward(self, x):
284+
b, w = x.shape[0], self.window_size
285+
286+
x = self.conv_stem(x)
287+
288+
for (conv, (block_attn, block_ff), (grid_attn, grid_ff)), register_tokens in zip(self.layers, self.register_tokens):
289+
x = conv(x)
290+
291+
# block-like attention
292+
293+
x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)
294+
295+
# prepare register tokens
296+
297+
r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
298+
r, register_batch_ps = pack_one(r, '* n d')
299+
300+
x, window_ps = pack_one(x, 'b x y * d')
301+
x, batch_ps = pack_one(x, '* n d')
302+
x, register_ps = pack([r, x], 'b * d')
303+
304+
x = block_attn(x) + x
305+
x = block_ff(x) + x
306+
307+
r, x = unpack(x, register_ps, 'b * d')
308+
309+
x = unpack_one(x, batch_ps, '* n d')
310+
x = unpack_one(x, window_ps, 'b x y * d')
311+
x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')
312+
313+
r = unpack_one(r, register_batch_ps, '* n d')
314+
315+
# grid-like attention
316+
317+
x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)
318+
319+
# prepare register tokens
320+
321+
r = reduce(r, 'b x y n d -> b n d', 'mean')
322+
r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
323+
r, register_batch_ps = pack_one(r, '* n d')
324+
325+
x, window_ps = pack_one(x, 'b x y * d')
326+
x, batch_ps = pack_one(x, '* n d')
327+
x, register_ps = pack([r, x], 'b * d')
328+
329+
x = grid_attn(x) + x
330+
331+
r, x = unpack(x, register_ps, 'b * d')
332+
333+
x = grid_ff(x) + x
334+
335+
x = unpack_one(x, batch_ps, '* n d')
336+
x = unpack_one(x, window_ps, 'b x y * d')
337+
x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')
338+
339+
return self.mlp_head(x)

0 commit comments

Comments
 (0)