9
9
import torch .nn as nn
10
10
11
11
from einops import rearrange
12
+ from einops .layers .torch import Reduce
12
13
13
14
def _make_divisible (v , divisor , min_value = None ):
14
-
15
15
if min_value is None :
16
16
min_value = divisor
17
17
new_v = max (min_value , int (v + divisor / 2 ) // divisor * divisor )
@@ -20,7 +20,7 @@ def _make_divisible(v, divisor, min_value=None):
20
20
return new_v
21
21
22
22
23
- def Conv_BN_ReLU (inp , oup , kernel , stride = 1 ):
23
+ def conv_bn_relu (inp , oup , kernel , stride = 1 ):
24
24
return nn .Sequential (
25
25
nn .Conv2d (inp , oup , kernel_size = kernel , stride = stride , padding = 1 , bias = False ),
26
26
nn .BatchNorm2d (oup ),
@@ -63,8 +63,6 @@ class Attention(nn.Module):
63
63
def __init__ (self , dim , heads = 8 , dim_head = 64 , dropout = 0. ):
64
64
super ().__init__ ()
65
65
inner_dim = dim_head * heads
66
- project_out = not (heads == 1 and dim_head == dim )
67
-
68
66
self .heads = heads
69
67
self .scale = dim_head ** - 0.5
70
68
@@ -74,7 +72,7 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
74
72
self .to_out = nn .Sequential (
75
73
nn .Linear (inner_dim , dim ),
76
74
nn .Dropout (dropout )
77
- ) if project_out else nn . Identity ()
75
+ )
78
76
79
77
def forward (self , x ):
80
78
qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
@@ -96,6 +94,7 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
96
94
PreNorm (dim , Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout )),
97
95
PreNorm (dim , FeedForward (dim , mlp_dim , dropout = dropout ))
98
96
]))
97
+
99
98
def forward (self , x ):
100
99
for attn , ff in self .layers :
101
100
x = attn (x ) + x
@@ -136,23 +135,24 @@ def __init__(self, inp, oup, stride=1, expand_ratio=4):
136
135
)
137
136
138
137
def forward (self , x ):
138
+ out = self .conv (x )
139
+
139
140
if self .identity :
140
- return x + self .conv (x )
141
- else :
142
- return self .conv (x )
141
+ out = out + x
142
+ return out
143
143
144
144
class MobileViTBlock (nn .Module ):
145
145
def __init__ (self , dim , depth , channel , kernel_size , patch_size , mlp_dim , dropout = 0. ):
146
146
super ().__init__ ()
147
147
self .ph , self .pw = patch_size
148
148
149
- self .conv1 = Conv_BN_ReLU (channel , channel , kernel_size )
149
+ self .conv1 = conv_bn_relu (channel , channel , kernel_size )
150
150
self .conv2 = conv_1x1_bn (channel , dim )
151
151
152
152
self .transformer = Transformer (dim , depth , 1 , 32 , mlp_dim , dropout )
153
153
154
154
self .conv3 = conv_1x1_bn (dim , channel )
155
- self .conv4 = Conv_BN_ReLU (2 * channel , channel , kernel_size )
155
+ self .conv4 = conv_bn_relu (2 * channel , channel , kernel_size )
156
156
157
157
def forward (self , x ):
158
158
y = x .clone ()
@@ -165,8 +165,7 @@ def forward(self, x):
165
165
_ , _ , h , w = x .shape
166
166
x = rearrange (x , 'b d (h ph) (w pw) -> b (ph pw) (h w) d' , ph = self .ph , pw = self .pw )
167
167
x = self .transformer (x )
168
- x = rearrange (x , 'b (ph pw) (h w) d -> b d (h ph) (w pw)' , h = h // self .ph , w = w // self .pw , ph = self .ph ,
169
- pw = self .pw )
168
+ x = rearrange (x , 'b (ph pw) (h w) d -> b d (h ph) (w pw)' , h = h // self .ph , w = w // self .pw , ph = self .ph , pw = self .pw )
170
169
171
170
# Fusion
172
171
x = self .conv3 (x )
@@ -176,54 +175,65 @@ def forward(self, x):
176
175
177
176
178
177
class MobileViT (nn .Module ):
179
- def __init__ (self , image_size , dims , channels , num_classes , expansion = 4 , kernel_size = 3 , patch_size = (2 , 2 )):
178
+ def __init__ (
179
+ self ,
180
+ image_size ,
181
+ dims ,
182
+ channels ,
183
+ num_classes ,
184
+ expansion = 4 ,
185
+ kernel_size = 3 ,
186
+ patch_size = (2 , 2 ),
187
+ depths = (2 , 4 , 3 )
188
+ ):
180
189
super ().__init__ ()
190
+ assert len (dims ) == 3 , 'dims must be a tuple of 3'
191
+ assert len (depths ) == 3 , 'depths must be a tuple of 3'
192
+
181
193
ih , iw = image_size
182
194
ph , pw = patch_size
183
195
assert ih % ph == 0 and iw % pw == 0
184
196
185
- L = [2 , 4 , 3 ]
186
-
187
- self .conv1 = Conv_BN_ReLU (3 , channels [0 ], kernel = 3 , stride = 2 )
188
-
189
- self .mv2 = nn .ModuleList ([])
190
- self .mv2 .append (MV2Block (channels [0 ], channels [1 ], 1 , expansion ))
191
- self .mv2 .append (MV2Block (channels [1 ], channels [2 ], 2 , expansion ))
192
- self .mv2 .append (MV2Block (channels [2 ], channels [3 ], 1 , expansion ))
193
- self .mv2 .append (MV2Block (channels [2 ], channels [3 ], 1 , expansion ))
194
- self .mv2 .append (MV2Block (channels [3 ], channels [4 ], 2 , expansion ))
195
- self .mv2 .append (MV2Block (channels [5 ], channels [6 ], 2 , expansion ))
196
- self .mv2 .append (MV2Block (channels [7 ], channels [8 ], 2 , expansion ))
197
-
198
- self .mvit = nn .ModuleList ([])
199
- self .mvit .append (MobileViTBlock (dims [0 ], L [0 ], channels [5 ], kernel_size , patch_size , int (dims [0 ] * 2 )))
200
- self .mvit .append (MobileViTBlock (dims [1 ], L [1 ], channels [7 ], kernel_size , patch_size , int (dims [1 ] * 4 )))
201
- self .mvit .append (MobileViTBlock (dims [2 ], L [2 ], channels [9 ], kernel_size , patch_size , int (dims [2 ] * 4 )))
202
-
203
- self .conv2 = conv_1x1_bn (channels [- 2 ], channels [- 1 ])
204
-
205
- self .pool = nn .AvgPool2d (ih // 32 , 1 )
206
- self .fc = nn .Linear (channels [- 1 ], num_classes , bias = False )
197
+ init_dim , * _ , last_dim = channels
198
+
199
+ self .conv1 = conv_bn_relu (3 , init_dim , kernel = 3 , stride = 2 )
200
+
201
+ self .stem = nn .ModuleList ([])
202
+ self .stem .append (MV2Block (channels [0 ], channels [1 ], 1 , expansion ))
203
+ self .stem .append (MV2Block (channels [1 ], channels [2 ], 2 , expansion ))
204
+ self .stem .append (MV2Block (channels [2 ], channels [3 ], 1 , expansion ))
205
+ self .stem .append (MV2Block (channels [2 ], channels [3 ], 1 , expansion ))
206
+
207
+ self .trunk = nn .ModuleList ([])
208
+ self .trunk .append (nn .ModuleList ([
209
+ MV2Block (channels [3 ], channels [4 ], 2 , expansion ),
210
+ MobileViTBlock (dims [0 ], depths [0 ], channels [5 ], kernel_size , patch_size , int (dims [0 ] * 2 ))
211
+ ]))
212
+
213
+ self .trunk .append (nn .ModuleList ([
214
+ MV2Block (channels [5 ], channels [6 ], 2 , expansion ),
215
+ MobileViTBlock (dims [1 ], depths [1 ], channels [7 ], kernel_size , patch_size , int (dims [1 ] * 4 ))
216
+ ]))
217
+
218
+ self .trunk .append (nn .ModuleList ([
219
+ MV2Block (channels [7 ], channels [8 ], 2 , expansion ),
220
+ MobileViTBlock (dims [2 ], depths [2 ], channels [9 ], kernel_size , patch_size , int (dims [2 ] * 4 ))
221
+ ]))
222
+
223
+ self .to_logits = nn .Sequential (
224
+ conv_1x1_bn (channels [- 2 ], last_dim ),
225
+ Reduce ('b c h w -> b c' , 'mean' ),
226
+ nn .Linear (channels [- 1 ], num_classes , bias = False )
227
+ )
207
228
208
229
def forward (self , x ):
209
230
x = self .conv1 (x )
210
- x = self .mv2 [0 ](x )
211
-
212
- x = self .mv2 [1 ](x )
213
- x = self .mv2 [2 ](x )
214
- x = self .mv2 [3 ](x )
215
231
216
- x = self .mv2 [ 4 ]( x )
217
- x = self . mvit [ 0 ] (x )
232
+ for conv in self .stem :
233
+ x = conv (x )
218
234
219
- x = self .mv2 [5 ](x )
220
- x = self .mvit [1 ](x )
221
-
222
- x = self .mv2 [6 ](x )
223
- x = self .mvit [2 ](x )
224
- x = self .conv2 (x )
225
-
226
- x = self .pool (x ).view (- 1 , x .shape [1 ])
227
- x = self .fc (x )
228
- return x
235
+ for conv , attn in self .trunk :
236
+ x = conv (x )
237
+ x = attn (x )
229
238
239
+ return self .to_logits (x )
0 commit comments