diff --git a/docs/detection/detection.md b/docs/detection/detection.md new file mode 100644 index 000000000..deb90456c --- /dev/null +++ b/docs/detection/detection.md @@ -0,0 +1,48 @@ +### `transform` + +Transform detections to match the dataset's class names and IDs. + +This method performs the following steps: + +1. **Remaps class names** using the provided `class_mapping` dictionary. +2. **Filters out predictions** that are not present in the dataset's classes. +3. **Remaps class IDs** to match the dataset's class IDs. + +#### Parameters + +- **dataset**: The dataset object containing class names and IDs. +- **class_mapping** (`Optional[Dict[str, str]]`): A dictionary to map model class names to dataset class names. If `None`, no remapping is performed. + +#### Returns + +- **Detections**: A new `Detections` object with transformed class names and IDs. + +#### Raises + +- **ValueError**: If the dataset does not contain the required class names. + +#### Example + +```python +# Example dataset with class names +class DatasetMock: + def __init__(self): + self.classes = ["animal", "bird"] + +# Example detections +detections = Detections( + xyxy=np.array([[10, 10, 50, 50], [60, 60, 100, 100]]), + confidence=np.array([0.9, 0.8]), + class_id=np.array([0, 1]), + data={"class_name": ["dog", "eagle"]} +) + +# Class mapping +class_mapping = {"dog": "animal", "eagle": "bird"} + +# Transform detections +transformed_detections = detections.transform(DatasetMock(), class_mapping) + +print(transformed_detections.class_id) # Output: [0, 1] +print(transformed_detections.data["class_name"]) # Output: ["animal", "bird"] +``` diff --git a/supervision/detection/core.py b/supervision/detection/core.py index ca4ded1ff..01987abc3 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -1530,3 +1530,65 @@ def validate_fields_both_defined_or_none( f"Field '{attribute}' should be consistently None or not None in both " "Detections." ) + + +def transform( + self, dataset, class_mapping: Optional[Dict[str, str]] = None +) -> Detections: + """ + Transform detections to match the dataset's class names and IDs. + + Args: + dataset: The dataset object containing class names and IDs. + class_mapping (Optional[Dict[str, str]]): A dictionary to map model class names + to dataset class names. If None, no remapping is performed. + + Returns: + Detections: A new Detections object with transformed class names and IDs. + + Raises: + ValueError: If the dataset does not contain the required class names. + """ + if self.is_empty(): + return self + + if class_mapping is not None: + if self.class_id is None or self.data.get(CLASS_NAME_DATA_FIELD) is None: + raise ValueError( + "Class names must be available in the data field for remapping." + ) + + current_class_names = self.data[CLASS_NAME_DATA_FIELD] + + remapped_class_names = np.array( + [class_mapping.get(name, name) for name in current_class_names] + ) + + if not all(name in dataset.classes for name in np.unique(remapped_class_names)): + raise ValueError("All mapped class names must be in the dataset's classes.") + + self.data[CLASS_NAME_DATA_FIELD] = remapped_class_names + + if self.class_id is not None and self.data.get(CLASS_NAME_DATA_FIELD) is not None: + class_names = self.data[CLASS_NAME_DATA_FIELD] + + mask = np.isin(class_names, dataset.classes) + + self.xyxy = self.xyxy[mask] + self.mask = self.mask[mask] if self.mask is not None else None + self.confidence = self.confidence[mask] if self.confidence is not None else None + self.class_id = self.class_id[mask] if self.class_id is not None else None + self.tracker_id = self.tracker_id[mask] if self.tracker_id is not None else None + + for key, value in self.data.items(): + if isinstance(value, np.ndarray): + self.data[key] = value[mask] + elif isinstance(value, list): + self.data[key] = [value[i] for i in np.where(mask)[0]] + + if self.class_id is not None and self.data.get(CLASS_NAME_DATA_FIELD) is not None: + class_names = self.data[CLASS_NAME_DATA_FIELD] + + self.class_id = np.array([dataset.classes.index(name) for name in class_names]) + + return self diff --git a/test/detection/test_detection.py b/test/detection/test_detection.py new file mode 100644 index 000000000..311be33a8 --- /dev/null +++ b/test/detection/test_detection.py @@ -0,0 +1,37 @@ +import unittest + +import numpy as np + +from supervision.detections.core import Detections + + +class TestDetectionsTransform(unittest.TestCase): + def test_transform(self): + # Mock dataset + class DatasetMock: + def __init__(self): + self.classes = ["animal", "bird"] + + # Example detections + detections = Detections( + xyxy=np.array([[10, 10, 50, 50], [60, 60, 100, 100]]), + confidence=np.array([0.9, 0.8]), + class_id=np.array([0, 1]), + data={"class_name": ["dog", "eagle"]}, + ) + + # Class mapping + class_mapping = {"dog": "animal", "eagle": "bird"} + + # Transform detections + transformed_detections = detections.transform(DatasetMock(), class_mapping) + + # Verify results + self.assertEqual(transformed_detections.class_id.tolist(), [0, 1]) + self.assertEqual( + transformed_detections.data["class_name"].tolist(), ["animal", "bird"] + ) + + +if __name__ == "__main__": + unittest.main()