Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attribute support for auto-annotation functions #9090

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog.d/20250211_145801_roman_aa_attributes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
### Added

- \[SDK\] Auto-annotation detection functions can now output shape/keypoint attributes
(<https://github.com/cvat-ai/cvat/pull/9090>)

- \{SDK\] Added a utility module for working with label attributes,
`cvat_sdk.attributes`
(<https://github.com/cvat-ai/cvat/pull/9090>)
45 changes: 30 additions & 15 deletions cvat-cli/src/cvat_cli/_internal/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from cvat_sdk.auto_annotation.driver import (
_AnnotationMapper,
_DetectionFunctionContextImpl,
_LabelNameMapping,
_SpecNameMapping,
)
from cvat_sdk.exceptions import ApiException
Expand Down Expand Up @@ -145,28 +144,42 @@ def _validate_detection_function_compatibility(self, remote_function: dict) -> N
labels_by_name = {label.name: label for label in self._function_spec.labels}

for remote_label in remote_function["labels_v2"]:
label_desc = f"label {remote_label['name']!r}"
label = labels_by_name.get(remote_label["name"])

if not label:
raise CriticalError(
incompatible_msg + f"label {remote_label['name']!r} is not supported."
)
raise CriticalError(incompatible_msg + f"{label_desc} is not supported.")

if (
remote_label["type"] not in {"any", "unknown"}
and remote_label["type"] != label.type
):
raise CriticalError(
incompatible_msg
+ f"label {remote_label['name']!r} has type {remote_label['type']!r}, "
f"but the function object expects type {label.type!r}."
incompatible_msg + f"{label_desc} has type {remote_label['type']!r}, "
f"but the function object declares type {label.type!r}."
)

if remote_label["attributes"]:
raise CriticalError(
incompatible_msg
+ f"label {remote_label['name']!r} has attributes, which is not supported."
)
attrs_by_name = {attr.name: attr for attr in getattr(label, "attributes", [])}

for remote_attr in remote_label["attributes"]:
attr_desc = f"attribute {remote_attr['name']!r} of {label_desc}"
attr = attrs_by_name.get(remote_attr["name"])

if not attr:
raise CriticalError(incompatible_msg + f"{attr_desc} is not supported.")

if remote_attr["input_type"] != attr.input_type.value:
raise CriticalError(
incompatible_msg
+ f"{attr_desc} has input type {remote_attr['input_type']!r},"
f" but the function object declares input type {attr.input_type.value!r}."
)

if remote_attr["values"] != attr.values:
raise CriticalError(
incompatible_msg + f"{attr_desc} has values {remote_attr['values']!r},"
f" but the function object declares values {attr.values!r}."
)

def _wait_between_polls(self):
# offset the interval randomly to avoid synchronization between workers
Expand Down Expand Up @@ -288,11 +301,13 @@ def _calculate_result_for_detection_ar(
self._update_ar(ar_id, 0)
last_update_timestamp = datetime.now(tz=timezone.utc)

mapping = ar_params["mapping"]
conv_mask_to_poly = ar_params["conv_mask_to_poly"]

spec_nm = _SpecNameMapping(
labels={k: _LabelNameMapping(v["name"]) for k, v in mapping.items()}
spec_nm = _SpecNameMapping.from_api(
{
k: models.LabelMappingEntryRequest._from_openapi_data(**v)
for k, v in ar_params["mapping"].items()
}
)

mapper = _AnnotationMapper(
Expand Down
8 changes: 8 additions & 0 deletions cvat-cli/src/cvat_cli/_internal/commands_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def execute(
remote_function["labels_v2"].append(
{
"name": label_spec.name,
"attributes": [
{
"name": attribute_spec.name,
"input_type": attribute_spec.input_type,
"values": attribute_spec.values,
}
for attribute_spec in getattr(label_spec, "attributes", [])
],
}
)

Expand Down
2 changes: 1 addition & 1 deletion cvat-cli/src/cvat_cli/_internal/commands_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--allow-unmatched-labels",
action="store_true",
help="Allow the function to declare labels not configured in the task",
help="Allow the function to declare labels/sublabels/attributes not configured in the task",
)

parser.add_argument(
Expand Down
4 changes: 3 additions & 1 deletion cvat-sdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ The SDK API includes several layers:
- Server API wrappers (`ApiClient`). Located in at `cvat_sdk.api_client`.
- High-level tools (`Core`). Located at `cvat_sdk.core`.
- PyTorch adapter. Located at `cvat_sdk.pytorch`.
* Auto-annotation support. Located at `cvat_sdk.auto_annotation`.
- Auto-annotation support. Located at `cvat_sdk.auto_annotation`.
- Miscellaneous utilities, grouped by topic.
Located at `cvat_sdk.attributes` and `cvat_sdk.masks`.

Package documentation is available [here](https://docs.cvat.ai/docs/api_sdk/sdk).

Expand Down
103 changes: 103 additions & 0 deletions cvat-sdk/cvat_sdk/attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from __future__ import annotations

from typing import Callable

from . import models


class _CheckboxAttributeValueValidator:
def __init__(self, values: list[str]) -> None:
pass

def __call__(self, value: str) -> bool:
return value in {"true", "false"}


class _NumberAttributeValueValidator:
def __init__(self, values: list[str]) -> None:
if len(values) != 3:
raise ValueError(f"wrong number of values: expected 3, got {len(values)}")

(self._min_value, self._max_value, self._step) = map(int, values)

try:
number_attribute_values(self._min_value, self._max_value, self._step)
except ValueError as ex:
raise ValueError(f"invalid values: {ex}") from ex

def __call__(self, value: str) -> bool:
try:
value = int(value)
except ValueError:
return False

return (
self._min_value <= value <= self._max_value
and (value - self._min_value) % self._step == 0
)


class _SelectAttributeValueValidator:
def __init__(self, values: list[str]) -> None:
if len(values) == 0:
raise ValueError("empty list of allowed values")

self._values = frozenset(values)

def __call__(self, value: str) -> bool:
return value in self._values


class _TextAttributeValueValidator:
def __init__(self, values: list[str]) -> None:
pass

def __call__(self, value: str) -> bool:
return True


_VALIDATOR_CLASSES = {
"checkbox": _CheckboxAttributeValueValidator,
"number": _NumberAttributeValueValidator,
"radio": _SelectAttributeValueValidator,
"select": _SelectAttributeValueValidator,
"text": _TextAttributeValueValidator,
}

# make sure all possible types are covered
assert set(models.InputTypeEnum.allowed_values[("value",)].values()) == _VALIDATOR_CLASSES.keys()


def attribute_value_validator(spec: models.IAttributeRequest) -> Callable[[str], bool]:
"""
Returns a callable that can be used to verify
whether an attribute value is suitable for an attribute with the given spec.
The resulting callable takes a single argument (the attribute value as a string)
and returns True if and only if the value is suitable.

The spec's `values` attribute must be consistent with its `input_type` attribute,
otherwise ValueError will be raised.
"""
return _VALIDATOR_CLASSES[spec.input_type.value](spec.values)


def number_attribute_values(min_value: int, max_value: int, /, step: int = 1) -> list[str]:
"""
Returns a list suitable as the value of the "values" field of an `AttributeRequest`
with `input_type="number"`.
"""

if min_value > max_value:
raise ValueError("min_value must be less than or equal to max_value")

if step <= 0:
raise ValueError("step must be positive")

if (max_value - min_value) % step != 0:
raise ValueError("step must be a divisor of max_value - min_value")

return [str(min_value), str(max_value), str(step)]
14 changes: 14 additions & 0 deletions cvat-sdk/cvat_sdk/auto_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,44 @@
DetectionFunction,
DetectionFunctionContext,
DetectionFunctionSpec,
attribute_spec,
attribute_val,
checkbox_attribute_spec,
keypoint,
keypoint_spec,
label_spec,
mask,
number_attribute_spec,
polygon,
radio_attribute_spec,
rectangle,
select_attribute_spec,
shape,
skeleton,
skeleton_label_spec,
text_attribute_spec,
)

__all__ = [
"annotate_task",
"attribute_spec",
"attribute_val",
"BadFunctionError",
"checkbox_attribute_spec",
"DetectionFunction",
"DetectionFunctionContext",
"DetectionFunctionSpec",
"keypoint_spec",
"keypoint",
"label_spec",
"mask",
"number_attribute_spec",
"polygon",
"radio_attribute_spec",
"rectangle",
"select_attribute_spec",
"shape",
"skeleton_label_spec",
"skeleton",
"text_attribute_spec",
]
Loading
Loading