diff --git a/vit_pytorch/vit.py b/vit_pytorch/vit.py index 5b34a44..1cd3423 100644 --- a/vit_pytorch/vit.py +++ b/vit_pytorch/vit.py @@ -65,7 +65,7 @@ def forward(self, x): class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() - self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ @@ -77,8 +77,7 @@ def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x - - return self.norm(x) + return x class ViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): @@ -90,7 +89,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width - assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling).' self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), @@ -106,9 +105,11 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.pool = pool - self.to_latent = nn.Identity() - self.mlp_head = nn.Linear(dim, num_classes) + self.mlp_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) def forward(self, img): x = self.to_patch_embedding(img) @@ -123,5 +124,4 @@ def forward(self, img): x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] - x = self.to_latent(x) return self.mlp_head(x)