Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions supervision/metrics/mean_average_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ def _compute(
large_objects=None,
)

concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
recall_scores_per_k, recall_per_class, unique_classes = (
self._compute_average_recall_for_classes(*concatenated_stats)
self._compute_average_recall_for_classes(stats)
)

return MeanAverageRecallResult(
Expand All @@ -238,25 +237,34 @@ def _compute(

def _compute_average_recall_for_classes(
self,
matches: np.ndarray,
prediction_confidence: np.ndarray,
prediction_class_ids: np.ndarray,
true_class_ids: np.ndarray,
stats: list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]],
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
sorted_indices = np.argsort(-prediction_confidence)
matches = matches[sorted_indices]
prediction_class_ids = prediction_class_ids[sorted_indices]
unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)

recalls_at_k = []

for max_detections in self.max_detections:
filtered_stats = []
for matches, confidence, class_id, true_class_id in stats:
sorted_indices = np.argsort(-confidence)[:max_detections]
filtered_stats.append(
(
matches[sorted_indices],
class_id[sorted_indices],
true_class_id,
)
)
concatenated_stats = [
np.concatenate(items, 0) for items in zip(*filtered_stats)
]

filtered_matches, prediction_class_ids, true_class_ids = concatenated_stats
unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)

# Shape: PxTh,P,C,C -> CxThx3
confusion_matrix = self._compute_confusion_matrix(
matches,
filtered_matches,
prediction_class_ids,
unique_classes,
class_counts,
max_detections=max_detections,
)

# Shape: CxThx3 -> CxTh
Expand Down
308 changes: 308 additions & 0 deletions test/metrics/test_mean_average_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
from contextlib import AbstractContextManager, ExitStack
from typing import Any

import numpy as np
import pytest

from supervision.detection.core import Detections
from supervision.metrics import MeanAverageRecall, MetricTarget

# Totals:
# class 0 GT count = 17
# class 1 GT count = 19

TARGETS = [
# img 0 (2 GT: c0, c1)
np.array(
[
[100, 120, 260, 400, 1.0, 0],
[500, 200, 760, 640, 1.0, 1],
],
dtype=np.float32,
),
# img 1 (3 GT: c0, c0, c1)
np.array(
[
[50, 60, 180, 300, 1.0, 0],
[210, 70, 340, 310, 1.0, 0],
[400, 90, 620, 360, 1.0, 1],
],
dtype=np.float32,
),
# img 2 (1 GT: c1)
np.array(
[
[320, 200, 540, 520, 1.0, 1],
],
dtype=np.float32,
),
# img 3 (4 GT: c0, c1, c0, c1)
np.array(
[
[100, 100, 240, 340, 1.0, 0],
[260, 110, 410, 350, 1.0, 1],
[430, 120, 580, 360, 1.0, 0],
[600, 130, 760, 370, 1.0, 1],
],
dtype=np.float32,
),
# img 4 (2 GT: c0, c0)
np.array(
[
[120, 400, 260, 700, 1.0, 0],
[300, 420, 480, 720, 1.0, 0],
],
dtype=np.float32,
),
# img 5 (3 GT: c1, c1, c1)
np.array(
[
[50, 50, 200, 260, 1.0, 1],
[230, 60, 380, 270, 1.0, 1],
[410, 70, 560, 280, 1.0, 1],
],
dtype=np.float32,
),
# img 6 (1 GT: c0)
np.array(
[
[600, 60, 780, 300, 1.0, 0],
],
dtype=np.float32,
),
# img 7 (5 GT: c0, c1, c1, c0, c1)
np.array(
[
[60, 360, 180, 600, 1.0, 0],
[200, 350, 340, 590, 1.0, 1],
[360, 340, 500, 580, 1.0, 1],
[520, 330, 660, 570, 1.0, 0],
[680, 320, 820, 560, 1.0, 1],
],
dtype=np.float32,
),
# img 8 (2 GT: c1, c1)
np.array(
[
[100, 100, 220, 300, 1.0, 1],
[260, 110, 380, 310, 1.0, 1],
],
dtype=np.float32,
),
# img 9 (1 GT: c0)
np.array(
[
[420, 400, 600, 700, 1.0, 0],
],
dtype=np.float32,
),
# img 10 (4 GT: c0, c1, c1, c0)
np.array(
[
[50, 500, 180, 760, 1.0, 0],
[200, 500, 350, 760, 1.0, 1],
[370, 500, 520, 760, 1.0, 1],
[540, 500, 690, 760, 1.0, 0],
],
dtype=np.float32,
),
# img 11 (2 GT: c1, c0)
np.array(
[
[150, 150, 300, 420, 1.0, 1],
[330, 160, 480, 430, 1.0, 0],
],
dtype=np.float32,
),
# img 12 (3 GT: c0, c1, c1)
np.array(
[
[600, 200, 760, 460, 1.0, 0],
[100, 220, 240, 480, 1.0, 1],
[260, 230, 400, 490, 1.0, 1],
],
dtype=np.float32,
),
# img 13 (1 GT: c0)
np.array(
[
[50, 50, 190, 250, 1.0, 0],
],
dtype=np.float32,
),
# img 14 (2 GT: c1, c0)
np.array(
[
[420, 80, 560, 300, 1.0, 1],
[580, 90, 730, 310, 1.0, 0],
],
dtype=np.float32,
),
]

