Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Changes to support DINOv2 in HF
  • Loading branch information
gheinrich committed Dec 10, 2024
commit 6fb0cd67dd1ca276c5d0d49cc42e4dd369140761
16 changes: 13 additions & 3 deletions hf_hub.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
from radio.adaptor_base import RadioOutput
from radio.adaptor_registry import adaptor_registry
from radio.adaptor_mlp import get_mlp_info_from_state
from radio.hf_model import RADIOConfig, RADIOModel
from radio.hf_model import RADIOConfig, RADIOModel, rename_all_gamma_to_weight_with_proxy
from test_hf import deterministic_grid_init


@@ -164,7 +164,7 @@ def main():

feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
feature_normalizer_config = None
if feat_norm_sd is not None:
if feat_norm_sd:
feature_normalizer_config = {
"embed_dim": feat_norm_sd['mean'].shape[0]
}
@@ -219,6 +219,10 @@ def main():
if inter_feat_norm_sd:
radio_model.radio_model.inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)

# Rename "gamma" parameters to "weight"
rename_all_gamma_to_weight_with_proxy(radio_model.radio_model)
radio_config.rename_gamma_to_weight = True

radio_model.eval().cuda()

# Sample inference with deterministic values.
@@ -240,7 +244,7 @@ def main():
hf_summary, hf_features = v.summary, v.features

print(
f"[{k}] Sample inference on tensor shape {x.shape} returned summary ",
f"[{k}] HF inference on tensor shape {x.shape} returned summary ",
f"with shape={hf_summary.shape} and std={hf_summary.std().item():.3}, ",
f"features with shape={hf_features.shape} and std={hf_features.std().item():.3}",
)
@@ -288,6 +292,12 @@ def main():
torchhub_output[k].features,
)

print(
f"[{k}] TorchHub inference on tensor shape {x.shape} returned summary ",
f"with shape={torchhub_summary.shape} and std={torchhub_summary.std().item():.3}, ",
f"features with shape={torchhub_features.shape} and std={torchhub_features.std().item():.3}",
)

# Make sure the shapes are the same.
assert (
hf_summary.shape == torchhub_summary.shape
7 changes: 7 additions & 0 deletions radio/common.py
Original file line number Diff line number Diff line change
@@ -87,6 +87,13 @@ class RadioResource:
max_resolution=2048,
preferred_resolution=Resolution(512, 512),
),
# RADIO-DINOv2
"radio_dinov2-g": RadioResource(
None, # TODO: add URL for DINOv2 student.
patch_size=14,
max_resolution=2044,
preferred_resolution=Resolution(518, 518),
),
}

DEFAULT_VERSION = "radio_v2.5-h"
33 changes: 33 additions & 0 deletions radio/hf_model.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,33 @@
from .extra_timm_models import *



def rename_all_gamma_to_weight_with_proxy(module):
"""
Renames all parameters named 'gamma' in a module (including submodules)
to 'weight' and sets up a property so that accesses to 'gamma' still work.
"""
# Recursively iterate through submodules
for submodule_name, submodule in module.named_modules():
# Get all parameters within the current submodule
for param_name, param in list(submodule.named_parameters(recurse=False)):
if 'gamma' in param_name:
# Generate the new name by replacing 'gamma' with 'weight'
new_name = param_name.replace('gamma', 'weight')

# Remove the old parameter and assign it with the new name
delattr(submodule, param_name)
setattr(submodule, new_name, nn.Parameter(param.data))

# Define a property to proxy access to the renamed parameter
def make_property(old_name, new_name):
return property(lambda self: getattr(self, new_name),
lambda self, value: setattr(self, new_name, value))

# Add the property to the submodule to proxy access to 'gamma'
setattr(submodule.__class__, param_name, make_property(param_name, new_name))


class RADIOConfig(PretrainedConfig):
"""Pretrained Hugging Face configuration for RADIO models."""

@@ -58,6 +85,7 @@ def __init__(
vitdet_window_size: Optional[int] = None,
feature_normalizer_config: Optional[dict] = None,
inter_feature_normalizer_config: Optional[dict] = None,
rename_gamma_to_weight: bool = False,
**kwargs,
):
self.args = args
@@ -79,9 +107,11 @@ def __init__(
self.vitdet_window_size = vitdet_window_size
self.feature_normalizer_config = feature_normalizer_config
self.inter_feature_normalizer_config = inter_feature_normalizer_config
self.rename_gamma_to_weight = rename_gamma_to_weight
super().__init__(**kwargs)



class RADIOModel(PreTrainedModel):
"""Pretrained Hugging Face model for RADIO.

@@ -149,6 +179,9 @@ def __init__(self, config: RADIOConfig):
inter_feature_normalizer=inter_feature_normalizer,
)

if config.rename_gamma_to_weight:
rename_all_gamma_to_weight_with_proxy(self.radio_model)

@property
def adaptors(self) -> nn.ModuleDict:
return self.radio_model.adaptors