diff --git a/core/vision_encoder/model.py b/core/vision_encoder/model.py index 0bb24f2..385a2d7 100644 --- a/core/vision_encoder/model.py +++ b/core/vision_encoder/model.py @@ -16,6 +16,8 @@ import torch.nn as nn from torch.nn import functional as F +from huggingface_hub import PyTorchModelHubMixin + from core.vision_encoder.pev1 import TextTransformer, VisionTransformer @@ -78,7 +80,11 @@ class CLIPTextCfg: norm_type: str = "layernorm" # "layernorm" or "rmsnorm" -class CLIP(nn.Module): +class CLIP(nn.Module, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/perception_models", + pipeline_tag="image-feature-extraction", + license="apache-2.0"): def __init__( self, embed_dim: int,