Skip to content

Commit cb6d749

Browse files
committed
add a 3d version of cct, addressing #238 0.38.1
1 parent 6ec8fda commit cb6d749

File tree

4 files changed

+505
-95
lines changed

4 files changed

+505
-95
lines changed

README.md

+29
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,35 @@ video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, widt
10231023
preds = v(video) # (4, 1000)
10241024
```
10251025

1026+
3D version of <a href="https://github.com/lucidrains/vit-pytorch#cct">CCT</a>
1027+
1028+
```python
1029+
import torch
1030+
from vit_pytorch.cct_3d import CCT
1031+
1032+
cct = CCT(
1033+
img_size = 224,
1034+
num_frames = 8,
1035+
embedding_dim = 384,
1036+
n_conv_layers = 2,
1037+
frame_kernel_size = 3,
1038+
kernel_size = 7,
1039+
stride = 2,
1040+
padding = 3,
1041+
pooling_kernel_size = 3,
1042+
pooling_stride = 2,
1043+
pooling_padding = 1,
1044+
num_layers = 14,
1045+
num_heads = 6,
1046+
mlp_radio = 3.,
1047+
num_classes = 1000,
1048+
positional_embedding = 'learnable'
1049+
)
1050+
1051+
video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, frames, height, width)
1052+
pred = cct(video)
1053+
```
1054+
10261055
## ViViT
10271056

10281057
<img src="./images/vivit.png" width="350px"></img>

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.37.1',
6+
version = '0.38.1',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/cct.py

+99-94
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import torch
2-
import torch.nn as nn
2+
from torch import nn, einsum
33
import torch.nn.functional as F
44

5+
from einops import rearrange, repeat
6+
57
# helpers
68

9+
def exists(val):
10+
return val is not None
11+
12+
def default(val, d):
13+
return val if exists(val) else d
14+
715
def pair(t):
816
return t if isinstance(t, tuple) else (t, t)
917

@@ -50,8 +58,9 @@ def cct_16(*args, **kwargs):
5058
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
5159
kernel_size=3, stride=None, padding=None,
5260
*args, **kwargs):
53-
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
54-
padding = padding if padding is not None else max(1, (kernel_size // 2))
61+
stride = default(stride, max(1, (kernel_size // 2) - 1))
62+
padding = default(padding, max(1, (kernel_size // 2)))
63+
5564
return CCT(num_layers=num_layers,
5665
num_heads=num_heads,
5766
mlp_ratio=mlp_ratio,
@@ -61,13 +70,22 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
6170
padding=padding,
6271
*args, **kwargs)
6372

73+
# positional
74+
75+
def sinusoidal_embedding(n_channels, dim):
76+
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
77+
for p in range(n_channels)])
78+
pe[:, 0::2] = torch.sin(pe[:, 0::2])
79+
pe[:, 1::2] = torch.cos(pe[:, 1::2])
80+
return rearrange(pe, '... -> 1 ...')
81+
6482
# modules
6583

6684
class Attention(nn.Module):
6785
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
6886
super().__init__()
69-
self.num_heads = num_heads
70-
head_dim = dim // self.num_heads
87+
self.heads = num_heads
88+
head_dim = dim // self.heads
7189
self.scale = head_dim ** -0.5
7290

7391
self.qkv = nn.Linear(dim, dim * 3, bias=False)
@@ -77,17 +95,20 @@ def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0
7795

7896
def forward(self, x):
7997
B, N, C = x.shape
80-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
81-
q, k, v = qkv[0], qkv[1], qkv[2]
8298

83-
attn = (q @ k.transpose(-2, -1)) * self.scale
99+
qkv = self.qkv(x).chunk(3, dim = -1)
100+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
101+
102+
q = q * self.scale
103+
104+
attn = einsum('b h i d, b h j d -> b h i j', q, k)
84105
attn = attn.softmax(dim=-1)
85106
attn = self.attn_drop(attn)
86107

87-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
88-
x = self.proj(x)
89-
x = self.proj_drop(x)
90-
return x
108+
x = einsum('b h i j, b h j d -> b h i d', attn, v)
109+
x = rearrange(x, 'b h n d -> b n (h d)')
110+
111+
return self.proj_drop(self.proj(x))
91112

92113

93114
class TransformerEncoderLayer(nn.Module):
@@ -97,7 +118,8 @@ class TransformerEncoderLayer(nn.Module):
97118
"""
98119
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
99120
attention_dropout=0.1, drop_path_rate=0.1):
100-
super(TransformerEncoderLayer, self).__init__()
121+
super().__init__()
122+
101123
self.pre_norm = nn.LayerNorm(d_model)
102124
self.self_attn = Attention(dim=d_model, num_heads=nhead,
103125
attention_dropout=attention_dropout, projection_dropout=dropout)
@@ -108,50 +130,34 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
108130
self.linear2 = nn.Linear(dim_feedforward, d_model)
109131
self.dropout2 = nn.Dropout(dropout)
110132

111-
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
133+
self.drop_path = DropPath(drop_path_rate)
112134

113135
self.activation = F.gelu
114136

115-
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
137+
def forward(self, src, *args, **kwargs):
116138
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
117139
src = self.norm1(src)
118140
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
119141
src = src + self.drop_path(self.dropout2(src2))
120142
return src
121143

122-
123-
def drop_path(x, drop_prob: float = 0., training: bool = False):
124-
"""
125-
Obtained from: github.com:rwightman/pytorch-image-models
126-
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
127-
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
128-
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
129-
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
130-
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
131-
'survival rate' as the argument.
132-
"""
133-
if drop_prob == 0. or not training:
134-
return x
135-
keep_prob = 1 - drop_prob
136-
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
137-
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
138-
random_tensor.floor_() # binarize
139-
output = x.div(keep_prob) * random_tensor
140-
return output
141-
142-
143144
class DropPath(nn.Module):
144-
"""
145-
Obtained from: github.com:rwightman/pytorch-image-models
146-
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
147-
"""
148145
def __init__(self, drop_prob=None):
149-
super(DropPath, self).__init__()
150-
self.drop_prob = drop_prob
146+
super().__init__()
147+
self.drop_prob = float(drop_prob)
151148

152149
def forward(self, x):
153-
return drop_path(x, self.drop_prob, self.training)
150+
batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype
151+
152+
if drop_prob <= 0. or not self.training:
153+
return x
154+
155+
keep_prob = 1 - self.drop_prob
156+
shape = (batch, *((1,) * (x.ndim - 1)))
154157

158+
keep_mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < keep_prob
159+
output = x.div(keep_prob) * keep_mask.float()
160+
return output
155161

156162
class Tokenizer(nn.Module):
157163
def __init__(self,
@@ -164,34 +170,35 @@ def __init__(self,
164170
activation=None,
165171
max_pool=True,
166172
conv_bias=False):
167-
super(Tokenizer, self).__init__()
173+
super().__init__()
168174

169175
n_filter_list = [n_input_channels] + \
170176
[in_planes for _ in range(n_conv_layers - 1)] + \
171177
[n_output_channels]
172178

179+
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
180+
173181
self.conv_layers = nn.Sequential(
174182
*[nn.Sequential(
175-
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
183+
nn.Conv2d(chan_in, chan_out,
176184
kernel_size=(kernel_size, kernel_size),
177185
stride=(stride, stride),
178186
padding=(padding, padding), bias=conv_bias),
179-
nn.Identity() if activation is None else activation(),
187+
nn.Identity() if not exists(activation) else activation(),
180188
nn.MaxPool2d(kernel_size=pooling_kernel_size,
181189
stride=pooling_stride,
182190
padding=pooling_padding) if max_pool else nn.Identity()
183191
)
184-
for i in range(n_conv_layers)
192+
for chan_in, chan_out in n_filter_list_pairs
185193
])
186194

187-
self.flattener = nn.Flatten(2, 3)
188195
self.apply(self.init_weight)
189196

190197
def sequence_length(self, n_channels=3, height=224, width=224):
191198
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
192199

193200
def forward(self, x):
194-
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
201+
return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')
195202

196203
@staticmethod
197204
def init_weight(m):
@@ -214,106 +221,104 @@ def __init__(self,
214221
sequence_length=None,
215222
*args, **kwargs):
216223
super().__init__()
217-
positional_embedding = positional_embedding if \
218-
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
224+
assert positional_embedding in {'sine', 'learnable', 'none'}
225+
219226
dim_feedforward = int(embedding_dim * mlp_ratio)
220227
self.embedding_dim = embedding_dim
221228
self.sequence_length = sequence_length
222229
self.seq_pool = seq_pool
223230

224-
assert sequence_length is not None or positional_embedding == 'none', \
231+
assert exists(sequence_length) or positional_embedding == 'none', \
225232
f"Positional embedding is set to {positional_embedding} and" \
226233
f" the sequence length was not specified."
227234

228235
if not seq_pool:
229236
sequence_length += 1
230-
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
231-
requires_grad=True)
237+
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True)
232238
else:
233239
self.attention_pool = nn.Linear(self.embedding_dim, 1)
234240

235-
if positional_embedding != 'none':
236-
if positional_embedding == 'learnable':
237-
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
238-
requires_grad=True)
239-
nn.init.trunc_normal_(self.positional_emb, std=0.2)
240-
else:
241-
self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
242-
requires_grad=False)
243-
else:
241+
if positional_embedding == 'none':
244242
self.positional_emb = None
243+
elif positional_embedding == 'learnable':
244+
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
245+
requires_grad=True)
246+
nn.init.trunc_normal_(self.positional_emb, std=0.2)
247+
else:
248+
self.positional_emb = nn.Parameter(sinusoidal_embedding(sequence_length, embedding_dim),
249+
requires_grad=False)
245250

246251
self.dropout = nn.Dropout(p=dropout_rate)
252+
247253
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
254+
248255
self.blocks = nn.ModuleList([
249256
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
250257
dim_feedforward=dim_feedforward, dropout=dropout_rate,
251-
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
252-
for i in range(num_layers)])
258+
attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
259+
for layer_dpr in dpr])
260+
253261
self.norm = nn.LayerNorm(embedding_dim)
254262

255263
self.fc = nn.Linear(embedding_dim, num_classes)
256264
self.apply(self.init_weight)
257265

258266
def forward(self, x):
259-
if self.positional_emb is None and x.size(1) < self.sequence_length:
267+
b = x.shape[0]
268+
269+
if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
260270
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
261271

262272
if not self.seq_pool:
263-
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
273+
cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b)
264274
x = torch.cat((cls_token, x), dim=1)
265275

266-
if self.positional_emb is not None:
276+
if exists(self.positional_emb):
267277
x += self.positional_emb
268278

269279
x = self.dropout(x)
270280

271281
for blk in self.blocks:
272282
x = blk(x)
283+
273284
x = self.norm(x)
274285

275286
if self.seq_pool:
276-
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
287+
attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
288+
x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)
277289
else:
278290
x = x[:, 0]
279291

280-
x = self.fc(x)
281-
return x
292+
return self.fc(x)
282293

283294
@staticmethod
284295
def init_weight(m):
285296
if isinstance(m, nn.Linear):
286297
nn.init.trunc_normal_(m.weight, std=.02)
287-
if isinstance(m, nn.Linear) and m.bias is not None:
298+
if isinstance(m, nn.Linear) and exists(m.bias):
288299
nn.init.constant_(m.bias, 0)
289300
elif isinstance(m, nn.LayerNorm):
290301
nn.init.constant_(m.bias, 0)
291302
nn.init.constant_(m.weight, 1.0)
292303

293-
@staticmethod
294-
def sinusoidal_embedding(n_channels, dim):
295-
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
296-
for p in range(n_channels)])
297-
pe[:, 0::2] = torch.sin(pe[:, 0::2])
298-
pe[:, 1::2] = torch.cos(pe[:, 1::2])
299-
return pe.unsqueeze(0)
300-
301-
302304
# CCT Main model
305+
303306
class CCT(nn.Module):
304-
def __init__(self,
305-
img_size=224,
306-
embedding_dim=768,
307-
n_input_channels=3,
308-
n_conv_layers=1,
309-
kernel_size=7,
310-
stride=2,
311-
padding=3,
312-
pooling_kernel_size=3,
313-
pooling_stride=2,
314-
pooling_padding=1,
315-
*args, **kwargs):
316-
super(CCT, self).__init__()
307+
def __init__(
308+
self,
309+
img_size=224,
310+
embedding_dim=768,
311+
n_input_channels=3,
312+
n_conv_layers=1,
313+
kernel_size=7,
314+
stride=2,
315+
padding=3,
316+
pooling_kernel_size=3,
317+
pooling_stride=2,
318+
pooling_padding=1,
319+
*args, **kwargs
320+
):
321+
super().__init__()
317322
img_height, img_width = pair(img_size)
318323

319324
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,

0 commit comments

Comments
 (0)