From e0d1fb0bbb6071c5a6a2fae2b55f40961c0afb7a Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Thu, 3 Oct 2024 08:36:08 -0700 Subject: [PATCH 01/18] Feature normalizer on HF --- hf_hub.py | 34 +++++++++++++++++++++++++++++++++- radio/feature_normalizer.py | 4 ++-- radio/hf_model.py | 9 +++++++++ test_hf.py | 18 ++++++++++++++++++ 4 files changed, 62 insertions(+), 3 deletions(-) diff --git a/hf_hub.py b/hf_hub.py index 6902501..a5b5261 100644 --- a/hf_hub.py +++ b/hf_hub.py @@ -161,11 +161,21 @@ def main(): adaptor_configs[adaptor_name] = adaptor_config + + feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.') + + feature_normalizer_config = None + if feat_norm_sd is not None: + feature_normalizer_config = { + "embed_dim": feat_norm_sd['mean'].shape[0] + } + radio_config = RADIOConfig( vars(model_args), version=args.version, adaptor_names=adaptor_names, adaptor_configs=adaptor_configs, + feature_normalizer_config=feature_normalizer_config, ) radio_model = RADIOModel(radio_config) @@ -194,6 +204,10 @@ def main(): get_prefix_state_dict(state_dict, "input_conditioner.") ) + # Restore feature normalizer. + if feat_norm_sd: + radio_model.radio_model.feature_normalizer.load_state_dict(feat_norm_sd) + radio_model.eval().cuda() # Sample inference with deterministic values. @@ -262,6 +276,25 @@ def main(): print(f"{k} outputs matched!") + x_conditioned = radio_model.input_conditioner(x) + intermediates = radio_model.radio_model.forward_intermediates( + x_conditioned, + indices=[-1], + return_prefix_tokens=True, + norm=False, + stop_early=False, + output_fmt='NLC', + intermediates_only=True, + aggregation="sparse", + ) + print( + f"Intermediates inference returned ", + f"features with shape={intermediates[0].features.shape} and std={intermediates[0].features.std().item():.3}", + ) + assert torch.allclose(intermediates[0].features, torchhub_output["backbone"].features, atol=1e-6) + + print("All outputs matched!") + if args.push: # Push to HuggingFace Hub. huggingface_repo = args.hf_repo @@ -273,7 +306,6 @@ def main(): ) print(f"Pushed to {commit}") - if __name__ == "__main__": """Call the main entrypoiny.""" main() diff --git a/radio/feature_normalizer.py b/radio/feature_normalizer.py index 7d4cd27..6e46950 100644 --- a/radio/feature_normalizer.py +++ b/radio/feature_normalizer.py @@ -28,8 +28,8 @@ class FeatureNormalizer(nn.Module): def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32): super().__init__() - self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype)) - self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype)) + self.mean = nn.Parameter(torch.zeros(embed_dim, dtype=dtype), requires_grad=False) + self.tx = nn.Parameter(torch.eye(embed_dim, dtype=dtype), requires_grad=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x = _run_kernel(x, self.mean, self.tx) diff --git a/radio/hf_model.py b/radio/hf_model.py index 5ecb3cf..8629049 100644 --- a/radio/hf_model.py +++ b/radio/hf_model.py @@ -31,6 +31,7 @@ from .enable_cpe_support import enable_cpe from .enable_spectral_reparam import configure_spectral_reparam_from_args from .eradio_model import eradio +from .feature_normalizer import FeatureNormalizer from .radio_model import create_model_from_args from .radio_model import RADIOModel as RADIOModelBase, Resolution from .input_conditioner import get_default_conditioner, InputConditioner @@ -55,6 +56,7 @@ def __init__( adaptor_names: Union[str, List[str]] = None, adaptor_configs: Dict[str, Dict[str, int]] = None, vitdet_window_size: Optional[int] = None, + feature_normalizer_config: Optional[dict] = None, **kwargs, ): self.args = args @@ -74,6 +76,7 @@ def __init__( self.adaptor_names = adaptor_names self.adaptor_configs = adaptor_configs self.vitdet_window_size = vitdet_window_size + self.feature_normalizer_config = feature_normalizer_config super().__init__(**kwargs) @@ -118,6 +121,11 @@ def __init__(self, config: RADIOConfig): adaptor.head_idx = mlp_config["head_idx"] adaptors[adaptor_name] = adaptor + feature_normalizer = None + if config.feature_normalizer_config is not None: + # Actual normalization values will be restored when loading checkpoint weights. + feature_normalizer = FeatureNormalizer(config.feature_normalizer_config["embed_dim"]) + self.radio_model = RADIOModelBase( model, input_conditioner, @@ -127,6 +135,7 @@ def __init__(self, config: RADIOConfig): window_size=config.vitdet_window_size, preferred_resolution=config.preferred_resolution, adaptors=adaptors, + feature_normalizer=feature_normalizer, ) @property diff --git a/test_hf.py b/test_hf.py index f07ab98..ecd6782 100644 --- a/test_hf.py +++ b/test_hf.py @@ -44,6 +44,7 @@ def main(): python3 -m test_hf --hf-repo gheinrich/RADIO --torchhub-version ./radio_v2.1_bf16.pth.tar --torchhub-repo NVlabs/RADIO:dev/hf python3 -m test_hf --hf-repo gheinrich/RADIO --torchhub-version ./radio-v2.5-l_half.pth.tar --torchhub-repo NVlabs/RADIO:dev/hf python3 -m test_hf --hf-repo gheinrich/RADIO --torchhub-version ./radio-v2.5-l_half.pth.tar --adaptor-names siglip,sam + python3 -m test_hf --hf-repo gheinrich/RADIO-NORM --torchhub-version /lustre/fs6/portfolios/llmservice/users/mranzinger/output/evfm/hero/n32_8-19-24_vit-h-16_hero-v4_s3/checkpoints/last_norm_release_half.pth.tar --torchhub-repo NVlabs/RADIO:mranzinger/ship_paper """ parser = argparse.ArgumentParser() parser.add_argument("--hf-repo", help="Path to the HuggingFace repo", required=True) @@ -126,6 +127,23 @@ def main(): assert torch.allclose(hf_summary, torchhub_summary, atol=1e-6) assert torch.allclose(hf_features, torchhub_features, atol=1e-6) + intermediates = hf_model.radio_model.forward_intermediates( + hf_model.input_conditioner(x), + indices=[-1], + return_prefix_tokens=True, + norm=False, + stop_early=False, + output_fmt='NLC', + intermediates_only=True, + aggregation="sparse", + ) + print( + f"Intermediates inference returned summary ", + f"with shape={intermediates[0].summary.shape} and std={intermediates[0].summary.std().item():.3}, ", + f"features with shape={intermediates[0].features.shape} and std={intermediates[0].features.std().item():.3}", + ) + assert torch.allclose(intermediates[0].features, torchhub_output["backbone"].features, atol=1e-6) + print("All outputs matched!") # Infer a sample image. From 35f04d64e3fb53fa30ab447aa11233109cf73909 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 8 Oct 2024 04:43:43 -0700 Subject: [PATCH 02/18] Fix visualized_features for intermediates --- examples/visualize_features.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/visualize_features.py b/examples/visualize_features.py index c1ea0e3..12c69b4 100644 --- a/examples/visualize_features.py +++ b/examples/visualize_features.py @@ -195,9 +195,10 @@ def main(rank: int = 0, world_size: int = 1): output_fmt='NLC', intermediates_only=True, aggregation=args.intermediate_aggregation, + norm_alpha_scheme="none", ) assert args.adaptor_name is None - all_feat = [o[1] for o in outputs] + all_feat = outputs else: output = model(p_images) if args.adaptor_name: From 8a93dd36fcbdd2729390dc2b00cb7e9dd226e510 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 8 Oct 2024 04:44:27 -0700 Subject: [PATCH 03/18] Add intermediate layer checks to hf_hub --- hf_hub.py | 49 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/hf_hub.py b/hf_hub.py index a5b5261..d915574 100644 --- a/hf_hub.py +++ b/hf_hub.py @@ -163,19 +163,28 @@ def main(): feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.') - feature_normalizer_config = None if feat_norm_sd is not None: feature_normalizer_config = { "embed_dim": feat_norm_sd['mean'].shape[0] } + inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.') + inter_feature_normalizer_config = None + if inter_feat_norm_sd: + inter_feature_normalizer_config = { + "num_intermediates": inter_feat_norm_sd['means'].shape[0], + "embed_dim": inter_feat_norm_sd['means'].shape[1], + "rot_per_layer": inter_feat_norm_sd['rotation'].ndim == 3, + } + radio_config = RADIOConfig( vars(model_args), version=args.version, adaptor_names=adaptor_names, adaptor_configs=adaptor_configs, feature_normalizer_config=feature_normalizer_config, + inter_feature_normalizer_config=inter_feature_normalizer_config, ) radio_model = RADIOModel(radio_config) @@ -207,6 +216,8 @@ def main(): # Restore feature normalizer. if feat_norm_sd: radio_model.radio_model.feature_normalizer.load_state_dict(feat_norm_sd) + if inter_feat_norm_sd: + radio_model.radio_model.inter_feature_normalizer.load_state_dict(inter_feat_norm_sd) radio_model.eval().cuda() @@ -234,6 +245,25 @@ def main(): f"features with shape={hf_features.shape} and std={hf_features.std().item():.3}", ) + intermediates = radio_model.radio_model.forward_intermediates( + x, + indices=[-1], + return_prefix_tokens=True, + norm=False, + stop_early=False, + output_fmt='NLC', + intermediates_only=True, + aggregation="sparse", + ) + print( + f"Intermediates inference returned ", + f"features with shape={intermediates[0].features.shape} and std={intermediates[0].features.std().item():.3}", + ) + print("diff norm", (intermediates[0].features- hf_output["backbone"].features).norm()) + print("std", intermediates[0].features.std().item(), hf_output["backbone"].features.std().item()) + print("mean", intermediates[0].features.mean().item(), hf_output["backbone"].features.mean().item()) + #assert torch.allclose(intermediates[0].features, hf_output["backbone"].features, atol=1e-4) + # Infer using TorchHub model. print("Infer using TorchHub model...") torchhub_model = torch.hub.load( @@ -276,22 +306,7 @@ def main(): print(f"{k} outputs matched!") - x_conditioned = radio_model.input_conditioner(x) - intermediates = radio_model.radio_model.forward_intermediates( - x_conditioned, - indices=[-1], - return_prefix_tokens=True, - norm=False, - stop_early=False, - output_fmt='NLC', - intermediates_only=True, - aggregation="sparse", - ) - print( - f"Intermediates inference returned ", - f"features with shape={intermediates[0].features.shape} and std={intermediates[0].features.std().item():.3}", - ) - assert torch.allclose(intermediates[0].features, torchhub_output["backbone"].features, atol=1e-6) + print("All outputs matched!") From a91a8391a95981861a40a183b8c71a6eac1b390a Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 8 Oct 2024 04:45:03 -0700 Subject: [PATCH 04/18] Fix import in feature normalizer --- radio/enable_cpe_support.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/radio/enable_cpe_support.py b/radio/enable_cpe_support.py index 763dafb..8860720 100644 --- a/radio/enable_cpe_support.py +++ b/radio/enable_cpe_support.py @@ -14,7 +14,7 @@ from timm.models import VisionTransformer, checkpoint_seq -from radio.feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer +from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer from .extra_models import DinoWrapper from .vit_patch_generator import ViTPatchGenerator From 69765c579776d0cd2f3130b430a906093b528416 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 8 Oct 2024 04:45:47 -0700 Subject: [PATCH 05/18] Support for inter feature normalizer in HF model --- radio/hf_model.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/radio/hf_model.py b/radio/hf_model.py index 8629049..238320b 100644 --- a/radio/hf_model.py +++ b/radio/hf_model.py @@ -31,7 +31,7 @@ from .enable_cpe_support import enable_cpe from .enable_spectral_reparam import configure_spectral_reparam_from_args from .eradio_model import eradio -from .feature_normalizer import FeatureNormalizer +from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer from .radio_model import create_model_from_args from .radio_model import RADIOModel as RADIOModelBase, Resolution from .input_conditioner import get_default_conditioner, InputConditioner @@ -57,6 +57,7 @@ def __init__( adaptor_configs: Dict[str, Dict[str, int]] = None, vitdet_window_size: Optional[int] = None, feature_normalizer_config: Optional[dict] = None, + inter_feature_normalizer_config: Optional[dict] = None, **kwargs, ): self.args = args @@ -77,6 +78,7 @@ def __init__( self.adaptor_configs = adaptor_configs self.vitdet_window_size = vitdet_window_size self.feature_normalizer_config = feature_normalizer_config + self.inter_feature_normalizer_config = inter_feature_normalizer_config super().__init__(**kwargs) @@ -126,6 +128,14 @@ def __init__(self, config: RADIOConfig): # Actual normalization values will be restored when loading checkpoint weights. feature_normalizer = FeatureNormalizer(config.feature_normalizer_config["embed_dim"]) + inter_feature_normalizer = None + if config.inter_feature_normalizer_config is not None: + inter_feature_normalizer = IntermediateFeatureNormalizer( + config.inter_feature_normalizer_config["num_intermediates"], + config.inter_feature_normalizer_config["embed_dim"], + rot_per_layer=config.inter_feature_normalizer_config["rot_per_layer"], + dtype=dtype) + self.radio_model = RADIOModelBase( model, input_conditioner, @@ -136,6 +146,7 @@ def __init__(self, config: RADIOConfig): preferred_resolution=config.preferred_resolution, adaptors=adaptors, feature_normalizer=feature_normalizer, + inter_feature_normalizer=inter_feature_normalizer, ) @property From 070e4bfbde683a347bc77bc12a7abb3f1f7c9f26 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 8 Oct 2024 04:46:19 -0700 Subject: [PATCH 06/18] Support pulling HF from branch --- test_hf.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test_hf.py b/test_hf.py index ecd6782..64e9804 100644 --- a/test_hf.py +++ b/test_hf.py @@ -54,6 +54,9 @@ def main(): parser.add_argument( "--torchhub-repo", help="Path to the Torchhub repo", default="NVlabs/RADIO" ) + parser.add_argument( + "--hf-revision", help="HuggingFace revision to checkout", default="main" + ) parser.add_argument( "--adaptor-names", default=None, @@ -64,13 +67,13 @@ def main(): args = parser.parse_args() - hf_config = AutoConfig.from_pretrained(args.hf_repo, trust_remote_code=True) + hf_config = AutoConfig.from_pretrained(args.hf_repo, revision=args.hf_revision, trust_remote_code=True) if args.adaptor_names is not None: # Configure adaptors if specified on the command line. # This needs to happen before we instantiate the model. hf_config.adaptor_names = args.adaptor_names hf_model = AutoModel.from_pretrained( - args.hf_repo, trust_remote_code=True, config=hf_config + args.hf_repo, revision=args.hf_revision, trust_remote_code=True, config=hf_config ) hf_model.eval().cuda() @@ -142,12 +145,12 @@ def main(): f"with shape={intermediates[0].summary.shape} and std={intermediates[0].summary.std().item():.3}, ", f"features with shape={intermediates[0].features.shape} and std={intermediates[0].features.std().item():.3}", ) - assert torch.allclose(intermediates[0].features, torchhub_output["backbone"].features, atol=1e-6) + #assert torch.allclose(intermediates[0].features, torchhub_output["backbone"].features, atol=1e-6) print("All outputs matched!") # Infer a sample image. - image_processor = CLIPImageProcessor.from_pretrained(args.hf_repo) + image_processor = CLIPImageProcessor.from_pretrained(args.hf_repo, revision=args.hf_revision) image = Image.open("./examples/image1.png").convert("RGB") pixel_values = image_processor(images=image, return_tensors="pt").pixel_values From b25fc4427849a90e91b2ea5533edb242c124ed0e Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 8 Oct 2024 04:55:50 -0700 Subject: [PATCH 07/18] Revert to using buffers in feature normalizer --- radio/feature_normalizer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/radio/feature_normalizer.py b/radio/feature_normalizer.py index 6e46950..cdd0bba 100644 --- a/radio/feature_normalizer.py +++ b/radio/feature_normalizer.py @@ -28,8 +28,10 @@ class FeatureNormalizer(nn.Module): def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32): super().__init__() - self.mean = nn.Parameter(torch.zeros(embed_dim, dtype=dtype), requires_grad=False) - self.tx = nn.Parameter(torch.eye(embed_dim, dtype=dtype), requires_grad=False) +# self.mean = nn.Parameter(torch.zeros(embed_dim, dtype=dtype), requires_grad=False) +# self.tx = nn.Parameter(torch.eye(embed_dim, dtype=dtype), requires_grad=False) + self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype)) + self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype)) def forward(self, x: torch.Tensor) -> torch.Tensor: x = _run_kernel(x, self.mean, self.tx) @@ -49,15 +51,19 @@ def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Opti class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase): def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32): super().__init__() +# self.alphas = nn.Parameter(torch.ones(num_intermediates, dtype=dtype), requires_grad=False) self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype)) rot = torch.eye(embed_dim, dtype=dtype) if rot_per_layer: rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1) +# self.rotation = nn.Parameter(rot.contiguous(), requires_grad=False) +# self.means = nn.Parameter(torch.zeros(num_intermediates, embed_dim, dtype=dtype), requires_grad=False) self.register_buffer('rotation', rot.contiguous()) self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype)) + def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: if rot_index is None: rot_index = index From 8c011c50d95b0811437138669fa1d982f7545d33 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 8 Oct 2024 04:57:00 -0700 Subject: [PATCH 08/18] Pin version of albumentations --- examples/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/requirements.txt b/examples/requirements.txt index 86444c9..646e0cd 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -2,7 +2,7 @@ transformers datasets timm open_clip_torch -albumentations +albumentations==1.3.1 opencv-python==4.8.0.74 opencv-python-headless==4.8.0.74 git+https://github.com/facebookresearch/segment-anything.git From 96a11c8e97fa13c8594c179e1bb48d7c1ca1494c Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 8 Oct 2024 04:58:33 -0700 Subject: [PATCH 09/18] Fix E-RADIO mmseg --- mmseg/radio.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mmseg/radio.py b/mmseg/radio.py index 34f9cb1..59121de 100644 --- a/mmseg/radio.py +++ b/mmseg/radio.py @@ -70,6 +70,10 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: # Standard ViT case. patch_height, patch_width = self.base_model.model.patch_embed.patch_size features = features.reshape(B, math.ceil(H/patch_height), math.ceil(W/patch_width), C).permute(0, 3, 1, 2).contiguous() + else: + B, _, C = features.shape + patch_height = patch_width = 16 + features = features.reshape(B, math.ceil(H/patch_height), math.ceil(W/patch_width), C).permute(0, 3, 1, 2).contiguous() # IMPORTANT: prevent gradients from flowing back towards the backbone. features = features.detach() From 6fb0cd67dd1ca276c5d0d49cc42e4dd369140761 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Thu, 10 Oct 2024 04:23:32 -0700 Subject: [PATCH 10/18] Changes to support DINOv2 in HF --- hf_hub.py | 16 +++++++++++++--- radio/common.py | 7 +++++++ radio/hf_model.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/hf_hub.py b/hf_hub.py index d915574..a761760 100644 --- a/hf_hub.py +++ b/hf_hub.py @@ -21,7 +21,7 @@ from radio.adaptor_base import RadioOutput from radio.adaptor_registry import adaptor_registry from radio.adaptor_mlp import get_mlp_info_from_state -from radio.hf_model import RADIOConfig, RADIOModel +from radio.hf_model import RADIOConfig, RADIOModel, rename_all_gamma_to_weight_with_proxy from test_hf import deterministic_grid_init @@ -164,7 +164,7 @@ def main(): feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.') feature_normalizer_config = None - if feat_norm_sd is not None: + if feat_norm_sd: feature_normalizer_config = { "embed_dim": feat_norm_sd['mean'].shape[0] } @@ -219,6 +219,10 @@ def main(): if inter_feat_norm_sd: radio_model.radio_model.inter_feature_normalizer.load_state_dict(inter_feat_norm_sd) + # Rename "gamma" parameters to "weight" + rename_all_gamma_to_weight_with_proxy(radio_model.radio_model) + radio_config.rename_gamma_to_weight = True + radio_model.eval().cuda() # Sample inference with deterministic values. @@ -240,7 +244,7 @@ def main(): hf_summary, hf_features = v.summary, v.features print( - f"[{k}] Sample inference on tensor shape {x.shape} returned summary ", + f"[{k}] HF inference on tensor shape {x.shape} returned summary ", f"with shape={hf_summary.shape} and std={hf_summary.std().item():.3}, ", f"features with shape={hf_features.shape} and std={hf_features.std().item():.3}", ) @@ -288,6 +292,12 @@ def main(): torchhub_output[k].features, ) + print( + f"[{k}] TorchHub inference on tensor shape {x.shape} returned summary ", + f"with shape={torchhub_summary.shape} and std={torchhub_summary.std().item():.3}, ", + f"features with shape={torchhub_features.shape} and std={torchhub_features.std().item():.3}", + ) + # Make sure the shapes are the same. assert ( hf_summary.shape == torchhub_summary.shape diff --git a/radio/common.py b/radio/common.py index 0acfbae..3d393fc 100644 --- a/radio/common.py +++ b/radio/common.py @@ -87,6 +87,13 @@ class RadioResource: max_resolution=2048, preferred_resolution=Resolution(512, 512), ), + # RADIO-DINOv2 + "radio_dinov2-g": RadioResource( + None, # TODO: add URL for DINOv2 student. + patch_size=14, + max_resolution=2044, + preferred_resolution=Resolution(518, 518), + ), } DEFAULT_VERSION = "radio_v2.5-h" diff --git a/radio/hf_model.py b/radio/hf_model.py index 238320b..157cdb0 100644 --- a/radio/hf_model.py +++ b/radio/hf_model.py @@ -43,6 +43,33 @@ from .extra_timm_models import * + +def rename_all_gamma_to_weight_with_proxy(module): + """ + Renames all parameters named 'gamma' in a module (including submodules) + to 'weight' and sets up a property so that accesses to 'gamma' still work. + """ + # Recursively iterate through submodules + for submodule_name, submodule in module.named_modules(): + # Get all parameters within the current submodule + for param_name, param in list(submodule.named_parameters(recurse=False)): + if 'gamma' in param_name: + # Generate the new name by replacing 'gamma' with 'weight' + new_name = param_name.replace('gamma', 'weight') + + # Remove the old parameter and assign it with the new name + delattr(submodule, param_name) + setattr(submodule, new_name, nn.Parameter(param.data)) + + # Define a property to proxy access to the renamed parameter + def make_property(old_name, new_name): + return property(lambda self: getattr(self, new_name), + lambda self, value: setattr(self, new_name, value)) + + # Add the property to the submodule to proxy access to 'gamma' + setattr(submodule.__class__, param_name, make_property(param_name, new_name)) + + class RADIOConfig(PretrainedConfig): """Pretrained Hugging Face configuration for RADIO models.""" @@ -58,6 +85,7 @@ def __init__( vitdet_window_size: Optional[int] = None, feature_normalizer_config: Optional[dict] = None, inter_feature_normalizer_config: Optional[dict] = None, + rename_gamma_to_weight: bool = False, **kwargs, ): self.args = args @@ -79,9 +107,11 @@ def __init__( self.vitdet_window_size = vitdet_window_size self.feature_normalizer_config = feature_normalizer_config self.inter_feature_normalizer_config = inter_feature_normalizer_config + self.rename_gamma_to_weight = rename_gamma_to_weight super().__init__(**kwargs) + class RADIOModel(PreTrainedModel): """Pretrained Hugging Face model for RADIO. @@ -149,6 +179,9 @@ def __init__(self, config: RADIOConfig): inter_feature_normalizer=inter_feature_normalizer, ) + if config.rename_gamma_to_weight: + rename_all_gamma_to_weight_with_proxy(self.radio_model) + @property def adaptors(self) -> nn.ModuleDict: return self.radio_model.adaptors From b8fd92890ebf8e023ef04f9c86e11e69aa09f935 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Sat, 2 Nov 2024 10:45:08 -0700 Subject: [PATCH 11/18] Import forwarD_intermediates in hf_hub.py --- radio/hf_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/radio/hf_model.py b/radio/hf_model.py index 157cdb0..2d9dacf 100644 --- a/radio/hf_model.py +++ b/radio/hf_model.py @@ -32,6 +32,7 @@ from .enable_spectral_reparam import configure_spectral_reparam_from_args from .eradio_model import eradio from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer +from .forward_intermediates import forward_intermediates from .radio_model import create_model_from_args from .radio_model import RADIOModel as RADIOModelBase, Resolution from .input_conditioner import get_default_conditioner, InputConditioner @@ -41,6 +42,7 @@ # Register extra models from .extra_timm_models import * +from .extra_models import * From 0c681b9ee907753828d249ec1068105169b67f71 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Mon, 25 Nov 2024 00:20:39 -0800 Subject: [PATCH 12/18] Remove invalid 'enable-cudnn-attention' key --- hf_hub.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/hf_hub.py b/hf_hub.py index a761760..16e08e3 100644 --- a/hf_hub.py +++ b/hf_hub.py @@ -86,6 +86,14 @@ def main(): checkpoint = torch.load(args.checkpoint_path, map_location="cpu") model_args = checkpoint["args"] + # Remove invalid identifier. + if hasattr(model_args, "enable_cudnn_attention"): + print(f'Removing attribute: enable-cudnn-attention!') + delattr(model_args, "enable-cudnn-attention") + if hasattr(model_args, "device"): + print(f'Removing attribute: device!') + delattr(model_args, "device") + # Extract the state dict from the checkpoint. if "state_dict_ema" in checkpoint: state_dict = checkpoint["state_dict_ema"] @@ -178,8 +186,11 @@ def main(): "rot_per_layer": inter_feat_norm_sd['rotation'].ndim == 3, } + model_vars = vars(model_args) + model_vars.pop('enable-cudnn-attention', None) + radio_config = RADIOConfig( - vars(model_args), + model_vars, version=args.version, adaptor_names=adaptor_names, adaptor_configs=adaptor_configs, @@ -263,9 +274,6 @@ def main(): f"Intermediates inference returned ", f"features with shape={intermediates[0].features.shape} and std={intermediates[0].features.std().item():.3}", ) - print("diff norm", (intermediates[0].features- hf_output["backbone"].features).norm()) - print("std", intermediates[0].features.std().item(), hf_output["backbone"].features.std().item()) - print("mean", intermediates[0].features.mean().item(), hf_output["backbone"].features.mean().item()) #assert torch.allclose(intermediates[0].features, hf_output["backbone"].features, atol=1e-4) # Infer using TorchHub model. From 5a84b7d7a4c90a05cae1d6794b02d8ee1c6edddb Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Mon, 25 Nov 2024 00:21:54 -0800 Subject: [PATCH 13/18] Option for zero-shot sweeps --- examples/zero_shot_imagenet.py | 127 ++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 57 deletions(-) diff --git a/examples/zero_shot_imagenet.py b/examples/zero_shot_imagenet.py index 181e123..c042477 100644 --- a/examples/zero_shot_imagenet.py +++ b/examples/zero_shot_imagenet.py @@ -80,6 +80,7 @@ def main(rank: int = 0, world_size: int = 1): parser.add_argument('--torchhub-repo', help="Path to the Torchhub repo", default="NVlabs/RADIO" ) + parser.add_argument('--sweep', default=False, action='store_true') parser.add_argument('--use-huggingface', default=False, action='store_true', help='Use the huggingface model') parser.add_argument('--csv-out', type=str, default=None, @@ -99,27 +100,11 @@ def main(rank: int = 0, world_size: int = 1): ds_builder.download_and_prepare() num_examples = ds_builder.info.splits[args.split].num_examples - if args.resolution is None: - args.resolution = (model.preferred_resolution.height, model.preferred_resolution.width) + if args.resize_multiple is None: args.resize_multiple = getattr(model, 'min_resolution_step', model.patch_size) - transform = get_standard_transform(args.resolution, args.resize_multiple, preprocessor=preprocessor) - dataset = ds_builder.as_dataset(split=args.split) - dataset = dataset.to_iterable_dataset(num_shards=world_size * max(1, args.workers)) - dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) - dataset = dataset.map(lambda ex: dict(image=transform(ex['image']), label=torch.as_tensor(ex['label'], dtype=torch.int64))) - - loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, collate_fn=collate, - pin_memory=args.workers > 0, - drop_last=False, - ) - num_steps = round_up(num_examples, args.batch_size * world_size) - rank_print('Done') - rank_print(f'Description: {ds_builder.info.description}') - rank_print('Building Zero Shot Classifier...') adaptor = model.adaptors[args.adaptor_name] if hasattr(model, 'adaptors') else model classifier = get_clip_classifier( @@ -127,49 +112,77 @@ def main(rank: int = 0, world_size: int = 1): ).float() rank_print('Done') - rank_print('Classifying...') - topks = { - k: torch.tensor(0.0, dtype=torch.float32, device=device) - for k in (1, 5) - } - num_processed = 0 - with torch.inference_mode(), tqdm(total=num_examples, disable=rank > 0) as t: - for batches in loader: - for images, targets in batches: - images = images.to(device=device, non_blocking=True) - targets = targets.to(device=device, non_blocking=True) - - with torch.autocast(device.type, dtype=torch.bfloat16, enabled=args.amp): - output = model(images) - summary = output[args.adaptor_name].summary - summary = F.normalize(summary, dim=-1) - - logits = summary.to(classifier.dtype) @ classifier - - accs = accuracy(logits, targets, topk=topks.keys()) - for k, acc in zip(topks.keys(), accs): - topks[k].add_(acc * images.shape[0]) - num_processed += images.shape[0] - - t.set_postfix({'Rank': '0', **{f'Top-{k}': f'{v.item() / num_processed:.03f}' for k, v in topks.items()}}) - t.update(world_size * args.batch_size) - - if world_size > 1: - rank_print('\tWaiting for all ranks to complete...') - num_processed = torch.tensor(num_processed, device=device) - dist.reduce(num_processed, dst=0, op=dist.ReduceOp.SUM) - + # sweep through all resolutions from 224 to 1024 in steps of 32 + if args.sweep: + resolutions = list(range(224, 1024+1, 32)) + else: + if args.resolution is None: + args.resolution = (model.preferred_resolution.height, model.preferred_resolution.width) + resolutions = [args.resolution] + for resolution in resolutions: + if isinstance(resolution, int): + resolution = (resolution, resolution) + transform = get_standard_transform(resolution, args.resize_multiple, preprocessor=preprocessor) + dataset = ds_builder.as_dataset(split=args.split) + dataset = dataset.to_iterable_dataset(num_shards=world_size * max(1, args.workers)) + dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + dataset = dataset.map(lambda ex: dict(image=transform(ex['image']), label=torch.as_tensor(ex['label'], dtype=torch.int64))) + + loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, collate_fn=collate, + pin_memory=args.workers > 0, + drop_last=False, + ) + num_steps = round_up(num_examples, args.batch_size * world_size) + rank_print('Done') + rank_print(f'Description: {ds_builder.info.description}') + + rank_print(f'Classifying at resolution={resolution}...') + topks = { + k: torch.tensor(0.0, dtype=torch.float32, device=device) + for k in (1, 5) + } + num_processed = 0 + with torch.inference_mode(), tqdm(total=num_examples, disable=rank > 0) as t: + for batches in loader: + for images, targets in batches: + images = images.to(device=device, non_blocking=True) + targets = targets.to(device=device, non_blocking=True) + + with torch.autocast(device.type, dtype=torch.bfloat16, enabled=args.amp): + output = model(images) + summary = output[args.adaptor_name].summary + summary = F.normalize(summary, dim=-1) + + logits = summary.to(classifier.dtype) @ classifier + + accs = accuracy(logits, targets, topk=topks.keys()) + for k, acc in zip(topks.keys(), accs): + topks[k].add_(acc * images.shape[0]) + num_processed += images.shape[0] + + t.set_postfix({'Rank': '0', **{f'Top-{k}': f'{v.item() / num_processed:.03f}' for k, v in topks.items()}}) + t.update(world_size * args.batch_size) + + if world_size > 1: + rank_print('\tWaiting for all ranks to complete...') + num_processed = torch.tensor(num_processed, device=device) + dist.reduce(num_processed, dst=0, op=dist.ReduceOp.SUM) + + for k, acc in topks.items(): + dist.reduce(acc, dst=0, op=dist.ReduceOp.SUM) + rank_print('\tDone') + rank_print('Done') + + rank_print(f'Resolution: {args.resolution}') + rank_print('Accuracy:') for k, acc in topks.items(): - dist.reduce(acc, dst=0, op=dist.ReduceOp.SUM) - rank_print('\tDone') - rank_print('Done') + acc = (acc / num_processed).item() - rank_print(f'Resolution: {args.resolution}') - rank_print('Accuracy:') - for k, acc in topks.items(): - acc = (acc / num_processed).item() + rank_print(f'\tResolution: {resolution} Top {k}: {acc:.3f}') - rank_print(f'\tTop {k}: {acc:.3f}') + del loader + del dataset if rank == 0 and k == 1 and args.csv_out: with open(args.csv_out, 'a') as fd: From 5cd0a423dbee4853c65af07a0efc0f1f2031bcdc Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Mon, 25 Nov 2024 02:40:21 -0800 Subject: [PATCH 14/18] Print statement when renaming modules --- radio/hf_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/radio/hf_model.py b/radio/hf_model.py index 2d9dacf..5ee2f46 100644 --- a/radio/hf_model.py +++ b/radio/hf_model.py @@ -59,6 +59,8 @@ def rename_all_gamma_to_weight_with_proxy(module): # Generate the new name by replacing 'gamma' with 'weight' new_name = param_name.replace('gamma', 'weight') + print("In submodule {}: Renaming '{}' to '{}'".format(submodule_name, param_name, new_name)) + # Remove the old parameter and assign it with the new name delattr(submodule, param_name) setattr(submodule, new_name, nn.Parameter(param.data)) From d6db8bf23439010833a2f29adb08f34b4732285f Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Sat, 30 Nov 2024 22:27:32 -0800 Subject: [PATCH 15/18] Import dino arch This is needed to be able to load an HF model with code. --- radio/hf_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/radio/hf_model.py b/radio/hf_model.py index 5ee2f46..5aac522 100644 --- a/radio/hf_model.py +++ b/radio/hf_model.py @@ -28,6 +28,7 @@ from .adaptor_mlp import create_mlp_from_config from .adaptor_registry import adaptor_registry from .cls_token import ClsToken +from .dinov2_arch import dinov2_vitg14_reg from .enable_cpe_support import enable_cpe from .enable_spectral_reparam import configure_spectral_reparam_from_args from .eradio_model import eradio From 552c99d9669987eb0b29038eb26cc852adf4d677 Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 10 Dec 2024 08:28:00 -0800 Subject: [PATCH 16/18] Remove custom DINO-g version and support for renaming gamma --- hf_hub.py | 6 +----- radio/common.py | 7 ------- radio/hf_model.py | 34 ---------------------------------- 3 files changed, 1 insertion(+), 46 deletions(-) diff --git a/hf_hub.py b/hf_hub.py index 16e08e3..996789a 100644 --- a/hf_hub.py +++ b/hf_hub.py @@ -21,7 +21,7 @@ from radio.adaptor_base import RadioOutput from radio.adaptor_registry import adaptor_registry from radio.adaptor_mlp import get_mlp_info_from_state -from radio.hf_model import RADIOConfig, RADIOModel, rename_all_gamma_to_weight_with_proxy +from radio.hf_model import RADIOConfig, RADIOModel from test_hf import deterministic_grid_init @@ -230,10 +230,6 @@ def main(): if inter_feat_norm_sd: radio_model.radio_model.inter_feature_normalizer.load_state_dict(inter_feat_norm_sd) - # Rename "gamma" parameters to "weight" - rename_all_gamma_to_weight_with_proxy(radio_model.radio_model) - radio_config.rename_gamma_to_weight = True - radio_model.eval().cuda() # Sample inference with deterministic values. diff --git a/radio/common.py b/radio/common.py index 3d393fc..0acfbae 100644 --- a/radio/common.py +++ b/radio/common.py @@ -87,13 +87,6 @@ class RadioResource: max_resolution=2048, preferred_resolution=Resolution(512, 512), ), - # RADIO-DINOv2 - "radio_dinov2-g": RadioResource( - None, # TODO: add URL for DINOv2 student. - patch_size=14, - max_resolution=2044, - preferred_resolution=Resolution(518, 518), - ), } DEFAULT_VERSION = "radio_v2.5-h" diff --git a/radio/hf_model.py b/radio/hf_model.py index 5aac522..27bf97b 100644 --- a/radio/hf_model.py +++ b/radio/hf_model.py @@ -46,35 +46,6 @@ from .extra_models import * - -def rename_all_gamma_to_weight_with_proxy(module): - """ - Renames all parameters named 'gamma' in a module (including submodules) - to 'weight' and sets up a property so that accesses to 'gamma' still work. - """ - # Recursively iterate through submodules - for submodule_name, submodule in module.named_modules(): - # Get all parameters within the current submodule - for param_name, param in list(submodule.named_parameters(recurse=False)): - if 'gamma' in param_name: - # Generate the new name by replacing 'gamma' with 'weight' - new_name = param_name.replace('gamma', 'weight') - - print("In submodule {}: Renaming '{}' to '{}'".format(submodule_name, param_name, new_name)) - - # Remove the old parameter and assign it with the new name - delattr(submodule, param_name) - setattr(submodule, new_name, nn.Parameter(param.data)) - - # Define a property to proxy access to the renamed parameter - def make_property(old_name, new_name): - return property(lambda self: getattr(self, new_name), - lambda self, value: setattr(self, new_name, value)) - - # Add the property to the submodule to proxy access to 'gamma' - setattr(submodule.__class__, param_name, make_property(param_name, new_name)) - - class RADIOConfig(PretrainedConfig): """Pretrained Hugging Face configuration for RADIO models.""" @@ -90,7 +61,6 @@ def __init__( vitdet_window_size: Optional[int] = None, feature_normalizer_config: Optional[dict] = None, inter_feature_normalizer_config: Optional[dict] = None, - rename_gamma_to_weight: bool = False, **kwargs, ): self.args = args @@ -112,7 +82,6 @@ def __init__( self.vitdet_window_size = vitdet_window_size self.feature_normalizer_config = feature_normalizer_config self.inter_feature_normalizer_config = inter_feature_normalizer_config - self.rename_gamma_to_weight = rename_gamma_to_weight super().__init__(**kwargs) @@ -184,9 +153,6 @@ def __init__(self, config: RADIOConfig): inter_feature_normalizer=inter_feature_normalizer, ) - if config.rename_gamma_to_weight: - rename_all_gamma_to_weight_with_proxy(self.radio_model) - @property def adaptors(self) -> nn.ModuleDict: return self.radio_model.adaptors From a29c37f8082ae7ed3547a7748d2faee48e94f95b Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 10 Dec 2024 08:29:41 -0800 Subject: [PATCH 17/18] Remove zero shot sweep --- examples/zero_shot_imagenet.py | 127 +++++++++++++++------------------ 1 file changed, 57 insertions(+), 70 deletions(-) diff --git a/examples/zero_shot_imagenet.py b/examples/zero_shot_imagenet.py index c042477..181e123 100644 --- a/examples/zero_shot_imagenet.py +++ b/examples/zero_shot_imagenet.py @@ -80,7 +80,6 @@ def main(rank: int = 0, world_size: int = 1): parser.add_argument('--torchhub-repo', help="Path to the Torchhub repo", default="NVlabs/RADIO" ) - parser.add_argument('--sweep', default=False, action='store_true') parser.add_argument('--use-huggingface', default=False, action='store_true', help='Use the huggingface model') parser.add_argument('--csv-out', type=str, default=None, @@ -100,11 +99,27 @@ def main(rank: int = 0, world_size: int = 1): ds_builder.download_and_prepare() num_examples = ds_builder.info.splits[args.split].num_examples - + if args.resolution is None: + args.resolution = (model.preferred_resolution.height, model.preferred_resolution.width) if args.resize_multiple is None: args.resize_multiple = getattr(model, 'min_resolution_step', model.patch_size) + transform = get_standard_transform(args.resolution, args.resize_multiple, preprocessor=preprocessor) + dataset = ds_builder.as_dataset(split=args.split) + dataset = dataset.to_iterable_dataset(num_shards=world_size * max(1, args.workers)) + dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + dataset = dataset.map(lambda ex: dict(image=transform(ex['image']), label=torch.as_tensor(ex['label'], dtype=torch.int64))) + + loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, collate_fn=collate, + pin_memory=args.workers > 0, + drop_last=False, + ) + num_steps = round_up(num_examples, args.batch_size * world_size) + rank_print('Done') + rank_print(f'Description: {ds_builder.info.description}') + rank_print('Building Zero Shot Classifier...') adaptor = model.adaptors[args.adaptor_name] if hasattr(model, 'adaptors') else model classifier = get_clip_classifier( @@ -112,77 +127,49 @@ def main(rank: int = 0, world_size: int = 1): ).float() rank_print('Done') - # sweep through all resolutions from 224 to 1024 in steps of 32 - if args.sweep: - resolutions = list(range(224, 1024+1, 32)) - else: - if args.resolution is None: - args.resolution = (model.preferred_resolution.height, model.preferred_resolution.width) - resolutions = [args.resolution] - for resolution in resolutions: - if isinstance(resolution, int): - resolution = (resolution, resolution) - transform = get_standard_transform(resolution, args.resize_multiple, preprocessor=preprocessor) - dataset = ds_builder.as_dataset(split=args.split) - dataset = dataset.to_iterable_dataset(num_shards=world_size * max(1, args.workers)) - dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) - dataset = dataset.map(lambda ex: dict(image=transform(ex['image']), label=torch.as_tensor(ex['label'], dtype=torch.int64))) - - loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, collate_fn=collate, - pin_memory=args.workers > 0, - drop_last=False, - ) - num_steps = round_up(num_examples, args.batch_size * world_size) - rank_print('Done') - rank_print(f'Description: {ds_builder.info.description}') - - rank_print(f'Classifying at resolution={resolution}...') - topks = { - k: torch.tensor(0.0, dtype=torch.float32, device=device) - for k in (1, 5) - } - num_processed = 0 - with torch.inference_mode(), tqdm(total=num_examples, disable=rank > 0) as t: - for batches in loader: - for images, targets in batches: - images = images.to(device=device, non_blocking=True) - targets = targets.to(device=device, non_blocking=True) - - with torch.autocast(device.type, dtype=torch.bfloat16, enabled=args.amp): - output = model(images) - summary = output[args.adaptor_name].summary - summary = F.normalize(summary, dim=-1) - - logits = summary.to(classifier.dtype) @ classifier - - accs = accuracy(logits, targets, topk=topks.keys()) - for k, acc in zip(topks.keys(), accs): - topks[k].add_(acc * images.shape[0]) - num_processed += images.shape[0] - - t.set_postfix({'Rank': '0', **{f'Top-{k}': f'{v.item() / num_processed:.03f}' for k, v in topks.items()}}) - t.update(world_size * args.batch_size) - - if world_size > 1: - rank_print('\tWaiting for all ranks to complete...') - num_processed = torch.tensor(num_processed, device=device) - dist.reduce(num_processed, dst=0, op=dist.ReduceOp.SUM) - - for k, acc in topks.items(): - dist.reduce(acc, dst=0, op=dist.ReduceOp.SUM) - rank_print('\tDone') - rank_print('Done') - - rank_print(f'Resolution: {args.resolution}') - rank_print('Accuracy:') + rank_print('Classifying...') + topks = { + k: torch.tensor(0.0, dtype=torch.float32, device=device) + for k in (1, 5) + } + num_processed = 0 + with torch.inference_mode(), tqdm(total=num_examples, disable=rank > 0) as t: + for batches in loader: + for images, targets in batches: + images = images.to(device=device, non_blocking=True) + targets = targets.to(device=device, non_blocking=True) + + with torch.autocast(device.type, dtype=torch.bfloat16, enabled=args.amp): + output = model(images) + summary = output[args.adaptor_name].summary + summary = F.normalize(summary, dim=-1) + + logits = summary.to(classifier.dtype) @ classifier + + accs = accuracy(logits, targets, topk=topks.keys()) + for k, acc in zip(topks.keys(), accs): + topks[k].add_(acc * images.shape[0]) + num_processed += images.shape[0] + + t.set_postfix({'Rank': '0', **{f'Top-{k}': f'{v.item() / num_processed:.03f}' for k, v in topks.items()}}) + t.update(world_size * args.batch_size) + + if world_size > 1: + rank_print('\tWaiting for all ranks to complete...') + num_processed = torch.tensor(num_processed, device=device) + dist.reduce(num_processed, dst=0, op=dist.ReduceOp.SUM) + for k, acc in topks.items(): - acc = (acc / num_processed).item() + dist.reduce(acc, dst=0, op=dist.ReduceOp.SUM) + rank_print('\tDone') + rank_print('Done') - rank_print(f'\tResolution: {resolution} Top {k}: {acc:.3f}') + rank_print(f'Resolution: {args.resolution}') + rank_print('Accuracy:') + for k, acc in topks.items(): + acc = (acc / num_processed).item() - del loader - del dataset + rank_print(f'\tTop {k}: {acc:.3f}') if rank == 0 and k == 1 and args.csv_out: with open(args.csv_out, 'a') as fd: From fcf1dbfa2b9605435670462bc008c6b665a79f8b Mon Sep 17 00:00:00 2001 From: Greg Heinrich Date: Tue, 10 Dec 2024 08:31:40 -0800 Subject: [PATCH 18/18] Remove commented code in feature normalizer --- radio/feature_normalizer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/radio/feature_normalizer.py b/radio/feature_normalizer.py index cdd0bba..7d4cd27 100644 --- a/radio/feature_normalizer.py +++ b/radio/feature_normalizer.py @@ -28,8 +28,6 @@ class FeatureNormalizer(nn.Module): def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32): super().__init__() -# self.mean = nn.Parameter(torch.zeros(embed_dim, dtype=dtype), requires_grad=False) -# self.tx = nn.Parameter(torch.eye(embed_dim, dtype=dtype), requires_grad=False) self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype)) self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype)) @@ -51,19 +49,15 @@ def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Opti class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase): def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32): super().__init__() -# self.alphas = nn.Parameter(torch.ones(num_intermediates, dtype=dtype), requires_grad=False) self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype)) rot = torch.eye(embed_dim, dtype=dtype) if rot_per_layer: rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1) -# self.rotation = nn.Parameter(rot.contiguous(), requires_grad=False) -# self.means = nn.Parameter(torch.zeros(num_intermediates, embed_dim, dtype=dtype), requires_grad=False) self.register_buffer('rotation', rot.contiguous()) self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype)) - def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: if rot_index is None: rot_index = index