Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions avex/models/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,14 @@ def _load_from_modelspec(
num_classes = len(label_mapping["label_to_index"])
model_kwargs["num_classes"] = num_classes
logger.info(f"Extracted num_classes={num_classes} from label mapping")
else:
# Checkpoint exists but no num_classes found - likely a backbone-only checkpoint
# Automatically enable embedding mode for models that support it
if supports_return_features_only:
return_features_only = True
model_kwargs["return_features_only"] = True
logger.info(
f"Checkpoint found but no classifier detected; loading {model_type} in embedding extraction mode"
)
# If we still couldn't determine num_classes, treat this as a backbone-only
# checkpoint and fall back to embedding extraction for models that support it.
if "num_classes" not in model_kwargs and supports_return_features_only:
return_features_only = True
model_kwargs["return_features_only"] = True
logger.info(
f"Checkpoint found but no classifier detected; loading {model_type} in embedding extraction mode"
)

# If pretrained=True, pretrained weights are typically backbone-only (no classifier)
# Automatically enable embedding mode for models that support it
Expand Down
32 changes: 32 additions & 0 deletions tests/unittests/test_api_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,3 +812,35 @@ def forward(self, x: torch.Tensor, padding_mask: torch.Tensor | None = None) ->

registry._MODEL_CLASSES.clear()
registry._MODEL_REGISTRY.clear()

def test_falls_back_to_embedding_mode_when_checkpoint_has_no_classifier(
self,
tmp_path: Path,
) -> None:
"""Test fallback to embedding mode for backbone-only checkpoints.

This matches the failure mode of some official EAT models where a checkpoint
exists but does not contain a classifier head and there is no label mapping.
In that case, we should automatically enable return_features_only=True for
models that support it.
"""
checkpoint_path = tmp_path / "backbone_only.pt"
# No classifier/head keys: _extract_num_classes_from_checkpoint should return None.
torch.save({"backbone.some_weight": torch.randn(2, 2)}, checkpoint_path)

model_spec = ModelSpec(
name="test_model_type",
pretrained=False,
device="cpu",
)

model = _load_from_modelspec(
model_spec,
device="cpu",
checkpoint_path=str(checkpoint_path),
registry_key="test_model",
return_features_only=False,
)

assert isinstance(model, ModelBase)
assert getattr(model, "return_features_only", False) is True
Loading