Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions fmpose3d/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import math
import json
from dataclasses import dataclass, field, fields, asdict
from enum import Enum
from typing import Dict, List
Expand Down Expand Up @@ -36,6 +37,16 @@ class ModelConfig:
"""Model architecture configuration."""
model_type: str = "fmpose3d_humans"

def to_json(self, filename: str | None = None, **kwargs) -> str:
json_str = json.dumps(asdict(self), **kwargs)
with open(filename, "w") as f:
f.write(json_str)

@classmethod
def from_json(cls, filename: str, **kwargs) -> "ModelConfig":
with open(filename, "r") as f:
return cls(**json.loads(f.read(), **kwargs))


# Per-model-type defaults for fields marked with INFER_FROM_MODEL_TYPE.
# Also consumed by PipelineConfig.for_model_type to set cross-config
Expand Down
45 changes: 30 additions & 15 deletions fmpose3d/fmpose3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
ProgressCallback = Callable[[int, int], None]


#: HuggingFace repository hosting the official FMPose3D checkpoints.
_HF_REPO_ID: str = "deruyter92/fmpose_temp"

# Default camera-to-world rotation quaternion (from the demo script).
_DEFAULT_CAM_ROTATION = np.array(
[0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
Expand Down Expand Up @@ -560,7 +563,7 @@ def __init__(
self,
model_cfg: FMPose3DConfig | None = None,
inference_cfg: InferenceConfig | None = None,
model_weights_path: str | Path | None = SKIP_WEIGHTS_VALIDATION,
model_weights_path: str | Path | None = None,
device: str | torch.device | None = None,
*,
estimator_2d: HRNetEstimator | SuperAnimalEstimator | None = None,
Expand Down Expand Up @@ -601,7 +604,7 @@ def __init__(
@classmethod
def for_animals(
cls,
model_weights_path: str = SKIP_WEIGHTS_VALIDATION,
model_weights_path: str | None = None,
*,
device: str | torch.device | None = None,
inference_cfg: InferenceConfig | None = None,
Expand Down Expand Up @@ -958,15 +961,11 @@ def _load_weights(self) -> None:
# Private helpers – input resolution
# ------------------------------------------------------------------

def _resolve_model_weights_path(self) -> None:
# TODO @deruyter92: THIS IS TEMPORARY UNTIL WE DOWNLOAD THE WEIGHTS FROM HUGGINGFACE
if self.model_weights_path is SKIP_WEIGHTS_VALIDATION:
return SKIP_WEIGHTS_VALIDATION

if not self.model_weights_path:
def _resolve_model_weights_path(self) -> None:
if self.model_weights_path is None:
self._download_model_weights()
self.model_weights_path = Path(self.model_weights_path).resolve()
if not self.model_weights_path.exists():
if not self.model_weights_path.is_file():
raise ValueError(
f"Model weights file not found: {self.model_weights_path}. "
"Please provide a valid path to a .pth checkpoint file in the "
Expand All @@ -976,12 +975,28 @@ def _resolve_model_weights_path(self) -> None:
return self.model_weights_path

def _download_model_weights(self) -> None:
"""Download model weights from huggingface."""
# TODO @deruyter92: Implement download from huggingface
raise NotImplementedError(
"Downloading model weights from huggingface is not implemented yet."
"Please provide a valid path to a .pth checkpoint file in the "
"FMPose3DInference constructor."
"""Download model weights from HuggingFace Hub.

The weight file is determined by the current ``model_cfg.model_type``
(e.g. ``"fmpose3d_humans"`` -> ``fmpose3d_humans.pth``). Files are
cached locally by :func:`huggingface_hub.hf_hub_download` so
subsequent calls are instant.

Sets ``self.model_weights_path`` to the local cached file path.
"""
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError(
"huggingface_hub is required to download model weights. "
"Install it with: pip install huggingface_hub. Or download "
"the weights manually and set model_weights_path to the weights file."
) from None

filename = f"{self.model_cfg.model_type.value}.pth"
self.model_weights_path = hf_hub_download(
repo_id=_HF_REPO_ID,
filename=filename,
)

def _ingest_input(self, source: Source) -> _IngestedInput:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"filterpy>=1.4.5",
"pandas>=1.0.1",
"deeplabcut==3.0.0rc13",
"huggingface_hub>=0.20.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -69,6 +70,7 @@ line_length = 100
[tool.pytest.ini_options]
markers = [
"functional: marks tests that require pretrained weights (deselect with '-m \"not functional\"')",
"network: marks tests that may need internet access on first run (deselect with '-m \"not network\"')",
]

[tool.codespell]
Expand Down
Empty file added tests/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions tests/fmpose3d_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
FMPose3D: monocular 3D Pose Estimation via Flow Matching

Official implementation of the paper:
"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
Licensed under Apache 2.0
"""
104 changes: 104 additions & 0 deletions tests/fmpose3d_api/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
FMPose3D: monocular 3D Pose Estimation via Flow Matching

Official implementation of the paper:
"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
Licensed under Apache 2.0

Shared fixtures, markers, and skip-helpers for the ``fmpose3d_api`` test suite.

Skip logic
----------
* **weights_ready(filename)** – ``True`` when the HuggingFace-cached file
already exists on disk *or* we can reach ``huggingface.co`` so that
``hf_hub_download`` will succeed.
* **has_internet** – evaluated once at collection time via a quick TCP probe.
* **HF_HUB_OFFLINE** – if set to ``"1"`` in the environment the network
check is skipped entirely (consistent with how ``huggingface_hub``
itself behaves).
"""

from __future__ import annotations

import os
import socket

import pytest

# ---------------------------------------------------------------------------
# HuggingFace repo & filenames (must match fmpose3d.fmpose3d._HF_REPO_ID)
# ---------------------------------------------------------------------------

HF_REPO_ID: str = "deruyter92/fmpose_temp"

HUMAN_WEIGHTS_FILENAME: str = "fmpose3d_humans.pth"
ANIMAL_WEIGHTS_FILENAME: str = "fmpose3d_animals.pth"

# ---------------------------------------------------------------------------
# Connectivity helpers
# ---------------------------------------------------------------------------


def _has_internet(host: str = "huggingface.co", port: int = 443, timeout: float = 3) -> bool:
"""Return ``True`` if *host* is reachable via TCP."""
if os.environ.get("HF_HUB_OFFLINE", "0") == "1":
return False
try:
socket.create_connection((host, port), timeout=timeout)
return True
except OSError:
return False


def _weights_cached(filename: str) -> bool:
"""Return ``True`` if *filename* already lives in the local HF cache."""
try:
from huggingface_hub import try_to_load_from_cache

result = try_to_load_from_cache(HF_REPO_ID, filename)
return isinstance(result, str)
except Exception:
return False


def weights_ready(filename: str) -> bool:
"""``True`` when we can obtain *filename* — either from cache or network."""
return _weights_cached(filename) or _has_internet()


# Evaluate once at collection time.
HAS_INTERNET: bool = _has_internet()
HUMAN_WEIGHTS_READY: bool = weights_ready(HUMAN_WEIGHTS_FILENAME)
ANIMAL_WEIGHTS_READY: bool = weights_ready(ANIMAL_WEIGHTS_FILENAME)

try:
import deeplabcut # noqa: F401

DLC_AVAILABLE: bool = True
except ImportError:
DLC_AVAILABLE = False

# ---------------------------------------------------------------------------
# Reusable skip markers
# ---------------------------------------------------------------------------

requires_network = pytest.mark.skipif(
not HAS_INTERNET,
reason="No internet connection (cannot reach huggingface.co)",
)

requires_human_weights = pytest.mark.skipif(
not HUMAN_WEIGHTS_READY,
reason="Human weights not cached and no internet connection",
)

requires_animal_weights = pytest.mark.skipif(
not ANIMAL_WEIGHTS_READY,
reason="Animal weights not cached and no internet connection",
)

requires_dlc = pytest.mark.skipif(
not DLC_AVAILABLE,
reason="DeepLabCut is not installed",
)
Loading
Loading