Skip to content

Commit b983bbe

Browse files
committed
release MobileViT, from @murufeng
1 parent 86a7302 commit b983bbe

File tree

3 files changed

+77
-56
lines changed

3 files changed

+77
-56
lines changed

README.md

+14-3
Original file line numberDiff line numberDiff line change
@@ -554,17 +554,17 @@ pred = nest(img) # (1, 1000)
554554

555555
<img src="./images/mbvit.png" width="400px"></img>
556556

557-
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and generalpurpose vision transformer for mobile devices. MobileViT presents a different
557+
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and general purpose vision transformer for mobile devices. MobileViT presents a different
558558
perspective for the global processing of information with transformers.
559559

560560
You can use it with the following code (ex. mobilevit_xs)
561561

562-
```
562+
```python
563563
import torch
564564
from vit_pytorch.mobile_vit import MobileViT
565565

566566
mbvit_xs = MobileViT(
567-
image_size=(256, 256),
567+
image_size = (256, 256),
568568
dims = [96, 120, 144],
569569
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
570570
num_classes = 1000
@@ -1190,6 +1190,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
11901190
}
11911191
```
11921192

1193+
```bibtex
1194+
@misc{mehta2021mobilevit,
1195+
title = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
1196+
author = {Sachin Mehta and Mohammad Rastegari},
1197+
year = {2021},
1198+
eprint = {2110.02178},
1199+
archivePrefix = {arXiv},
1200+
primaryClass = {cs.CV}
1201+
}
1202+
```
1203+
11931204
```bibtex
11941205
@misc{vaswani2017attention,
11951206
title = {Attention Is All You Need},

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.24.3',
6+
version = '0.25.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/mobile_vit.py

+62-52
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import torch.nn as nn
1010

1111
from einops import rearrange
12+
from einops.layers.torch import Reduce
1213

1314
def _make_divisible(v, divisor, min_value=None):
14-
1515
if min_value is None:
1616
min_value = divisor
1717
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
@@ -20,7 +20,7 @@ def _make_divisible(v, divisor, min_value=None):
2020
return new_v
2121

2222

23-
def Conv_BN_ReLU(inp, oup, kernel, stride=1):
23+
def conv_bn_relu(inp, oup, kernel, stride=1):
2424
return nn.Sequential(
2525
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
2626
nn.BatchNorm2d(oup),
@@ -63,8 +63,6 @@ class Attention(nn.Module):
6363
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
6464
super().__init__()
6565
inner_dim = dim_head * heads
66-
project_out = not (heads == 1 and dim_head == dim)
67-
6866
self.heads = heads
6967
self.scale = dim_head ** -0.5
7068

@@ -74,7 +72,7 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
7472
self.to_out = nn.Sequential(
7573
nn.Linear(inner_dim, dim),
7674
nn.Dropout(dropout)
77-
) if project_out else nn.Identity()
75+
)
7876

7977
def forward(self, x):
8078
qkv = self.to_qkv(x).chunk(3, dim=-1)
@@ -96,6 +94,7 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
9694
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
9795
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
9896
]))
97+
9998
def forward(self, x):
10099
for attn, ff in self.layers:
101100
x = attn(x) + x
@@ -136,23 +135,24 @@ def __init__(self, inp, oup, stride=1, expand_ratio=4):
136135
)
137136

138137
def forward(self, x):
138+
out = self.conv(x)
139+
139140
if self.identity:
140-
return x + self.conv(x)
141-
else:
142-
return self.conv(x)
141+
out = out + x
142+
return out
143143

144144
class MobileViTBlock(nn.Module):
145145
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
146146
super().__init__()
147147
self.ph, self.pw = patch_size
148148

149-
self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size)
149+
self.conv1 = conv_bn_relu(channel, channel, kernel_size)
150150
self.conv2 = conv_1x1_bn(channel, dim)
151151

152152
self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
153153

154154
self.conv3 = conv_1x1_bn(dim, channel)
155-
self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size)
155+
self.conv4 = conv_bn_relu(2 * channel, channel, kernel_size)
156156

157157
def forward(self, x):
158158
y = x.clone()
@@ -165,8 +165,7 @@ def forward(self, x):
165165
_, _, h, w = x.shape
166166
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
167167
x = self.transformer(x)
168-
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
169-
pw=self.pw)
168+
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw)
170169

171170
# Fusion
172171
x = self.conv3(x)
@@ -176,54 +175,65 @@ def forward(self, x):
176175

177176

178177
class MobileViT(nn.Module):
179-
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
178+
def __init__(
179+
self,
180+
image_size,
181+
dims,
182+
channels,
183+
num_classes,
184+
expansion = 4,
185+
kernel_size = 3,
186+
patch_size = (2, 2),
187+
depths = (2, 4, 3)
188+
):
180189
super().__init__()
190+
assert len(dims) == 3, 'dims must be a tuple of 3'
191+
assert len(depths) == 3, 'depths must be a tuple of 3'
192+
181193
ih, iw = image_size
182194
ph, pw = patch_size
183195
assert ih % ph == 0 and iw % pw == 0
184196

185-
L = [2, 4, 3]
186-
187-
self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2)
188-
189-
self.mv2 = nn.ModuleList([])
190-
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
191-
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
192-
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
193-
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
194-
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
195-
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
196-
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
197-
198-
self.mvit = nn.ModuleList([])
199-
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
200-
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
201-
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
202-
203-
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
204-
205-
self.pool = nn.AvgPool2d(ih // 32, 1)
206-
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
197+
init_dim, *_, last_dim = channels
198+
199+
self.conv1 = conv_bn_relu(3, init_dim, kernel=3, stride=2)
200+
201+
self.stem = nn.ModuleList([])
202+
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
203+
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
204+
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
205+
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
206+
207+
self.trunk = nn.ModuleList([])
208+
self.trunk.append(nn.ModuleList([
209+
MV2Block(channels[3], channels[4], 2, expansion),
210+
MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))
211+
]))
212+
213+
self.trunk.append(nn.ModuleList([
214+
MV2Block(channels[5], channels[6], 2, expansion),
215+
MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))
216+
]))
217+
218+
self.trunk.append(nn.ModuleList([
219+
MV2Block(channels[7], channels[8], 2, expansion),
220+
MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))
221+
]))
222+
223+
self.to_logits = nn.Sequential(
224+
conv_1x1_bn(channels[-2], last_dim),
225+
Reduce('b c h w -> b c', 'mean'),
226+
nn.Linear(channels[-1], num_classes, bias=False)
227+
)
207228

208229
def forward(self, x):
209230
x = self.conv1(x)
210-
x = self.mv2[0](x)
211-
212-
x = self.mv2[1](x)
213-
x = self.mv2[2](x)
214-
x = self.mv2[3](x)
215231

216-
x = self.mv2[4](x)
217-
x = self.mvit[0](x)
232+
for conv in self.stem:
233+
x = conv(x)
218234

219-
x = self.mv2[5](x)
220-
x = self.mvit[1](x)
221-
222-
x = self.mv2[6](x)
223-
x = self.mvit[2](x)
224-
x = self.conv2(x)
225-
226-
x = self.pool(x).view(-1, x.shape[1])
227-
x = self.fc(x)
228-
return x
235+
for conv, attn in self.trunk:
236+
x = conv(x)
237+
x = attn(x)
229238

239+
return self.to_logits(x)

0 commit comments

Comments
 (0)