-
Notifications
You must be signed in to change notification settings - Fork 3
Description
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for XCLIP:
size mismatch for visual.transformer.resblocks.0.message_attn.interactive_block.1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for visual.transformer.resblocks.1.message_attn.interactive_block.1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for visual.transformer.resblocks.2.message_attn.interactive_block.1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256]).
In model/align
class ILA(nn.Module):
def __init__(self, T=8, d_model=768, patch_size=16, input_resolution=224, is_training=True):
super().__init__()
self.T = T
self.W = input_resolution // patch_size
self.d_model = d_model
self.is_training = is_training
self.interactive_block = nn.Sequential(
nn.Conv2d(self.d_model * 2, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
)