PREDICTIONS = [
# img 0: 2 TP + 1 class mismatch FP
np.array(
[
[102, 118, 258, 398, 0.94, 0], # TP (c0)
[500, 200, 760, 640, 0.90, 1], # TP (c1)
[100, 120, 260, 400, 0.55, 1], # FP (class mismatch)
],
dtype=np.float32,
),
# img 1: TPs for two c0, miss c1 (FN) + background FP
np.array(
[
[50, 60, 180, 300, 0.91, 0], # TP (c0)
[210, 70, 340, 310, 0.88, 0], # TP (c0)
[600, 400, 720, 560, 0.42, 1], # FP (no GT nearby)
],
dtype=np.float32,
),
# img 2: Low-IoU (miss) + random FP
np.array(
[
[300, 180, 500, 430, 0.83, 1], # Low IoU (shifted, suppose < threshold)
[50, 50, 140, 140, 0.30, 0], # FP
],
dtype=np.float32,
),
# img 3: Only match two (others FN) + one mismatch
np.array(
[
[100, 100, 240, 340, 0.90, 0], # TP (c0)
[260, 110, 410, 350, 0.87, 1], # TP (c1)
[430, 120, 580, 360, 0.70, 1], # FP (class mismatch; GT is c0)
],
dtype=np.float32,
),
# img 4: No predictions (2 FN)
np.array([], dtype=np.float32).reshape(0, 6),
# img 5: All three matched + class mismatch
np.array(
[
[50, 50, 200, 260, 0.95, 1], # TP (c1)
[230, 60, 380, 270, 0.92, 1], # TP (c1)
[410, 70, 560, 280, 0.90, 1], # TP (c1)
[50, 50, 200, 260, 0.40, 0], # FP (class mismatch)
],
dtype=np.float32,
),
# img 6: Wrong class over GT (0 recall)
np.array(
[
[600, 60, 780, 300, 0.89, 1], # FP (class mismatch)
],
dtype=np.float32,
),
# img 7: 3 TP, 1 miss (only 3/5 recalled)
np.array(
[
[60, 360, 180, 600, 0.93, 0], # TP (c0)
[200, 350, 340, 590, 0.90, 1], # TP (c1)
[360, 340, 500, 580, 0.88, 1], # TP (c1)
[520, 330, 660, 570, 0.50, 1], # FP (class mismatch; GT is c0)
],
dtype=np.float32,
),
# img 8: 2 TP
np.array(
[
[100, 100, 220, 300, 0.96, 1], # TP
[262, 112, 378, 308, 0.89, 1], # TP
],
dtype=np.float32,
),
# img 9: 1 TP + 1 FP
np.array(
[
[418, 398, 602, 702, 0.86, 0], # TP
[100, 100, 140, 160, 0.33, 1], # FP
],
dtype=np.float32,
),
# img 10: Perfect (all 4 TP)
np.array(
[
[50, 500, 180, 760, 0.94, 0], # TP
[200, 500, 350, 760, 0.93, 1], # TP
[370, 500, 520, 760, 0.92, 1], # TP
[540, 500, 690, 760, 0.91, 0], # TP
],
dtype=np.float32,
),
# img 11: 1 TP, 1 low IoU (FN remains) + FP
np.array(
[
[150, 150, 300, 420, 0.90, 1], # TP (c1)
[
332,
162,
478,
428,
0.58,
0,
], # TP? (slight shift) treat as TP if IoU high enough; assume OK
[148, 148, 298, 415, 0.52, 0], # FP (class mismatch over c1)
],
dtype=np.float32,
),
# img 12: 2 TP + 1 miss (one c1 missed)
np.array(
[
[600, 200, 760, 460, 0.92, 0], # TP
[100, 220, 240, 480, 0.90, 1], # TP
[260, 230, 400, 490, 0.40, 0], # FP (class mismatch; GT is c1)
],
dtype=np.float32,
),
# img 13: No predictions (1 FN)
np.array([], dtype=np.float32).reshape(0, 6),
# img 14: Class swapped (0 recall) + one correct + one FP
np.array(
[
[420, 80, 560, 300, 0.88, 0], # FP (class mismatch; GT is c1)
[580, 90, 730, 310, 0.86, 1], # FP (class mismatch; GT is c0)
],
dtype=np.float32,
),
]


# Expected mAR at K = 1, 10, 100
EXPECTED_RESULT = np.array([0.2874613, 0.63622291, 0.63622291])


def mock_detections_list(boxes_list):
return [
Detections(
xyxy=boxes[:, :4], confidence=boxes[:, 4], class_id=boxes[:, 5].astype(int)
)
for boxes in boxes_list
]


@pytest.mark.parametrize(
"predictions_list, targets_list, expected_result, exception",
[
(
mock_detections_list(PREDICTIONS),
mock_detections_list(TARGETS),
EXPECTED_RESULT,
ExitStack(),
),
],
)
def test_recall(
predictions_list: list[Detections],
targets_list: list[Detections],
expected_result: np.ndarray,
exception: AbstractContextManager[Any],
):
mar_metrics = MeanAverageRecall(metric_target=MetricTarget.BOXES)
mar_result = mar_metrics._compute(predictions_list, targets_list)

with exception:
np.testing.assert_almost_equal(
mar_result.recall_scores, expected_result, decimal=5
)