|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +from typing import Dict, List, Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +from mmengine.structures import BaseDataElement |
| 6 | + |
| 7 | +from mmselfsup.registry import MODELS |
| 8 | +from mmselfsup.structures import SelfSupDataSample |
| 9 | +from .base import BaseModel |
| 10 | + |
| 11 | +@MODELS.register_module() |
| 12 | +class GreenMIM(BaseModel): |
| 13 | + """GreenMIM. |
| 14 | +
|
| 15 | + Implementation of `GreenMIM: Green Hierarchical Vision Transformer for Masked Image Modeling |
| 16 | + <https://arxiv.org/abs/2205.13515>`_. |
| 17 | + """ |
| 18 | + |
| 19 | + def extract_feat(self, |
| 20 | + inputs: List[torch.Tensor], |
| 21 | + data_samples: Optional[List[SelfSupDataSample]] = None, |
| 22 | + **kwarg) -> Tuple[torch.Tensor]: |
| 23 | + """The forward function to extract features from neck. |
| 24 | +
|
| 25 | + Args: |
| 26 | + inputs (List[torch.Tensor]): The input images. |
| 27 | +
|
| 28 | + Returns: |
| 29 | + Tuple[torch.Tensor]: Neck outputs. |
| 30 | + """ |
| 31 | + latent, mask, ids_restore = self.backbone(inputs[0]) |
| 32 | + pred = self.neck(latent, ids_restore) |
| 33 | + self.mask = mask |
| 34 | + return pred |
| 35 | + |
| 36 | + def reconstruct(self, |
| 37 | + features: torch.Tensor, |
| 38 | + data_samples: Optional[List[SelfSupDataSample]] = None, |
| 39 | + **kwargs) -> SelfSupDataSample: |
| 40 | + """The function is for image reconstruction. |
| 41 | +
|
| 42 | + Args: |
| 43 | + features (torch.Tensor): The input images. |
| 44 | + data_samples (List[SelfSupDataSample]): All elements required |
| 45 | + during the forward function. |
| 46 | +
|
| 47 | + Returns: |
| 48 | + SelfSupDataSample: The prediction from model. |
| 49 | + """ |
| 50 | + mean = kwargs['mean'] |
| 51 | + std = kwargs['std'] |
| 52 | + features = features * std + mean |
| 53 | + |
| 54 | + pred = self.head.unpatchify(features) |
| 55 | + pred = torch.einsum('nchw->nhwc', pred).detach().cpu() |
| 56 | + |
| 57 | + mask = self.mask.detach() |
| 58 | + mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 * |
| 59 | + 3) # (N, H*W, p*p*3) |
| 60 | + mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping |
| 61 | + mask = torch.einsum('nchw->nhwc', mask).detach().cpu() |
| 62 | + |
| 63 | + results = SelfSupDataSample() |
| 64 | + results.mask = BaseDataElement(**dict(value=mask)) |
| 65 | + results.pred = BaseDataElement(**dict(value=pred)) |
| 66 | + |
| 67 | + return results |
| 68 | + |
| 69 | + def patchify(self, imgs, patch_size): |
| 70 | + """ |
| 71 | + imgs: (N, 3, H, W) |
| 72 | + x: (N, L, patch_size**2 *3) |
| 73 | + """ |
| 74 | + p = patch_size |
| 75 | + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
| 76 | + |
| 77 | + h = w = imgs.shape[2] // p |
| 78 | + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
| 79 | + x = torch.einsum('nchpwq->nhwpqc', x) |
| 80 | + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) |
| 81 | + return x |
| 82 | + |
| 83 | + def loss(self, inputs: List[torch.Tensor], |
| 84 | + data_samples: List[SelfSupDataSample], |
| 85 | + **kwargs) -> Dict[str, torch.Tensor]: |
| 86 | + """The forward function in training. |
| 87 | +
|
| 88 | + Args: |
| 89 | + inputs (List[torch.Tensor]): The input images. |
| 90 | + data_samples (List[SelfSupDataSample]): All elements required |
| 91 | + during the forward function. |
| 92 | +
|
| 93 | + Returns: |
| 94 | + Dict[str, torch.Tensor]: A dictionary of loss components. |
| 95 | + """ |
| 96 | + # ids_restore: the same as that in original repo, which is used |
| 97 | + # to recover the original order of tokens in decoder. |
| 98 | + latent, mask, ids_restore = self.backbone(inputs[0]) |
| 99 | + pred = self.neck(latent, ids_restore) |
| 100 | + target = self.patchify(inputs[0], self.backbone.final_patch_size) |
| 101 | + loss = self.head(pred, target, mask) |
| 102 | + losses = dict(loss=loss) |
| 103 | + return losses |
0 commit comments