Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why does 2d tokens become 1d tokens and then become 2d tokens? #12

Open
BeautySilly opened this issue Jul 27, 2022 · 1 comment
Open

Comments

@BeautySilly
Copy link

BeautySilly commented Jul 27, 2022

Hi, I have a question,in the van.py file of VAN-Segmentation:‘why does 2d tokens become 1d tokens and then become 2d tokens?’, in line 105, line 110 and line 223? The details are as follows:
In block class:

def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.permute(0, 2, 1).view(B, C, H, W)          # <---------this
        x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
                               * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
                               * self.mlp(self.norm2(x)))
        x = x.view(B, C, N).permute(0, 2, 1).               # <---------this
        return x

In VAN class:

def forward(self, x):
        B = x.shape[0]
        outs = []

        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            block = getattr(self, f"block{i + 1}")
            norm = getattr(self, f"norm{i + 1}")
            x, H, W = patch_embed(x)
            for blk in block:
                x = blk(x, H, W)
            x = norm(x)
            x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()          # <---------this
            outs.append(x)

        return outs

My teacher thinks this has a very profound meaning, and I think the author arbitrarily set it up to keep it consistent with the traditional ViT, there is no special meaning, but my explanation can't convince my teacher, so I seek the author's help.
Looking forward to your reply!
Best!

@XuRuihan
Copy link

Because the LayerNorm is performed on the last dimension (which need 1d tokens) and Conv is performed on the last two dimensions (which need 2d tokens)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants