Update timm universal (support transformer-style model)#1004
Update timm universal (support transformer-style model)#1004qubvel merged 14 commits intoqubvel-org:mainfrom
Conversation
Output Stride Not MatchingThe following models were removed from the list as their output strides do not match the expected values:
test codeimport torch
import segmentation_models_pytorch as smp
model_list = [
"inception_resnet_v2",
"inception_v3",
"inception_v4",
"legacy_xception",
"nasnetalarge",
"pnasnet5large",
]
if __name__ == "__main__":
x = torch.rand(1, 3, 256, 256)
for name in model_list:
model = smp.encoders.get_encoder(f"tu-{name}", weights=None).eval()
f = model(x)
print(name, [f_.detach().numpy().shape[2:] for f_ in f])outputinception_resnet_v2 [(256, 256), (125, 125), (60, 60), (29, 29), (14, 14), (6, 6)]
inception_v3 [(256, 256), (125, 125), (60, 60), (29, 29), (14, 14), (6, 6)]
inception_v4 [(256, 256), (125, 125), (62, 62), (29, 29), (14, 14), (6, 6)]
legacy_xception [(256, 256), (125, 125), (63, 63), (32, 32), (16, 16), (8, 8)]
nasnetalarge [(256, 256), (127, 127), (64, 64), (32, 32), (16, 16), (8, 8)]
pnasnet5large [(256, 256), (127, 127), (64, 64), (32, 32), (16, 16), (8, 8)]Renamed / Deprecated ModelsThe following models remain functional but are deprecated in
test codeimport torch
import segmentation_models_pytorch as smp
model_list = [
"mnasnet_a1",
"mnasnet_b1",
"efficientnet_b2a",
"efficientnet_b3a",
"seresnext26tn_32x4d",
]
if __name__ == "__main__":
x = torch.rand(1, 3, 256, 256)
for name in model_list:
model = smp.encoders.get_encoder(f"tu-{name}", weights=None).eval()outputUserWarning: Mapping deprecated model name mnasnet_a1 to current semnasnet_100.
UserWarning: Mapping deprecated model name mnasnet_b1 to current mnasnet_100.
UserWarning: Mapping deprecated model name efficientnet_b2a to current efficientnet_b2.
UserWarning: Mapping deprecated model name efficientnet_b3a to current efficientnet_b3.
UserWarning: Mapping deprecated model name seresnext26tn_32x4d to current seresnext26t_32x4d. |
Add New Traditional-Style Models
new support models
new support modelscspdarknet53
darknet17
darknet21
darknet53
darknetaa53
sedarknet21
vgg11
vgg11_bn
vgg13
vgg13_bn
vgg16
vgg16_bn
vgg19
vgg19_bnAdd New Transformer-Style ModelsChannel-First Models
new support modelsChannel-Last Models
These models are clearly transformer-style models, but their format is channel-last. new support models |
|
Hi @brianhou0208! Thanks for working with this challenging feature! My main concerns are:
Let me know what you think? |
|
Regarding NHWC format I got the answer from Ross, the following should work: getattr(model, "output_fmt", None) == "NHWC"That attribute is only set if models have NHWC format, that's why we have to use Also, there are some models that come with features in |
|
Hi @qubvel ,
Without using timm api test & resulttest codeimport torch
import timm
import segmentation_models_pytorch as smp
model_list = [
["dla34", 224],
["cspdarknet53", 224],
["efficientnet_x_b3", 224],
["efficientvit_m0", 224],
["inception_resnet_v2", 299],
["inception_v3", 299],
["inception_v4", 299],
["mambaout_tiny", 224],
["tresnet_m", 224],
["vit_tiny_patch16_224", 224],
]
if __name__ == "__main__":
for model_name, img_size in model_list:
x = torch.rand(1, 3, img_size, img_size)
model = timm.create_model(f"{model_name}", features_only=True).eval()
y = model(x)
print(f"timm-{model_name}-(C, H, W) = {(3, img_size, img_size)}")
print(f" Feature shape: {[f.detach().numpy().shape[1:] for f in y]}")
print(f" Feature channels: {model.feature_info.channels()}")
print(f" Feature reduction: {model.feature_info.reduction()}")outputtimm-dla34-(C, H, W) = (3, 224, 224)
Feature shape: [(32, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
Feature channels: [32, 64, 128, 256, 512]
Feature reduction: [2, 4, 8, 16, 32]
timm-cspdarknet53-(C, H, W) = (3, 224, 224)
Feature shape: [(32, 224, 224), (64, 112, 112), (128, 56, 56), (256, 28, 28), (512, 14, 14), (1024, 7, 7)]
Feature channels: [32, 64, 128, 256, 512, 1024]
Feature reduction: [1, 2, 4, 8, 16, 32]
timm-efficientnet_x_b3-(C, H, W) = (3, 224, 224)
Feature shape: [(96, 56, 56), (32, 56, 56), (48, 28, 28), (136, 14, 14), (384, 7, 7)]
Feature channels: [96, 32, 48, 136, 384]
Feature reduction: [2, 2, 4, 8, 16]
timm-efficientvit_m0-(C, H, W) = (3, 224, 224)
Feature shape: [(64, 14, 14), (128, 7, 7), (192, 4, 4)]
Feature channels: [64, 128, 192]
Feature reduction: [16, 32, 64]
timm-inception_resnet_v2-(C, H, W) = (3, 299, 299)
Feature shape: [(64, 147, 147), (192, 71, 71), (320, 35, 35), (1088, 17, 17), (1536, 8, 8)]
Feature channels: [64, 192, 320, 1088, 1536]
Feature reduction: [2, 4, 8, 16, 32]
timm-inception_v3-(C, H, W) = (3, 299, 299)
Feature shape: [(64, 147, 147), (192, 71, 71), (288, 35, 35), (768, 17, 17), (2048, 8, 8)]
Feature channels: [64, 192, 288, 768, 2048]
Feature reduction: [2, 4, 8, 16, 32]
timm-inception_v4-(C, H, W) = (3, 299, 299)
Feature shape: [(64, 147, 147), (160, 73, 73), (384, 35, 35), (1024, 17, 17), (1536, 8, 8)]
Feature channels: [64, 160, 384, 1024, 1536]
Feature reduction: [2, 4, 8, 16, 32]
timm-tresnet_m-(C, H, W) = (3, 224, 224)
Feature shape: [(64, 56, 56), (128, 28, 28), (1024, 14, 14), (2048, 7, 7)]
Feature channels: [64, 128, 1024, 2048]
Feature reduction: [4, 8, 16, 32]
timm-vit_tiny_patch16_224-(C, H, W) = (3, 224, 224)
Feature shape: [(192, 14, 14), (192, 14, 14), (192, 14, 14)]
Feature channels: [192, 192, 192]
Feature reduction: [16, 16, 16]However, you might still encounter cases like:
out_indices test & resulttestimport torch
import timm
import segmentation_models_pytorch as smp
model_list = [
["resnet18", 224, (0, 1, 2, 3, 4)],
["dla34", 224, (1, 2, 3, 4, 5)],
["mambaout_tiny", 224, (0, 1, 2, 3)],
["tresnet_m", 224, (1, 2, 3, 4)],
]
if __name__ == "__main__":
for model_name, img_size, out_indices in model_list:
x = torch.rand(1, 3, img_size, img_size)
model = timm.create_model(f"{model_name}", features_only=True, out_indices=out_indices).eval()
y = model(x)
print(f"timm-{model_name}")
print(f" Feature shape: {[f.detach().numpy().shape[1:] for f in y]}")
model = smp.encoders.get_encoder(f"tu-{model_name}").eval()
y = model(x)[1:]
print(f"smp-{model_name}")
print(f" Feature shape: {[f.detach().numpy().shape[1:] for f in y]}")
print()outputtimm-resnet18
Feature shape: [(64, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
smp-resnet18
Feature shape: [(64, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
timm-dla34
Feature shape: [(32, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
smp-dla34
Feature shape: [(32, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
timm-mambaout_tiny
Feature shape: [(56, 56, 96), (28, 28, 192), (14, 14, 384), (7, 7, 576)]
smp-mambaout_tiny
Feature shape: [(0, 112, 112), (96, 56, 56), (192, 28, 28), (384, 14, 14), (576, 7, 7)]
timm-tresnet_m
Feature shape: [(64, 56, 56), (128, 28, 28), (1024, 14, 14), (2048, 7, 7)]
smp-tresnet_m
Feature shape: [(0, 112, 112), (64, 56, 56), (128, 28, 28), (1024, 14, 14), (2048, 7, 7)] |
|
I spent some time reviewing all the models in Timm Support Backbone
Unsupported feature extraction: 34coat_lite_medium
coat_lite_medium_384
coat_lite_mini
coat_lite_small
coat_lite_tiny
coat_mini
coat_small
coat_tiny
convit_base
convit_small
convit_tiny
convmixer_768_32
convmixer_1024_20_ks9_p14
convmixer_1536_20
crossvit_9_240
crossvit_9_dagger_240
crossvit_15_240
crossvit_15_dagger_240
crossvit_15_dagger_408
crossvit_18_240
crossvit_18_dagger_240
crossvit_18_dagger_408
crossvit_base_240
crossvit_small_240
crossvit_tiny_240
gcvit_base
gcvit_small
gcvit_tiny
gcvit_xtiny
gcvit_xxtiny
tnt_b_patch16_224
tnt_s_patch16_224
visformer_small
visformer_tinyTests models: 14test_byobnet
test_convnext
test_convnext2
test_convnext3
test_efficientnet
test_efficientnet_evos
test_efficientnet_gn
test_efficientnet_ln
test_mambaout
test_nfnet
test_resnet
test_vit
test_vit2
test_vit3SMP Support Backbone
SMP Unsupported Backbone
Unsupported models: 310beit_base_patch16_224
beit_base_patch16_384
beit_large_patch16_224
beit_large_patch16_384
beit_large_patch16_512
beitv2_base_patch16_224
beitv2_large_patch16_224
cait_m36_384
cait_m48_448
cait_s24_224
cait_s24_384
cait_s36_384
cait_xs24_384
cait_xxs24_224
cait_xxs24_384
cait_xxs36_224
cait_xxs36_384
deit3_base_patch16_224
deit3_base_patch16_384
deit3_huge_patch14_224
deit3_large_patch16_224
deit3_large_patch16_384
deit3_medium_patch16_224
deit3_small_patch16_224
deit3_small_patch16_384
deit_base_distilled_patch16_224
deit_base_distilled_patch16_384
deit_base_patch16_224
deit_base_patch16_384
deit_small_distilled_patch16_224
deit_small_patch16_224
deit_tiny_distilled_patch16_224
deit_tiny_patch16_224
efficientnet_h_b5
efficientnet_x_b3
efficientnet_x_b5
efficientvit_m0
efficientvit_m1
efficientvit_m2
efficientvit_m3
efficientvit_m4
efficientvit_m5
eva02_base_patch14_224
eva02_base_patch14_448
eva02_base_patch16_clip_224
eva02_enormous_patch14_clip_224
eva02_large_patch14_224
eva02_large_patch14_448
eva02_large_patch14_clip_224
eva02_large_patch14_clip_336
eva02_small_patch14_224
eva02_small_patch14_336
eva02_tiny_patch14_224
eva02_tiny_patch14_336
eva_giant_patch14_224
eva_giant_patch14_336
eva_giant_patch14_560
eva_giant_patch14_clip_224
eva_large_patch14_196
eva_large_patch14_336
flexivit_base
flexivit_large
flexivit_small
gmixer_12_224
gmixer_24_224
gmlp_b16_224
gmlp_s16_224
gmlp_ti16_224
inception_resnet_v2
inception_v3
inception_v4
legacy_xception
levit_128
levit_128s
levit_192
levit_256
levit_256d
levit_384
levit_384_s8
levit_512
levit_512_s8
levit_512d
levit_conv_128
levit_conv_128s
levit_conv_192
levit_conv_256
levit_conv_256d
levit_conv_384
levit_conv_384_s8
levit_conv_512
levit_conv_512_s8
levit_conv_512d
mixer_b16_224
mixer_b32_224
mixer_l16_224
mixer_l32_224
mixer_s16_224
mixer_s32_224
nasnetalarge
pit_b_224
pit_b_distilled_224
pit_s_224
pit_s_distilled_224
pit_ti_224
pit_ti_distilled_224
pit_xs_224
pit_xs_distilled_224
pnasnet5large
resmlp_12_224
resmlp_24_224
resmlp_36_224
resmlp_big_24_224
samvit_base_patch16
samvit_base_patch16_224
samvit_huge_patch16
samvit_large_patch16
sequencer2d_l
sequencer2d_m
sequencer2d_s
vit_base_mci_224
vit_base_patch8_224
vit_base_patch14_dinov2
vit_base_patch14_reg4_dinov2
vit_base_patch16_18x2_224
vit_base_patch16_224
vit_base_patch16_224_miil
vit_base_patch16_384
vit_base_patch16_clip_224
vit_base_patch16_clip_384
vit_base_patch16_clip_quickgelu_224
vit_base_patch16_gap_224
vit_base_patch16_plus_240
vit_base_patch16_plus_clip_240
vit_base_patch16_reg4_gap_256
vit_base_patch16_rope_reg1_gap_256
vit_base_patch16_rpn_224
vit_base_patch16_siglip_224
vit_base_patch16_siglip_256
vit_base_patch16_siglip_384
vit_base_patch16_siglip_512
vit_base_patch16_siglip_gap_224
vit_base_patch16_siglip_gap_256
vit_base_patch16_siglip_gap_384
vit_base_patch16_siglip_gap_512
vit_base_patch16_xp_224
vit_base_patch32_224
vit_base_patch32_384
vit_base_patch32_clip_224
vit_base_patch32_clip_256
vit_base_patch32_clip_384
vit_base_patch32_clip_448
vit_base_patch32_clip_quickgelu_224
vit_base_patch32_plus_256
vit_base_r26_s32_224
vit_base_r50_s16_224
vit_base_r50_s16_384
vit_base_resnet26d_224
vit_base_resnet50d_224
vit_betwixt_patch16_gap_256
vit_betwixt_patch16_reg1_gap_256
vit_betwixt_patch16_reg4_gap_256
vit_betwixt_patch16_reg4_gap_384
vit_betwixt_patch16_rope_reg4_gap_256
vit_betwixt_patch32_clip_224
vit_giant_patch14_224
vit_giant_patch14_clip_224
vit_giant_patch14_dinov2
vit_giant_patch14_reg4_dinov2
vit_giant_patch16_gap_224
vit_gigantic_patch14_224
vit_gigantic_patch14_clip_224
vit_gigantic_patch14_clip_quickgelu_224
vit_huge_patch14_224
vit_huge_patch14_clip_224
vit_huge_patch14_clip_336
vit_huge_patch14_clip_378
vit_huge_patch14_clip_quickgelu_224
vit_huge_patch14_clip_quickgelu_378
vit_huge_patch14_gap_224
vit_huge_patch14_xp_224
vit_huge_patch16_gap_448
vit_intern300m_patch14_448
vit_large_patch14_224
vit_large_patch14_clip_224
vit_large_patch14_clip_336
vit_large_patch14_clip_quickgelu_224
vit_large_patch14_clip_quickgelu_336
vit_large_patch14_dinov2
vit_large_patch14_reg4_dinov2
vit_large_patch14_xp_224
vit_large_patch16_224
vit_large_patch16_384
vit_large_patch16_siglip_256
vit_large_patch16_siglip_384
vit_large_patch16_siglip_gap_256
vit_large_patch16_siglip_gap_384
vit_large_patch32_224
vit_large_patch32_384
vit_large_r50_s32_224
vit_large_r50_s32_384
vit_little_patch16_reg1_gap_256
vit_little_patch16_reg4_gap_256
vit_medium_patch16_clip_224
vit_medium_patch16_gap_240
vit_medium_patch16_gap_256
vit_medium_patch16_gap_384
vit_medium_patch16_reg1_gap_256
vit_medium_patch16_reg4_gap_256
vit_medium_patch16_rope_reg1_gap_256
vit_medium_patch32_clip_224
vit_mediumd_patch16_reg4_gap_256
vit_mediumd_patch16_reg4_gap_384
vit_mediumd_patch16_rope_reg1_gap_256
vit_pwee_patch16_reg1_gap_256
vit_relpos_base_patch16_224
vit_relpos_base_patch16_cls_224
vit_relpos_base_patch16_clsgap_224
vit_relpos_base_patch16_plus_240
vit_relpos_base_patch16_rpn_224
vit_relpos_base_patch32_plus_rpn_256
vit_relpos_medium_patch16_224
vit_relpos_medium_patch16_cls_224
vit_relpos_medium_patch16_rpn_224
vit_relpos_small_patch16_224
vit_relpos_small_patch16_rpn_224
vit_small_patch8_224
vit_small_patch14_dinov2
vit_small_patch14_reg4_dinov2
vit_small_patch16_18x2_224
vit_small_patch16_36x1_224
vit_small_patch16_224
vit_small_patch16_384
vit_small_patch32_224
vit_small_patch32_384
vit_small_r26_s32_224
vit_small_r26_s32_384
vit_small_resnet26d_224
vit_small_resnet50d_s16_224
vit_so150m_patch16_reg4_gap_256
vit_so150m_patch16_reg4_map_256
vit_so400m_patch14_siglip_224
vit_so400m_patch14_siglip_378
vit_so400m_patch14_siglip_384
vit_so400m_patch14_siglip_gap_224
vit_so400m_patch14_siglip_gap_378
vit_so400m_patch14_siglip_gap_384
vit_so400m_patch14_siglip_gap_448
vit_so400m_patch14_siglip_gap_896
vit_so400m_patch16_siglip_256
vit_so400m_patch16_siglip_gap_256
vit_srelpos_medium_patch16_224
vit_srelpos_small_patch16_224
vit_tiny_patch16_224
vit_tiny_patch16_384
vit_tiny_r_s16_p8_224
vit_tiny_r_s16_p8_384
vit_wee_patch16_reg1_gap_256
vit_xsmall_patch16_clip_224
vitamin_base_224
vitamin_large2_224
vitamin_large2_256
vitamin_large2_336
vitamin_large2_384
vitamin_large_224
vitamin_large_256
vitamin_large_336
vitamin_large_384
vitamin_small_224
vitamin_xlarge_256
vitamin_xlarge_336
vitamin_xlarge_384
volo_d1_224
volo_d1_384
volo_d2_224
volo_d2_384
volo_d3_224
volo_d3_448
volo_d4_224
volo_d4_448
volo_d5_224
volo_d5_448
volo_d5_512
xcit_large_24_p8_224
xcit_large_24_p8_384
xcit_large_24_p16_224
xcit_large_24_p16_384
xcit_medium_24_p8_224
xcit_medium_24_p8_384
xcit_medium_24_p16_224
xcit_medium_24_p16_384
xcit_nano_12_p8_224
xcit_nano_12_p8_384
xcit_nano_12_p16_224
xcit_nano_12_p16_384
xcit_small_12_p8_224
xcit_small_12_p8_384
xcit_small_12_p16_224
xcit_small_12_p16_384
xcit_small_24_p8_224
xcit_small_24_p8_384
xcit_small_24_p16_224
xcit_small_24_p16_384
xcit_tiny_12_p8_224
xcit_tiny_12_p8_384
xcit_tiny_12_p16_224
xcit_tiny_12_p16_384
xcit_tiny_24_p8_224
xcit_tiny_24_p8_384
xcit_tiny_24_p16_224
xcit_tiny_24_p16_384Check for unsupported models in SMPimport torch
import timm
if __name__ == "__main__":
for model_name in model_list:
model = timm.create_model(f"{model_name}", features_only=True).eval()
is_channel_last = getattr(model, "output_fmt", None) == "NHWC"
print(f"{model_name} {is_channel_last}")
print(model.feature_info.reduction())With downsample feature: 46efficientnet_h_b5 False
[2, 2, 4, 8, 16]
efficientnet_x_b3 False
[2, 2, 4, 8, 16]
efficientnet_x_b5 False
[2, 2, 4, 8, 16]
efficientvit_m0 False
[16, 32, 64]
efficientvit_m1 False
[16, 32, 64]
efficientvit_m2 False
[16, 32, 64]
efficientvit_m3 False
[16, 32, 64]
efficientvit_m4 False
[16, 32, 64]
efficientvit_m5 False
[16, 32, 64]
inception_resnet_v2 False
[2, 4, 8, 16, 32]
inception_v3 False
[2, 4, 8, 16, 32]
inception_v4 False
[2, 4, 8, 16, 32]
legacy_xception False
[2, 4, 8, 16, 32]
levit_128 False
[16, 32, 64]
levit_128s False
[16, 32, 64]
levit_192 False
[16, 32, 64]
levit_256 False
[16, 32, 64]
levit_256d False
[16, 32, 64]
levit_384 False
[16, 32, 64]
levit_384_s8 False
[8, 16, 32]
levit_512 False
[16, 32, 64]
levit_512_s8 False
[8, 16, 32]
levit_512d False
[16, 32, 64]
levit_conv_128 False
[16, 32, 64]
levit_conv_128s False
[16, 32, 64]
levit_conv_192 False
[16, 32, 64]
levit_conv_256 False
[16, 32, 64]
levit_conv_256d False
[16, 32, 64]
levit_conv_384 False
[16, 32, 64]
levit_conv_384_s8 False
[8, 16, 32]
levit_conv_512 False
[16, 32, 64]
levit_conv_512_s8 False
[8, 16, 32]
levit_conv_512d False
[16, 32, 64]
nasnetalarge False
[2, 4, 8, 16, 32]
pit_b_224 False
[6, 12, 24]
pit_b_distilled_224 False
[6, 12, 24]
pit_s_224 False
[7, 14, 28]
pit_s_distilled_224 False
[7, 14, 28]
pit_ti_224 False
[7, 14, 28]
pit_ti_distilled_224 False
[7, 14, 28]
pit_xs_224 False
[7, 14, 28]
pit_xs_distilled_224 False
[7, 14, 28]
pnasnet5large False
[2, 4, 8, 16, 32]
sequencer2d_l True
[7, 14, 14]
sequencer2d_m True
[7, 14, 14]
sequencer2d_s True
[7, 14, 14]Without downsample feature: 264beit_base_patch16_224 False
[16, 16, 16]
beit_base_patch16_384 False
[16, 16, 16]
beit_large_patch16_224 False
[16, 16, 16]
beit_large_patch16_384 False
[16, 16, 16]
beit_large_patch16_512 False
[16, 16, 16]
beitv2_base_patch16_224 False
[16, 16, 16]
beitv2_large_patch16_224 False
[16, 16, 16]
cait_m36_384 False
[16, 16, 16]
cait_m48_448 False
[16, 16, 16]
cait_s24_224 False
[16, 16, 16]
cait_s24_384 False
[16, 16, 16]
cait_s36_384 False
[16, 16, 16]
cait_xs24_384 False
[16, 16, 16]
cait_xxs24_224 False
[16, 16, 16]
cait_xxs24_384 False
[16, 16, 16]
cait_xxs36_224 False
[16, 16, 16]
cait_xxs36_384 False
[16, 16, 16]
deit3_base_patch16_224 False
[16, 16, 16]
deit3_base_patch16_384 False
[16, 16, 16]
deit3_huge_patch14_224 False
[14, 14, 14]
deit3_large_patch16_224 False
[16, 16, 16]
deit3_large_patch16_384 False
[16, 16, 16]
deit3_medium_patch16_224 False
[16, 16, 16]
deit3_small_patch16_224 False
[16, 16, 16]
deit3_small_patch16_384 False
[16, 16, 16]
deit_base_distilled_patch16_224 False
[16, 16, 16]
deit_base_distilled_patch16_384 False
[16, 16, 16]
deit_base_patch16_224 False
[16, 16, 16]
deit_base_patch16_384 False
[16, 16, 16]
deit_small_distilled_patch16_224 False
[16, 16, 16]
deit_small_patch16_224 False
[16, 16, 16]
deit_tiny_distilled_patch16_224 False
[16, 16, 16]
deit_tiny_patch16_224 False
[16, 16, 16]
eva02_base_patch14_224 False
[14, 14, 14]
eva02_base_patch14_448 False
[14, 14, 14]
eva02_base_patch16_clip_224 False
[16, 16, 16]
eva02_enormous_patch14_clip_224 False
[14, 14, 14]
eva02_large_patch14_224 False
[14, 14, 14]
eva02_large_patch14_448 False
[14, 14, 14]
eva02_large_patch14_clip_224 False
[14, 14, 14]
eva02_large_patch14_clip_336 False
[14, 14, 14]
eva02_small_patch14_224 False
[14, 14, 14]
eva02_small_patch14_336 False
[14, 14, 14]
eva02_tiny_patch14_224 False
[14, 14, 14]
eva02_tiny_patch14_336 False
[14, 14, 14]
eva_giant_patch14_224 False
[14, 14, 14]
eva_giant_patch14_336 False
[14, 14, 14]
eva_giant_patch14_560 False
[14, 14, 14]
eva_giant_patch14_clip_224 False
[14, 14, 14]
eva_large_patch14_196 False
[14, 14, 14]
eva_large_patch14_336 False
[14, 14, 14]
flexivit_base False
[16, 16, 16]
flexivit_large False
[16, 16, 16]
flexivit_small False
[16, 16, 16]
gmixer_12_224 False
[16, 16, 16]
gmixer_24_224 False
[16, 16, 16]
gmlp_b16_224 False
[16, 16, 16]
gmlp_s16_224 False
[16, 16, 16]
gmlp_ti16_224 False
[16, 16, 16]
mixer_b16_224 False
[16, 16, 16]
mixer_b32_224 False
[32, 32, 32]
mixer_l16_224 False
[16, 16, 16]
mixer_l32_224 False
[32, 32, 32]
mixer_s16_224 False
[16, 16, 16]
mixer_s32_224 False
[32, 32, 32]
resmlp_12_224 False
[16, 16, 16]
resmlp_24_224 False
[16, 16, 16]
resmlp_36_224 False
[16, 16, 16]
resmlp_big_24_224 False
[8, 8, 8]
samvit_base_patch16 False
[16, 16, 16]
samvit_base_patch16_224 False
[16, 16, 16]
samvit_huge_patch16 False
[16, 16, 16]
samvit_large_patch16 False
[16, 16, 16]
vit_base_mci_224 False
[16, 16, 16]
vit_base_patch8_224 False
[8, 8, 8]
vit_base_patch14_dinov2 False
[14, 14, 14]
vit_base_patch14_reg4_dinov2 False
[14, 14, 14]
vit_base_patch16_18x2_224 False
[16, 16, 16]
vit_base_patch16_224 False
[16, 16, 16]
vit_base_patch16_224_miil False
[16, 16, 16]
vit_base_patch16_384 False
[16, 16, 16]
vit_base_patch16_clip_224 False
[16, 16, 16]
vit_base_patch16_clip_384 False
[16, 16, 16]
vit_base_patch16_clip_quickgelu_224 False
[16, 16, 16]
vit_base_patch16_gap_224 False
[16, 16, 16]
vit_base_patch16_plus_240 False
[16, 16, 16]
vit_base_patch16_plus_clip_240 False
[16, 16, 16]
vit_base_patch16_reg4_gap_256 False
[16, 16, 16]
vit_base_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_base_patch16_rpn_224 False
[16, 16, 16]
vit_base_patch16_siglip_224 False
[16, 16, 16]
vit_base_patch16_siglip_256 False
[16, 16, 16]
vit_base_patch16_siglip_384 False
[16, 16, 16]
vit_base_patch16_siglip_512 False
[16, 16, 16]
vit_base_patch16_siglip_gap_224 False
[16, 16, 16]
vit_base_patch16_siglip_gap_256 False
[16, 16, 16]
vit_base_patch16_siglip_gap_384 False
[16, 16, 16]
vit_base_patch16_siglip_gap_512 False
[16, 16, 16]
vit_base_patch16_xp_224 False
[16, 16, 16]
vit_base_patch32_224 False
[32, 32, 32]
vit_base_patch32_384 False
[32, 32, 32]
vit_base_patch32_clip_224 False
[32, 32, 32]
vit_base_patch32_clip_256 False
[32, 32, 32]
vit_base_patch32_clip_384 False
[32, 32, 32]
vit_base_patch32_clip_448 False
[32, 32, 32]
vit_base_patch32_clip_quickgelu_224 False
[32, 32, 32]
vit_base_patch32_plus_256 False
[32, 32, 32]
vit_base_r26_s32_224 False
[32, 32, 32]
vit_base_r50_s16_224 False
[16, 16, 16]
vit_base_r50_s16_384 False
[16, 16, 16]
vit_base_resnet26d_224 False
[32, 32, 32]
vit_base_resnet50d_224 False
[32, 32, 32]
vit_betwixt_patch16_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg1_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg4_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg4_gap_384 False
[16, 16, 16]
vit_betwixt_patch16_rope_reg4_gap_256 False
[16, 16, 16]
vit_betwixt_patch32_clip_224 False
[32, 32, 32]
vit_giant_patch14_224 False
[14, 14, 14]
vit_giant_patch14_clip_224 False
[14, 14, 14]
vit_giant_patch14_dinov2 False
[14, 14, 14]
vit_giant_patch14_reg4_dinov2 False
[14, 14, 14]
vit_giant_patch16_gap_224 False
[16, 16, 16]
vit_gigantic_patch14_224 False
[14, 14, 14]
vit_gigantic_patch14_clip_224 False
[14, 14, 14]
vit_gigantic_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_huge_patch14_224 False
[14, 14, 14]
vit_huge_patch14_clip_224 False
[14, 14, 14]
vit_huge_patch14_clip_336 False
[14, 14, 14]
vit_huge_patch14_clip_378 False
[14, 14, 14]
vit_huge_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_huge_patch14_clip_quickgelu_378 False
[14, 14, 14]
vit_huge_patch14_gap_224 False
[14, 14, 14]
vit_huge_patch14_xp_224 False
[14, 14, 14]
vit_huge_patch16_gap_448 False
[16, 16, 16]
vit_intern300m_patch14_448 False
[14, 14, 14]
vit_large_patch14_224 False
[14, 14, 14]
vit_large_patch14_clip_224 False
[14, 14, 14]
vit_large_patch14_clip_336 False
[14, 14, 14]
vit_large_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_large_patch14_clip_quickgelu_336 False
[14, 14, 14]
vit_large_patch14_dinov2 False
[14, 14, 14]
vit_large_patch14_reg4_dinov2 False
[14, 14, 14]
vit_large_patch14_xp_224 False
[14, 14, 14]
vit_large_patch16_224 False
[16, 16, 16]
vit_large_patch16_384 False
[16, 16, 16]
vit_large_patch16_siglip_256 False
[16, 16, 16]
vit_large_patch16_siglip_384 False
[16, 16, 16]
vit_large_patch16_siglip_gap_256 False
[16, 16, 16]
vit_large_patch16_siglip_gap_384 False
[16, 16, 16]
vit_large_patch32_224 False
[32, 32, 32]
vit_large_patch32_384 False
[32, 32, 32]
vit_large_r50_s32_224 False
[32, 32, 32]
vit_large_r50_s32_384 False
[32, 32, 32]
vit_little_patch16_reg1_gap_256 False
[16, 16, 16]
vit_little_patch16_reg4_gap_256 False
[16, 16, 16]
vit_medium_patch16_clip_224 False
[16, 16, 16]
vit_medium_patch16_gap_240 False
[16, 16, 16]
vit_medium_patch16_gap_256 False
[16, 16, 16]
vit_medium_patch16_gap_384 False
[16, 16, 16]
vit_medium_patch16_reg1_gap_256 False
[16, 16, 16]
vit_medium_patch16_reg4_gap_256 False
[16, 16, 16]
vit_medium_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_medium_patch32_clip_224 False
[32, 32, 32]
vit_mediumd_patch16_reg4_gap_256 False
[16, 16, 16]
vit_mediumd_patch16_reg4_gap_384 False
[16, 16, 16]
vit_mediumd_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_pwee_patch16_reg1_gap_256 False
[16, 16, 16]
vit_relpos_base_patch16_224 False
[16, 16, 16]
vit_relpos_base_patch16_cls_224 False
[16, 16, 16]
vit_relpos_base_patch16_clsgap_224 False
[16, 16, 16]
vit_relpos_base_patch16_plus_240 False
[16, 16, 16]
vit_relpos_base_patch16_rpn_224 False
[16, 16, 16]
vit_relpos_base_patch32_plus_rpn_256 False
[32, 32, 32]
vit_relpos_medium_patch16_224 False
[16, 16, 16]
vit_relpos_medium_patch16_cls_224 False
[16, 16, 16]
vit_relpos_medium_patch16_rpn_224 False
[16, 16, 16]
vit_relpos_small_patch16_224 False
[16, 16, 16]
vit_relpos_small_patch16_rpn_224 False
[16, 16, 16]
vit_small_patch8_224 False
[8, 8, 8]
vit_small_patch14_dinov2 False
[14, 14, 14]
vit_small_patch14_reg4_dinov2 False
[14, 14, 14]
vit_small_patch16_18x2_224 False
[16, 16, 16]
vit_small_patch16_36x1_224 False
[16, 16, 16]
vit_small_patch16_224 False
[16, 16, 16]
vit_small_patch16_384 False
[16, 16, 16]
vit_small_patch32_224 False
[32, 32, 32]
vit_small_patch32_384 False
[32, 32, 32]
vit_small_r26_s32_224 False
[32, 32, 32]
vit_small_r26_s32_384 False
[32, 32, 32]
vit_small_resnet26d_224 False
[32, 32, 32]
vit_small_resnet50d_s16_224 False
[16, 16, 16]
vit_so150m_patch16_reg4_gap_256 False
[16, 16, 16]
vit_so150m_patch16_reg4_map_256 False
[16, 16, 16]
vit_so400m_patch14_siglip_224 False
[14, 14, 14]
vit_so400m_patch14_siglip_378 False
[14, 14, 14]
vit_so400m_patch14_siglip_384 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_224 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_378 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_384 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_448 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_896 False
[14, 14, 14]
vit_so400m_patch16_siglip_256 False
[16, 16, 16]
vit_so400m_patch16_siglip_gap_256 False
[16, 16, 16]
vit_srelpos_medium_patch16_224 False
[16, 16, 16]
vit_srelpos_small_patch16_224 False
[16, 16, 16]
vit_tiny_patch16_224 False
[16, 16, 16]
vit_tiny_patch16_384 False
[16, 16, 16]
vit_tiny_r_s16_p8_224 False
[32, 32, 32]
vit_tiny_r_s16_p8_384 False
[32, 32, 32]
vit_wee_patch16_reg1_gap_256 False
[16, 16, 16]
vit_xsmall_patch16_clip_224 False
[16, 16, 16]
vitamin_base_224 False
[16, 16, 16]
vitamin_large2_224 False
[16, 16, 16]
vitamin_large2_256 False
[16, 16, 16]
vitamin_large2_336 False
[16, 16, 16]
vitamin_large2_384 False
[16, 16, 16]
vitamin_large_224 False
[16, 16, 16]
vitamin_large_256 False
[16, 16, 16]
vitamin_large_336 False
[16, 16, 16]
vitamin_large_384 False
[16, 16, 16]
vitamin_small_224 False
[16, 16, 16]
vitamin_xlarge_256 False
[16, 16, 16]
vitamin_xlarge_336 False
[16, 16, 16]
vitamin_xlarge_384 False
[16, 16, 16]
volo_d1_224 False
[16, 16, 16]
volo_d1_384 False
[16, 16, 16]
volo_d2_224 False
[16, 16, 16]
volo_d2_384 False
[16, 16, 16]
volo_d3_224 False
[16, 16, 16]
volo_d3_448 False
[16, 16, 16]
volo_d4_224 False
[16, 16, 16]
volo_d4_448 False
[16, 16, 16]
volo_d5_224 False
[16, 16, 16]
volo_d5_448 False
[16, 16, 16]
volo_d5_512 False
[16, 16, 16]
xcit_large_24_p8_224 False
[8, 8, 8]
xcit_large_24_p8_384 False
[8, 8, 8]
xcit_large_24_p16_224 False
[16, 16, 16]
xcit_large_24_p16_384 False
[16, 16, 16]
xcit_medium_24_p8_224 False
[8, 8, 8]
xcit_medium_24_p8_384 False
[8, 8, 8]
xcit_medium_24_p16_224 False
[16, 16, 16]
xcit_medium_24_p16_384 False
[16, 16, 16]
xcit_nano_12_p8_224 False
[8, 8, 8]
xcit_nano_12_p8_384 False
[8, 8, 8]
xcit_nano_12_p16_224 False
[16, 16, 16]
xcit_nano_12_p16_384 False
[16, 16, 16]
xcit_small_12_p8_224 False
[8, 8, 8]
xcit_small_12_p8_384 False
[8, 8, 8]
xcit_small_12_p16_224 False
[16, 16, 16]
xcit_small_12_p16_384 False
[16, 16, 16]
xcit_small_24_p8_224 False
[8, 8, 8]
xcit_small_24_p8_384 False
[8, 8, 8]
xcit_small_24_p16_224 False
[16, 16, 16]
xcit_small_24_p16_384 False
[16, 16, 16]
xcit_tiny_12_p8_224 False
[8, 8, 8]
xcit_tiny_12_p8_384 False
[8, 8, 8]
xcit_tiny_12_p16_224 False
[16, 16, 16]
xcit_tiny_12_p16_384 False
[16, 16, 16]
xcit_tiny_24_p8_224 False
[8, 8, 8]
xcit_tiny_24_p8_384 False
[8, 8, 8]
xcit_tiny_24_p16_224 False
[16, 16, 16]
xcit_tiny_24_p16_384 False
[16, 16, 16] |
qubvel
left a comment
There was a problem hiding this comment.
Hi @brianhou0208! Thanks for continuing to work on this 🚀 It already looks really great. I just have one question:
There was a problem hiding this comment.
Why do we need to load a temporary model? I would try to avoid it if possible.
There was a problem hiding this comment.
I believe that a temporary model is necessary because we need to determine feature_info.reduction() to classify the model as traditional, transformer, or VGG style. This affects the range of out_indices to be used:
common_kwargs["out_indices"] = tuple(range(depth))- If
depth == 5,out_indicesis- traditional-style
(0, 1, 2, 3, 4) - transformer-style
(0, 1, 2, 3) - vgg-style
(0, 1, 2, 3, 4, 5)
- traditional-style
- If
depth == 3,out_indicesis- traditional-style
(0, 1, 2) - transformer-style
(0, 1) - vgg-style
(0, 1, 2, 3)
- traditional-style
Is there any other way to determine feature_info.reduction() in advance?
There was a problem hiding this comment.
Can we slice features in forward instead of providing "out_indicies"? Otherwise, I would recommend using pretrained=False for the tmp model and maybe initialize it on the meta device to avoid double memory consumption.
There was a problem hiding this comment.
In timm.create_model(), default is pretrained=False
I think initialize tmp model to torch.device("meta") is good
self.model = timm.create_model(name, pretrained=False, features_only=True).to("meta")what do you think?
There was a problem hiding this comment.
Explicit pretrained=False would be nice, for meta it should be something like this:
with torch.device("meta"):
tmp_model = timm.create_model(name, pretrained=False, features_only=True)+ without self.
+ let's name it with tmp_
There was a problem hiding this comment.
lets leave it as is for now, it can be optimized later if needed
There was a problem hiding this comment.
If we don't use additional variable names, it shouldn't take up extra memory?
renamed temp_model to self.model
Although the variable names will be a little confusing.
There was a problem hiding this comment.
As is I mean:
# Load a temporary model to analyze its feature hierarchy
try:
with torch.device("meta"):
tmp_model = timm.create_model(name, features_only=True)
except Exception:
tmp_model = timm.create_model(name, features_only=True)
sorry for the confusuion
There was a problem hiding this comment.
If we don't use additional variable names, it shouldn't take up extra memory?
don't think so, we still allocate twice.
- we have tmp model initialized and linked to
self.model - we initialize required model
- we unlink tmp model from
self.modelvar name and link required one
two models exist at a time
There was a problem hiding this comment.
I think you are right, thanks for your explanation
1. rename temporary model 2. create temporary model on meta device to speed up
|
Hi @qubvel , Thank you for your comment; it has made this PR more complete. However, I think they do not affect this PR. It's ready to be merged. |
qubvel
left a comment
There was a problem hiding this comment.
Thanks, can you please add one transfomers-like and one vgg-like encoders to tests? And we are good to merge
segmentation_models.pytorch/tests/test_models.py
Lines 7 to 17 in 34ee31d
|
@qubvel It's ready to merge, please check Since version |
qubvel
left a comment
There was a problem hiding this comment.
Thanks for delivering this super important feature!
Yeah, I will do a release 👍 |
Hi, @qubvel ,
This PR improves
TimmUniversalEncoderto better support transformer-style models, updates the documentation, and ensures compatibility withtimm==1.0.12.Key Updates
TimmUniversalEncoderto seamlessly handle both traditional (CNN-based) and transformer-style models.timm==1.0.12Details of Changes
TimmUniversalEncoderfeature_info.reduction()to determine whether a model is traditional or transformer-style.out_indicesfor models liketresnetanddlato ensure accurate feature extraction.Documentation
Testing
timm==1.0.12for all models withfeatures_onlyenabled.