From dae16db719ea1bc77575c4eeff555699571ada28 Mon Sep 17 00:00:00 2001 From: Pavlo Molchanov Date: Fri, 3 May 2024 16:09:49 -0700 Subject: [PATCH] Update eradio_model.py updated radio file, cleaned a bit --- radio/eradio_model.py | 653 +++++++++--------------------------------- 1 file changed, 140 insertions(+), 513 deletions(-) diff --git a/radio/eradio_model.py b/radio/eradio_model.py index a3ca6b9..c19dc6f 100644 --- a/radio/eradio_model.py +++ b/radio/eradio_model.py @@ -13,10 +13,9 @@ # based on FasterViT, Swin Transformer, YOLOv8 -# FasterViT: +# FasterViT original model: # Ali Hatamizadeh, Greg Heinrich, Hongxu Yin, Andrew Tao, Jose M. Alvarez, Jan Kautz, and Pavlo Molchanov. "FasterViT: Fast Vision Transformers with Hierarchical Attention." arXiv preprint arXiv:2306.06189 (2023). -import timm import torch import torch.nn as nn from timm.models.registry import register_model @@ -367,7 +366,7 @@ def __init__(self, window_size, dim_in, dim_out, use_swiglu=True, subsample_ratio=1, dim_ratio=1, conv_base=False, do_windowing=True, multi_query=False, use_shift=0, - cpb_mlp_hidden=512, conv_groups_ratio=0): + cpb_mlp_hidden=512, conv_groups_ratio=0, pretrained_window_size=16): ''' Global Resolution Attention Block , see README for details Attention with subsampling to get a bigger receptive field for attention @@ -404,8 +403,6 @@ def __init__(self, window_size, dim_in, dim_out, if subsample_ratio == 1: # conv_groups_ratio=0 self.pre_conv = Conv2d_BN(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False) - # self.pre_conv = nn.Conv2d(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False) - # self.pre_conv_act = nn.ReLU6() #for simplicity: self.pre_conv_act = nn.Identity() if conv_groups_ratio == -1: @@ -421,7 +418,8 @@ def __init__(self, window_size, dim_in, dim_out, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, resolution=window_size, seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query, - shift_size=self.shift_size, cpb_mlp_hidden=cpb_mlp_hidden) + shift_size=self.shift_size, cpb_mlp_hidden=cpb_mlp_hidden, + pretrained_window_size=pretrained_window_size) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -533,7 +531,8 @@ def __init__(self, window_size, sr_ratio, layer_scale=1e-5, norm_layer=nn.LayerNorm, drop_path = 0, qkv_bias=False, qk_scale=1.0, use_swiglu=True, multi_query=False, conv_base=False, - use_shift=0, cpb_mlp_hidden=512, conv_groups_ratio=0) -> None: + use_shift=0, cpb_mlp_hidden=512, conv_groups_ratio=0, + pretrained_window_size=16) -> None: """ Args: input_resolution: input image resolution @@ -562,7 +561,8 @@ def __init__(self, window_size, sr_ratio, layer_scale=layer_scale, drop_path=drop_path, use_swiglu=use_swiglu, subsample_ratio=subsample_ratio, dim_ratio=dim_ratio, do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base, - use_shift=use_shift, cpb_mlp_hidden=cpb_mlp_hidden, conv_groups_ratio=conv_groups_ratio), + use_shift=use_shift, cpb_mlp_hidden=cpb_mlp_hidden, conv_groups_ratio=conv_groups_ratio, + pretrained_window_size=pretrained_window_size), ) def forward(self, x): @@ -731,7 +731,7 @@ class WindowAttention(nn.Module): # use a MLP trick to deal with various input image resolutions, then fold it to improve speed def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0, - seq_length=0, dim_out=None, multi_query=False, shift_size=0, cpb_mlp_hidden=512): + seq_length=0, dim_out=None, multi_query=False, shift_size=0, cpb_mlp_hidden=512, pretrained_window_size=16): # taken from EdgeViT and tweaked with attention bias. super().__init__() if not dim_out: dim_out = dim @@ -752,7 +752,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0 self.proj = nn.Linear(dim, dim_out, bias=False) # attention positional bias self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution], - pretrained_window_size=[resolution, resolution], + pretrained_window_size=[pretrained_window_size, pretrained_window_size], num_heads=num_heads, seq_length=seq_length, cpb_mlp_hidden=cpb_mlp_hidden) @@ -775,7 +775,8 @@ def forward(self, x, attn_mask = None): attn = (q @ k.transpose(-2, -1)) * self.scale - attn = self.pos_emb_funct(attn) + if 1: + attn = self.pos_emb_funct(attn) #add window shift if attn_mask is not None: @@ -790,9 +791,9 @@ def forward(self, x, attn_mask = None): -class FasterViTLayer(nn.Module): +class ERADIOLayer(nn.Module): """ - fastervitlayer + ERADIOlayer """ def __init__(self, @@ -820,6 +821,7 @@ def __init__(self, cpb_mlp_hidden=512, conv_groups_ratio=0, verbose: bool = True, + pretrained_window_size=16, ): """ @@ -840,6 +842,7 @@ def __init__(self, layer_scale: layer scaling coefficient. use_shift: SWIN like window shifting for half the window size for every alternating layer (considering multi-resolution) conv_groups_ratio: group ratio for conv when no subsampling in multi-res attention + pretrained_window_size: window size used for model training, used for positional embedding """ super().__init__() @@ -895,6 +898,7 @@ def __init__(self, cpb_mlp_hidden=cpb_mlp_hidden, use_shift =0 if ((not use_shift) or ((i) % 2 == 0)) else True , conv_groups_ratio=conv_groups_ratio, + pretrained_window_size=pretrained_window_size, )) self.blocks = nn.Sequential(*self.blocks) @@ -960,7 +964,7 @@ def forward(self, x): class HiResNeck(nn.Module): """ The block is used to output dense features from all stages - Otherwise, by default, only the last stage features are returned with FasterViTv2 + Otherwise, by default, only the last stage features are returned as in FasterViT """ def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled): @@ -984,12 +988,12 @@ def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsa if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output: feature_projection = nn.Sequential() if False: - feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse + # fast but not as good + feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output, full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio)) else: # B, in_channels, H, W -> B, in_channels, H*upsample_ratio, W*upsample_ratio - # print("upsample ratio", upsample_ratio, level_n_features_output, level_n_features_output) feature_projection.add_module("upsample", InterpolateLayer(scale_factor=upsample_ratio, mode='nearest')) feature_projection.add_module("conv1", nn.Conv2d(level_n_features_output, level_n_features_output, kernel_size=3, stride=1, padding=1, groups=level_n_features_output)) feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) @@ -1017,9 +1021,9 @@ def forward(self, x, il_level=-1, full_features=None): full_features = full_features + feature_projection return full_features -class FasterViT(nn.Module): +class ERADIO(nn.Module): """ - FasterViT + ERADIO """ def __init__(self, @@ -1055,6 +1059,7 @@ def __init__(self, cpb_mlp_hidden=512, conv_groups_ratio=0, verbose: bool = False, + pretrained_window_size = 16, **kwargs): """ Args: @@ -1083,6 +1088,7 @@ def __init__(self, if 0 then normal conv, if 1 then channels are independent, if -1 then no conv at all + pretrained_window_size: window size used during pretraining of the model, used for positional embedding """ super().__init__() @@ -1090,7 +1096,7 @@ def __init__(self, num_features = int(dim * 2 ** (len(depths) - 1)) self.num_classes = num_classes self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down) - # set return_full_features true if we want to return full features from all stages + # set return_full_features true if we want to return features from multiple resolution stages self.return_full_features = return_full_features self.use_neck = use_neck @@ -1104,7 +1110,7 @@ def __init__(self, for i in range(len(depths)): conv = True if (i == 0 or i == 1) else False - level = FasterViTLayer(dim=int(dim * 2 ** i), + level = ERADIOLayer(dim=int(dim * 2 ** i), depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], @@ -1126,7 +1132,9 @@ def __init__(self, cpb_mlp_hidden=cpb_mlp_hidden, use_shift=use_shift, conv_groups_ratio=conv_groups_ratio, - verbose=verbose) + verbose=verbose, + pretrained_window_size=pretrained_window_size, + ) self.levels.append(level) @@ -1141,6 +1149,8 @@ def __init__(self, self.avgpool = nn.AdaptiveAvgPool2d(1) self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) + + self.unsafe_mode = False # naively eradio supports input resolution that is divisible by 32. If you want to use other resolutions, set this to True - not recommended as accuracy will be affected. def _init_weights(self, m): if isinstance(m, nn.Linear): @@ -1162,9 +1172,6 @@ def no_weight_decay_keywords(self): return {'rpb'} def forward_features(self, x): - _, _, H, W = x.shape - if H % 32 != 0 or W % 32 != 0: - raise ValueError(f"E-RADIO requires input dimensions to be divisible by 32 but got H x W: {H} x {W}") x = self.patch_embed(x) full_features = None for il, level in enumerate(self.levels): @@ -1182,6 +1189,11 @@ def forward_features(self, x): return x, full_features def forward(self, x): + # do resolution check + if not self.unsafe_mode: + if x.shape[-1] % 32 != 0 or x.shape[-2] % 32 != 0: + raise ValueError("Input resolution must be divisible by 32. Model works for all inputs, but quality will degradate if /32 is not met. If you want to use other resolutions, and sure about it, set model.unsafe_mode = True.") + x, full_features = self.forward_features(x) x = self.avgpool(x) @@ -1208,9 +1220,9 @@ def switch_to_deploy(self): def change_window_size(self, new_window_size): """ - FasterViT employs windowed attention, which may be sensitive to the choice of this parameter, + ERADIO employs windowed attention, which may be sensitive to the choice of this parameter, especially in cases of uneven partitioning of the feature maps. - FasterViT allows for the adjustment of the window size after training, + ERADIO allows for the adjustment of the window size after training, making it adaptable to different input image resolutions. The recommended values for window size based on input resolution are as follows: @@ -1241,11 +1253,10 @@ def change_window_size(self, new_window_size): def set_optimal_window_size(self, image_dim, max_window_size = 16): """ - Using hand picked window size for various resolutions. - - FasterViT employs windowed attention, which may be sensitive to the choice of this parameter, + Picks the best window size for windowed attention based on the input image resolution. + ERADIO employs windowed attention, which may be sensitive to the choice of this parameter, especially in cases of uneven partitioning of the feature maps. - FasterViT allows for the adjustment of the window size after training, + ERADIO allows for the adjustment of the window size after training, making it adaptable to different input image resolutions. The recommended values for window size based on input resolution are as follows: @@ -1281,278 +1292,11 @@ def divisorGenerator(n): all_divisors = np.array(list(divisorGenerator(image_dim//32))) new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size)) - # for image_dim in [128, 224, 256, 384, 512, 768, 1024]: - # all_divisors = np.array(list(divisorGenerator(image_dim//32))) - # new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size)) - # print(f"Setting window size to {new_window_size} for image resolution {image_dim}") - self.change_window_size(new_window_size = new_window_size) -# 83.44200001953125 -@register_model -def fastervit2_small(pretrained=False, **kwargs): #, - model = FasterViT(depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=96, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.2, - sr_ratio=[1, 1, [1, 2], 1], - use_swiglu=False, - downsample_shuffle=False, - yolo_arch=True, - shuffle_down=False, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - -# 82.61 -@register_model -def fastervit2_tiny(pretrained=False, **kwargs): #, - model = FasterViT(depths=[1, 3, 4, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=80, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.2, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - downsample_shuffle=False, - yolo_arch=True, - shuffle_down=False, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - -#'top1', 84.31800001220704 -@register_model -def fastervit2_base(pretrained=False, **kwargs): - model = FasterViT(depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=128, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.2, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - -#84.39999999267579 -@register_model -def fastervit2_base_v1(pretrained=False, **kwargs): - model = FasterViT(depths=[4, 4, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=128, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.2, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - downsample_shuffle=False, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - -@register_model -def fastervit2_base_fullres1(pretrained=False, **kwargs): - model = FasterViT(depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=128, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.2, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - use_neck=True, - full_features_head_dim=1024, - neck_start_stage=2, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - -@register_model -def fastervit2_base_fullres2(pretrained=False, **kwargs): - model = FasterViT(depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=128, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.2, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - use_neck=True, - full_features_head_dim=512, - neck_start_stage=1, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - -@register_model -def fastervit2_base_fullres3(pretrained=False, **kwargs): - model = FasterViT(depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=128, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.2, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - use_neck=True, - full_features_head_dim=256, - neck_start_stage=1, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - -@register_model -def fastervit2_base_fullres4(pretrained=False, **kwargs): - model = FasterViT(depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=128, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.2, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - use_neck=True, - full_features_head_dim=256, - neck_start_stage=2, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - -@register_model -def fastervit2_base_fullres5(pretrained=False, **kwargs): - model = FasterViT(depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=128, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.2, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - use_neck=True, - full_features_head_dim=512, - neck_start_stage=2, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - -#84.87 -@register_model -def fastervit2_large(pretrained=False, **kwargs): - model = FasterViT(depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=128+64, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.3, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=False, - shuffle_down=False, - cpb_mlp_hidden=64, - conv_base=True, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - -@register_model -def fastervit2_large_fullres(pretrained=False, **kwargs): - model = FasterViT( - depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[None, None, [7, 7], 7], - dim=192, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.0, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - use_neck=True, - full_features_head_dim=1536, - neck_start_stage=2, - **kwargs, - ) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - - -@register_model -def fastervit2_large_fullres_ws8(pretrained=False, **kwargs): - model = FasterViT( - depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[None, None, [8, 8], 8], - dim=192, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.0, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - use_neck=True, - full_features_head_dim=1536, - neck_start_stage=2, - **kwargs, - ) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model - - @register_model -def fastervit2_large_fullres_ws16(pretrained=False, **kwargs): - model = FasterViT( +def eradio_large_fullres_ws16(pretrained=False, **kwargs): + model = ERADIO( depths=[3, 3, 5, 5], num_heads=[2, 4, 8, 16], window_size=[None, None, [16, 16], 16], @@ -1568,238 +1312,121 @@ def fastervit2_large_fullres_ws16(pretrained=False, **kwargs): use_neck=True, full_features_head_dim=1536, neck_start_stage=2, + conv_groups_ratio=1, + pretrained_window_size=16, **kwargs, ) if pretrained: model.load_state_dict(torch.load(pretrained)["state_dict"]) return model - @register_model -def fastervit2_large_fullres_ws32(pretrained=False, **kwargs): - model = FasterViT( - depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[None, None, [32, 32], 32], - dim=192, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.0, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - use_neck=True, - full_features_head_dim=1536, - neck_start_stage=2, - **kwargs, - ) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model +def eradio(pretrained=False, **kwargs): + return eradio_large_fullres_ws16(pretrained=pretrained, **kwargs) -#85.23% top1 -@register_model -def fastervit2_xlarge(pretrained=False, **kwargs): - model = FasterViT(depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=128+128+64, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.4, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=False, - shuffle_down=False, - cpb_mlp_hidden=64, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model -@register_model -def fastervit2_huge(pretrained=False, **kwargs): - model = FasterViT(depths=[3, 3, 5, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=128+128+128+64, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.2, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model +if __name__ == "__main__": + from ptflops import get_model_complexity_info + import argparse -# 81.61 -@register_model -def fastervit2_xtiny(pretrained=False, **kwargs): #, - model = FasterViT(depths=[1, 3, 4, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=64, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.1, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - downsample_shuffle=False, - yolo_arch=True, - shuffle_down=False, - cpb_mlp_hidden=64, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model + parser = argparse.ArgumentParser() + parser.add_argument("--bs", type=int, default=32) + args = parser.parse_args() -# 80.19 -@register_model -def fastervit2_xxtiny(pretrained=False, **kwargs): #, - model = FasterViT(depths=[1, 3, 4, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=48, - in_dim=64, - mlp_ratio=4, - drop_path_rate=0.05, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - downsample_shuffle=False, - yolo_arch=True, - shuffle_down=False, - cpb_mlp_hidden=64, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model + channel_last = True + compute_latency = False + bs = args.bs + resolution = 224 + + resolution = 512 + resolution = 224 + # resolution = 432 # will not work as it is not divisible by 32 -@register_model -# 77.0 -def fastervit2_xxxtiny(pretrained=False, **kwargs): #, - model = FasterViT(depths=[1, 3, 4, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=32, - in_dim=32, - mlp_ratio=4, - drop_path_rate=0.0, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - downsample_shuffle=False, - yolo_arch=True, - shuffle_down=False, - cpb_mlp_hidden=64, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model + device = "cuda" + device = "cpu" + model = eradio_large_fullres_ws16() + model.return_full_features = False # enable if need dense features at the output + # model.return_full_features = False + model.set_optimal_window_size(resolution) + """ + ERADIO employs windowed attention, which may be sensitive to the choice of this parameter, + especially in cases of uneven partitioning of the feature maps. + ERADIO allows for the adjustment of the window size after training, + making it adaptable to different input image resolutions. + The recommended values for window size based on input resolution are as follows: + + Input Resolution | Window Size + 224 | 7 + 256 | 8 + 386 | 12 + 512 | 16 + Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be + img_res/16/2 + for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size. + Manual way to change resolution -> model.change_window_size(resolution) + """ + + model.eval() + input_data = torch.randn((bs, 3, resolution, resolution), device=device) + + if channel_last: + input_data = input_data.to(memory_format=torch.channels_last) + model = model.to(memory_format=torch.channels_last) + + if device == "cuda": + model.cuda() + + if 1: + x = model(input_data) + if isinstance(x, list) or isinstance(x, tuple): + full_features = x[1] + x = x[0] + print("full_features shape", full_features.shape) + + output = model(input_data) + model.switch_to_deploy() + + macs, params = get_model_complexity_info(model, tuple([3, resolution, resolution]), + as_strings=False, print_per_layer_stat=False, verbose=False) + + print(f"Model stats: macs: {macs}, and params: {params}") + + # warm up + with torch.cuda.amp.autocast(): + for ii in range(10): + output = model(input_data) + + # speed + import time + import numpy as np + + starter, ender = ( + torch.cuda.Event(enable_timing=True), + torch.cuda.Event(enable_timing=True), + ) -@register_model -def fastervit2_xxxtiny_fullres(pretrained=False, **kwargs): - model = FasterViT(depths=[1, 3, 4, 5], - num_heads=[2, 4, 8, 16], - window_size=[8, 8, [7, 7], 7], - dim=32, - in_dim=32, - mlp_ratio=4, - drop_path_rate=0.0, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - downsample_shuffle=False, - yolo_arch=True, - shuffle_down=False, - cpb_mlp_hidden=64, - use_neck=True, - full_features_head_dim=128, - neck_start_stage=1, - conv_groups_ratio = 1, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model -@register_model -def eradio_xxxtiny(pretrained=False, **kwargs): # , - model = FasterViT( - depths=[1, 3, 4, 5], - num_heads=[2, 4, 8, 16], - window_size=[None, None, [16, 16], 16], - dim=32, - in_dim=32, - mlp_ratio=4, - drop_path_rate=0.0, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - yolo_arch=True, - shuffle_down=False, - conv_base=True, - use_neck=True, - full_features_head_dim=256, - neck_start_stage=2, - **kwargs, - ) - if pretrained: - model.load_state_dict(torch.load(pretrained)) - return model + timer = [] + start_time = time.time() + runs=100 + # with torch.cuda.amp.autocast(dtype=torch.bfloat16): + if device == "cuda": -@register_model -def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs): - model = FasterViT(depths=[1, 3, 4, 5], - num_heads=[2, 4, 8, 16], - window_size=[None, None, [12, 12], 12], - dim=32, - in_dim=32, - mlp_ratio=4, - drop_path_rate=0.0, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - downsample_shuffle=False, - yolo_arch=True, - shuffle_down=False, - cpb_mlp_hidden=64, - use_neck=True, - full_features_head_dim=256, - neck_start_stage=2, - conv_groups_ratio = 1, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model + with torch.cuda.amp.autocast(True): + for ii in range(runs): + starter.record() + start_time_loc = time.time() -@register_model -def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs): - model = FasterViT(depths=[1, 3, 4, 5], - num_heads=[2, 4, 8, 16], - window_size=[None, None, [16, 16], 16], - dim=32, - in_dim=32, - mlp_ratio=4, - drop_path_rate=0.0, - sr_ratio=[1, 1, [2, 1], 1], - use_swiglu=False, - downsample_shuffle=False, - yolo_arch=True, - shuffle_down=False, - cpb_mlp_hidden=64, - use_neck=True, - full_features_head_dim=256, - neck_start_stage=1, - conv_groups_ratio = 1, - **kwargs) - if pretrained: - model.load_state_dict(torch.load(pretrained)["state_dict"]) - return model + output = model(input_data) + ender.record() + torch.cuda.synchronize() + # timer.append(time.time()-start_time_loc) + timer.append(starter.elapsed_time(ender)/1000.0) + end_time = time.time() + print(f"Throughput {bs * 1.0 / ((end_time - start_time) / runs)}") + print(f"Throughput Med {int(bs * 1.0 / ((np.median(timer))))}") -@register_model -def eradio(pretrained=False, **kwargs): - return fastervit2_large_fullres_ws16(pretrained=pretrained, **kwargs)