diff --git a/avex/models/utils/load.py b/avex/models/utils/load.py index 7341278..6bfb899 100644 --- a/avex/models/utils/load.py +++ b/avex/models/utils/load.py @@ -245,15 +245,14 @@ def _load_from_modelspec( num_classes = len(label_mapping["label_to_index"]) model_kwargs["num_classes"] = num_classes logger.info(f"Extracted num_classes={num_classes} from label mapping") - else: - # Checkpoint exists but no num_classes found - likely a backbone-only checkpoint - # Automatically enable embedding mode for models that support it - if supports_return_features_only: - return_features_only = True - model_kwargs["return_features_only"] = True - logger.info( - f"Checkpoint found but no classifier detected; loading {model_type} in embedding extraction mode" - ) + # If we still couldn't determine num_classes, treat this as a backbone-only + # checkpoint and fall back to embedding extraction for models that support it. + if "num_classes" not in model_kwargs and supports_return_features_only: + return_features_only = True + model_kwargs["return_features_only"] = True + logger.info( + f"Checkpoint found but no classifier detected; loading {model_type} in embedding extraction mode" + ) # If pretrained=True, pretrained weights are typically backbone-only (no classifier) # Automatically enable embedding mode for models that support it diff --git a/tests/unittests/test_api_load.py b/tests/unittests/test_api_load.py index 2baac22..da6796c 100644 --- a/tests/unittests/test_api_load.py +++ b/tests/unittests/test_api_load.py @@ -812,3 +812,35 @@ def forward(self, x: torch.Tensor, padding_mask: torch.Tensor | None = None) -> registry._MODEL_CLASSES.clear() registry._MODEL_REGISTRY.clear() + + def test_falls_back_to_embedding_mode_when_checkpoint_has_no_classifier( + self, + tmp_path: Path, + ) -> None: + """Test fallback to embedding mode for backbone-only checkpoints. + + This matches the failure mode of some official EAT models where a checkpoint + exists but does not contain a classifier head and there is no label mapping. + In that case, we should automatically enable return_features_only=True for + models that support it. + """ + checkpoint_path = tmp_path / "backbone_only.pt" + # No classifier/head keys: _extract_num_classes_from_checkpoint should return None. + torch.save({"backbone.some_weight": torch.randn(2, 2)}, checkpoint_path) + + model_spec = ModelSpec( + name="test_model_type", + pretrained=False, + device="cpu", + ) + + model = _load_from_modelspec( + model_spec, + device="cpu", + checkpoint_path=str(checkpoint_path), + registry_key="test_model", + return_features_only=False, + ) + + assert isinstance(model, ModelBase) + assert getattr(model, "return_features_only", False) is True