1
1
import torch
2
- import torch . nn as nn
2
+ from torch import nn , einsum
3
3
import torch .nn .functional as F
4
4
5
+ from einops import rearrange , repeat
6
+
5
7
# helpers
6
8
9
+ def exists (val ):
10
+ return val is not None
11
+
12
+ def default (val , d ):
13
+ return val if exists (val ) else d
14
+
7
15
def pair (t ):
8
16
return t if isinstance (t , tuple ) else (t , t )
9
17
@@ -50,8 +58,9 @@ def cct_16(*args, **kwargs):
50
58
def _cct (num_layers , num_heads , mlp_ratio , embedding_dim ,
51
59
kernel_size = 3 , stride = None , padding = None ,
52
60
* args , ** kwargs ):
53
- stride = stride if stride is not None else max (1 , (kernel_size // 2 ) - 1 )
54
- padding = padding if padding is not None else max (1 , (kernel_size // 2 ))
61
+ stride = default (stride , max (1 , (kernel_size // 2 ) - 1 ))
62
+ padding = default (padding , max (1 , (kernel_size // 2 )))
63
+
55
64
return CCT (num_layers = num_layers ,
56
65
num_heads = num_heads ,
57
66
mlp_ratio = mlp_ratio ,
@@ -61,13 +70,22 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
61
70
padding = padding ,
62
71
* args , ** kwargs )
63
72
73
+ # positional
74
+
75
+ def sinusoidal_embedding (n_channels , dim ):
76
+ pe = torch .FloatTensor ([[p / (10000 ** (2 * (i // 2 ) / dim )) for i in range (dim )]
77
+ for p in range (n_channels )])
78
+ pe [:, 0 ::2 ] = torch .sin (pe [:, 0 ::2 ])
79
+ pe [:, 1 ::2 ] = torch .cos (pe [:, 1 ::2 ])
80
+ return rearrange (pe , '... -> 1 ...' )
81
+
64
82
# modules
65
83
66
84
class Attention (nn .Module ):
67
85
def __init__ (self , dim , num_heads = 8 , attention_dropout = 0.1 , projection_dropout = 0.1 ):
68
86
super ().__init__ ()
69
- self .num_heads = num_heads
70
- head_dim = dim // self .num_heads
87
+ self .heads = num_heads
88
+ head_dim = dim // self .heads
71
89
self .scale = head_dim ** - 0.5
72
90
73
91
self .qkv = nn .Linear (dim , dim * 3 , bias = False )
@@ -77,17 +95,20 @@ def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0
77
95
78
96
def forward (self , x ):
79
97
B , N , C = x .shape
80
- qkv = self .qkv (x ).reshape (B , N , 3 , self .num_heads , C // self .num_heads ).permute (2 , 0 , 3 , 1 , 4 )
81
- q , k , v = qkv [0 ], qkv [1 ], qkv [2 ]
82
98
83
- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
99
+ qkv = self .qkv (x ).chunk (3 , dim = - 1 )
100
+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = self .heads ), qkv )
101
+
102
+ q = q * self .scale
103
+
104
+ attn = einsum ('b h i d, b h j d -> b h i j' , q , k )
84
105
attn = attn .softmax (dim = - 1 )
85
106
attn = self .attn_drop (attn )
86
107
87
- x = ( attn @ v ). transpose ( 1 , 2 ). reshape ( B , N , C )
88
- x = self . proj ( x )
89
- x = self . proj_drop ( x )
90
- return x
108
+ x = einsum ( 'b h i j, b h j d -> b h i d' , attn , v )
109
+ x = rearrange ( x , 'b h n d -> b n (h d)' )
110
+
111
+ return self . proj_drop ( self . proj ( x ))
91
112
92
113
93
114
class TransformerEncoderLayer (nn .Module ):
@@ -97,7 +118,8 @@ class TransformerEncoderLayer(nn.Module):
97
118
"""
98
119
def __init__ (self , d_model , nhead , dim_feedforward = 2048 , dropout = 0.1 ,
99
120
attention_dropout = 0.1 , drop_path_rate = 0.1 ):
100
- super (TransformerEncoderLayer , self ).__init__ ()
121
+ super ().__init__ ()
122
+
101
123
self .pre_norm = nn .LayerNorm (d_model )
102
124
self .self_attn = Attention (dim = d_model , num_heads = nhead ,
103
125
attention_dropout = attention_dropout , projection_dropout = dropout )
@@ -108,50 +130,34 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
108
130
self .linear2 = nn .Linear (dim_feedforward , d_model )
109
131
self .dropout2 = nn .Dropout (dropout )
110
132
111
- self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0 else nn . Identity ()
133
+ self .drop_path = DropPath (drop_path_rate )
112
134
113
135
self .activation = F .gelu
114
136
115
- def forward (self , src : torch . Tensor , * args , ** kwargs ) -> torch . Tensor :
137
+ def forward (self , src , * args , ** kwargs ):
116
138
src = src + self .drop_path (self .self_attn (self .pre_norm (src )))
117
139
src = self .norm1 (src )
118
140
src2 = self .linear2 (self .dropout1 (self .activation (self .linear1 (src ))))
119
141
src = src + self .drop_path (self .dropout2 (src2 ))
120
142
return src
121
143
122
-
123
- def drop_path (x , drop_prob : float = 0. , training : bool = False ):
124
- """
125
- Obtained from: github.com:rwightman/pytorch-image-models
126
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
127
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
128
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
129
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
130
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
131
- 'survival rate' as the argument.
132
- """
133
- if drop_prob == 0. or not training :
134
- return x
135
- keep_prob = 1 - drop_prob
136
- shape = (x .shape [0 ],) + (1 ,) * (x .ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets
137
- random_tensor = keep_prob + torch .rand (shape , dtype = x .dtype , device = x .device )
138
- random_tensor .floor_ () # binarize
139
- output = x .div (keep_prob ) * random_tensor
140
- return output
141
-
142
-
143
144
class DropPath (nn .Module ):
144
- """
145
- Obtained from: github.com:rwightman/pytorch-image-models
146
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
147
- """
148
145
def __init__ (self , drop_prob = None ):
149
- super (DropPath , self ).__init__ ()
150
- self .drop_prob = drop_prob
146
+ super ().__init__ ()
147
+ self .drop_prob = float ( drop_prob )
151
148
152
149
def forward (self , x ):
153
- return drop_path (x , self .drop_prob , self .training )
150
+ batch , drop_prob , device , dtype = x .shape [0 ], self .drop_prob , x .device , x .dtype
151
+
152
+ if drop_prob <= 0. or not self .training :
153
+ return x
154
+
155
+ keep_prob = 1 - self .drop_prob
156
+ shape = (batch , * ((1 ,) * (x .ndim - 1 )))
154
157
158
+ keep_mask = torch .zeros (shape , device = device ).float ().uniform_ (0 , 1 ) < keep_prob
159
+ output = x .div (keep_prob ) * keep_mask .float ()
160
+ return output
155
161
156
162
class Tokenizer (nn .Module ):
157
163
def __init__ (self ,
@@ -164,34 +170,35 @@ def __init__(self,
164
170
activation = None ,
165
171
max_pool = True ,
166
172
conv_bias = False ):
167
- super (Tokenizer , self ).__init__ ()
173
+ super ().__init__ ()
168
174
169
175
n_filter_list = [n_input_channels ] + \
170
176
[in_planes for _ in range (n_conv_layers - 1 )] + \
171
177
[n_output_channels ]
172
178
179
+ n_filter_list_pairs = zip (n_filter_list [:- 1 ], n_filter_list [1 :])
180
+
173
181
self .conv_layers = nn .Sequential (
174
182
* [nn .Sequential (
175
- nn .Conv2d (n_filter_list [ i ], n_filter_list [ i + 1 ] ,
183
+ nn .Conv2d (chan_in , chan_out ,
176
184
kernel_size = (kernel_size , kernel_size ),
177
185
stride = (stride , stride ),
178
186
padding = (padding , padding ), bias = conv_bias ),
179
- nn .Identity () if activation is None else activation (),
187
+ nn .Identity () if not exists ( activation ) else activation (),
180
188
nn .MaxPool2d (kernel_size = pooling_kernel_size ,
181
189
stride = pooling_stride ,
182
190
padding = pooling_padding ) if max_pool else nn .Identity ()
183
191
)
184
- for i in range ( n_conv_layers )
192
+ for chan_in , chan_out in n_filter_list_pairs
185
193
])
186
194
187
- self .flattener = nn .Flatten (2 , 3 )
188
195
self .apply (self .init_weight )
189
196
190
197
def sequence_length (self , n_channels = 3 , height = 224 , width = 224 ):
191
198
return self .forward (torch .zeros ((1 , n_channels , height , width ))).shape [1 ]
192
199
193
200
def forward (self , x ):
194
- return self . flattener (self .conv_layers (x )). transpose ( - 2 , - 1 )
201
+ return rearrange (self .conv_layers (x ), 'b c h w -> b (h w) c' )
195
202
196
203
@staticmethod
197
204
def init_weight (m ):
@@ -214,106 +221,104 @@ def __init__(self,
214
221
sequence_length = None ,
215
222
* args , ** kwargs ):
216
223
super ().__init__ ()
217
- positional_embedding = positional_embedding if \
218
- positional_embedding in [ 'sine' , 'learnable' , 'none' ] else 'sine'
224
+ assert positional_embedding in { 'sine' , 'learnable' , 'none' }
225
+
219
226
dim_feedforward = int (embedding_dim * mlp_ratio )
220
227
self .embedding_dim = embedding_dim
221
228
self .sequence_length = sequence_length
222
229
self .seq_pool = seq_pool
223
230
224
- assert sequence_length is not None or positional_embedding == 'none' , \
231
+ assert exists ( sequence_length ) or positional_embedding == 'none' , \
225
232
f"Positional embedding is set to { positional_embedding } and" \
226
233
f" the sequence length was not specified."
227
234
228
235
if not seq_pool :
229
236
sequence_length += 1
230
- self .class_emb = nn .Parameter (torch .zeros (1 , 1 , self .embedding_dim ),
231
- requires_grad = True )
237
+ self .class_emb = nn .Parameter (torch .zeros (1 , 1 , self .embedding_dim ), requires_grad = True )
232
238
else :
233
239
self .attention_pool = nn .Linear (self .embedding_dim , 1 )
234
240
235
- if positional_embedding != 'none' :
236
- if positional_embedding == 'learnable' :
237
- self .positional_emb = nn .Parameter (torch .zeros (1 , sequence_length , embedding_dim ),
238
- requires_grad = True )
239
- nn .init .trunc_normal_ (self .positional_emb , std = 0.2 )
240
- else :
241
- self .positional_emb = nn .Parameter (self .sinusoidal_embedding (sequence_length , embedding_dim ),
242
- requires_grad = False )
243
- else :
241
+ if positional_embedding == 'none' :
244
242
self .positional_emb = None
243
+ elif positional_embedding == 'learnable' :
244
+ self .positional_emb = nn .Parameter (torch .zeros (1 , sequence_length , embedding_dim ),
245
+ requires_grad = True )
246
+ nn .init .trunc_normal_ (self .positional_emb , std = 0.2 )
247
+ else :
248
+ self .positional_emb = nn .Parameter (sinusoidal_embedding (sequence_length , embedding_dim ),
249
+ requires_grad = False )
245
250
246
251
self .dropout = nn .Dropout (p = dropout_rate )
252
+
247
253
dpr = [x .item () for x in torch .linspace (0 , stochastic_depth_rate , num_layers )]
254
+
248
255
self .blocks = nn .ModuleList ([
249
256
TransformerEncoderLayer (d_model = embedding_dim , nhead = num_heads ,
250
257
dim_feedforward = dim_feedforward , dropout = dropout_rate ,
251
- attention_dropout = attention_dropout , drop_path_rate = dpr [i ])
252
- for i in range (num_layers )])
258
+ attention_dropout = attention_dropout , drop_path_rate = layer_dpr )
259
+ for layer_dpr in dpr ])
260
+
253
261
self .norm = nn .LayerNorm (embedding_dim )
254
262
255
263
self .fc = nn .Linear (embedding_dim , num_classes )
256
264
self .apply (self .init_weight )
257
265
258
266
def forward (self , x ):
259
- if self .positional_emb is None and x .size (1 ) < self .sequence_length :
267
+ b = x .shape [0 ]
268
+
269
+ if not exists (self .positional_emb ) and x .size (1 ) < self .sequence_length :
260
270
x = F .pad (x , (0 , 0 , 0 , self .n_channels - x .size (1 )), mode = 'constant' , value = 0 )
261
271
262
272
if not self .seq_pool :
263
- cls_token = self .class_emb . expand ( x . shape [ 0 ], - 1 , - 1 )
273
+ cls_token = repeat ( self .class_emb , '1 1 d -> b 1 d' , b = b )
264
274
x = torch .cat ((cls_token , x ), dim = 1 )
265
275
266
- if self .positional_emb is not None :
276
+ if exists ( self .positional_emb ) :
267
277
x += self .positional_emb
268
278
269
279
x = self .dropout (x )
270
280
271
281
for blk in self .blocks :
272
282
x = blk (x )
283
+
273
284
x = self .norm (x )
274
285
275
286
if self .seq_pool :
276
- x = torch .matmul (F .softmax (self .attention_pool (x ), dim = 1 ).transpose (- 1 , - 2 ), x ).squeeze (- 2 )
287
+ attn_weights = rearrange (self .attention_pool (x ), 'b n 1 -> b n' )
288
+ x = einsum ('b n, b n d -> b d' , attn_weights .softmax (dim = 1 ), x )
277
289
else :
278
290
x = x [:, 0 ]
279
291
280
- x = self .fc (x )
281
- return x
292
+ return self .fc (x )
282
293
283
294
@staticmethod
284
295
def init_weight (m ):
285
296
if isinstance (m , nn .Linear ):
286
297
nn .init .trunc_normal_ (m .weight , std = .02 )
287
- if isinstance (m , nn .Linear ) and m .bias is not None :
298
+ if isinstance (m , nn .Linear ) and exists ( m .bias ) :
288
299
nn .init .constant_ (m .bias , 0 )
289
300
elif isinstance (m , nn .LayerNorm ):
290
301
nn .init .constant_ (m .bias , 0 )
291
302
nn .init .constant_ (m .weight , 1.0 )
292
303
293
- @staticmethod
294
- def sinusoidal_embedding (n_channels , dim ):
295
- pe = torch .FloatTensor ([[p / (10000 ** (2 * (i // 2 ) / dim )) for i in range (dim )]
296
- for p in range (n_channels )])
297
- pe [:, 0 ::2 ] = torch .sin (pe [:, 0 ::2 ])
298
- pe [:, 1 ::2 ] = torch .cos (pe [:, 1 ::2 ])
299
- return pe .unsqueeze (0 )
300
-
301
-
302
304
# CCT Main model
305
+
303
306
class CCT (nn .Module ):
304
- def __init__ (self ,
305
- img_size = 224 ,
306
- embedding_dim = 768 ,
307
- n_input_channels = 3 ,
308
- n_conv_layers = 1 ,
309
- kernel_size = 7 ,
310
- stride = 2 ,
311
- padding = 3 ,
312
- pooling_kernel_size = 3 ,
313
- pooling_stride = 2 ,
314
- pooling_padding = 1 ,
315
- * args , ** kwargs ):
316
- super (CCT , self ).__init__ ()
307
+ def __init__ (
308
+ self ,
309
+ img_size = 224 ,
310
+ embedding_dim = 768 ,
311
+ n_input_channels = 3 ,
312
+ n_conv_layers = 1 ,
313
+ kernel_size = 7 ,
314
+ stride = 2 ,
315
+ padding = 3 ,
316
+ pooling_kernel_size = 3 ,
317
+ pooling_stride = 2 ,
318
+ pooling_padding = 1 ,
319
+ * args , ** kwargs
320
+ ):
321
+ super ().__init__ ()
317
322
img_height , img_width = pair (img_size )
318
323
319
324
self .tokenizer = Tokenizer (n_input_channels = n_input_channels ,
0 commit comments