Skip to content

Commit 5ae5557

Browse files
committed
add SimMIM
1 parent c5a4616 commit 5ae5557

File tree

4 files changed

+137
-1
lines changed

4 files changed

+137
-1
lines changed

README.md

+52
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
- [RegionViT](#regionvit)
2020
- [NesT](#nest)
2121
- [Masked Autoencoder](#masked-autoencoder)
22+
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
2223
- [Masked Patch Prediction](#masked-patch-prediction)
2324
- [Dino](#dino)
2425
- [Accessing Attention](#accessing-attention)
@@ -519,6 +520,46 @@ img = torch.randn(1, 3, 224, 224)
519520
pred = nest(img) # (1, 1000)
520521
```
521522

523+
## Simple Masked Image Modeling
524+
525+
<img src="./images/simmim.png" width="400px"/>
526+
527+
This <a href="https://arxiv.org/abs/2111.09886">paper</a> proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.
528+
529+
You can use this as follows
530+
531+
```python
532+
import torch
533+
from vit_pytorch import ViT
534+
from vit_pytorch.simmim import SimMIM
535+
536+
v = ViT(
537+
image_size = 256,
538+
patch_size = 32,
539+
num_classes = 1000,
540+
dim = 1024,
541+
depth = 6,
542+
heads = 8,
543+
mlp_dim = 2048
544+
)
545+
546+
mim = SimMIM(
547+
encoder = v,
548+
masking_ratio = 0.5 # they found 50% to yield the best results
549+
)
550+
551+
images = torch.randn(8, 3, 256, 256)
552+
553+
loss = mim(images)
554+
loss.backward()
555+
556+
# that's all!
557+
# do the above in a for loop many times with a lot of images and your vision transformer will learn
558+
559+
torch.save(v.state_dict(), './trained-vit.pt')
560+
```
561+
562+
522563
## Masked Autoencoder
523564

524565
<img src="./images/mae.png" width="400px"/>
@@ -1026,6 +1067,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
10261067
}
10271068
```
10281069

1070+
```bibtex
1071+
@misc{xie2021simmim,
1072+
title = {SimMIM: A Simple Framework for Masked Image Modeling},
1073+
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
1074+
year = {2021},
1075+
eprint = {2111.09886},
1076+
archivePrefix = {arXiv},
1077+
primaryClass = {cs.CV}
1078+
}
1079+
```
1080+
10291081
```bibtex
10301082
@misc{vaswani2017attention,
10311083
title = {Attention Is All You Need},

images/simmim.png

365 KB
Loading

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.22.0',
6+
version = '0.23.2',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/simmim.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
from einops import repeat
5+
6+
class SimMIM(nn.Module):
7+
def __init__(
8+
self,
9+
*,
10+
encoder,
11+
masking_ratio = 0.5
12+
):
13+
super().__init__()
14+
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
15+
self.masking_ratio = masking_ratio
16+
17+
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
18+
19+
self.encoder = encoder
20+
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
21+
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
22+
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
23+
24+
# simple linear head
25+
26+
self.mask_token = nn.Parameter(torch.randn(encoder_dim))
27+
self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)
28+
29+
def forward(self, img):
30+
device = img.device
31+
32+
# get patches
33+
34+
patches = self.to_patch(img)
35+
batch, num_patches, *_ = patches.shape
36+
37+
# for indexing purposes
38+
39+
batch_range = torch.arange(batch, device = device)[:, None]
40+
41+
# get positions
42+
43+
pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]
44+
45+
# patch to encoder tokens and add positions
46+
47+
tokens = self.patch_to_emb(patches)
48+
tokens = tokens + pos_emb
49+
50+
# prepare mask tokens
51+
52+
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
53+
mask_tokens = mask_tokens + pos_emb
54+
55+
# calculate of patches needed to be masked, and get positions (indices) to be masked
56+
57+
num_masked = int(self.masking_ratio * num_patches)
58+
masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
59+
masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()
60+
61+
# mask tokens
62+
63+
tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)
64+
65+
# attend with vision transformer
66+
67+
encoded = self.encoder.transformer(tokens)
68+
69+
# get the masked tokens
70+
71+
encoded_mask_tokens = encoded[batch_range, masked_indices]
72+
73+
# small linear projection for predicted pixel values
74+
75+
pred_pixel_values = self.to_pixels(encoded_mask_tokens)
76+
77+
# get the masked patches for the final reconstruction loss
78+
79+
masked_patches = patches[batch_range, masked_indices]
80+
81+
# calculate reconstruction loss
82+
83+
recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
84+
return recon_loss

0 commit comments

Comments
 (0)