Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Marishka17 committed Feb 10, 2025
1 parent b8f8a88 commit 6208de4
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 75 deletions.
4 changes: 2 additions & 2 deletions cvat/apps/engine/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def setup_background_job(
result_url = self.make_result_url()

with get_rq_lock_by_user(queue, user_id):
meta = ExportRQMeta.build(
meta = ExportRQMeta.build_for(
request=self.request,
db_obj=self.db_instance,
result_url=result_url,
Expand Down Expand Up @@ -758,7 +758,7 @@ def setup_background_job(
user_id = self.request.user.id

with get_rq_lock_by_user(queue, user_id):
meta = ExportRQMeta.build(
meta = ExportRQMeta.build_for(
request=self.request,
db_obj=self.db_instance,
result_url=result_url,
Expand Down
2 changes: 1 addition & 1 deletion cvat/apps/engine/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ def _import(importer, request: PatchedRequest, queue, rq_id, Serializer, file_fi
user_id = request.user.id

with get_rq_lock_by_user(queue, user_id):
meta = ImportRQMeta.build(
meta = ImportRQMeta.build_for(
request=request,
db_obj=None,
tmp_file=filename,
Expand Down
71 changes: 13 additions & 58 deletions cvat/apps/engine/rq_job_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def reset_meta_on_retry(self) -> dict[RQJobMetaField, Any]:

@attrs.define(kw_only=True)
class RQMetaWithFailureInfo(AbstractRQMeta):
# immutable and optional fields
# mutable && optional fields
formatted_exception: str | None = attrs.field(
validator=[optional_str_validator],
default=None,
Expand All @@ -98,7 +98,6 @@ class RQMetaWithFailureInfo(AbstractRQMeta):

@staticmethod
def _get_resettable_fields() -> list[RQJobMetaField]:
"""Return a list of fields that must be reset on retry"""
return [
RQJobMetaField.FORMATTED_EXCEPTION,
RQJobMetaField.EXCEPTION_TYPE,
Expand All @@ -111,12 +110,12 @@ class BaseRQMeta(RQMetaWithFailureInfo):
# immutable and required fields
user: UserInfo = attrs.field(
validator=[attrs.validators.instance_of(UserInfo)],
converter=lambda d: UserInfo(**d),
converter=lambda x: x if isinstance(x, UserInfo) else UserInfo(**x),
on_setattr=attrs.setters.frozen,
)
request: RequestInfo = attrs.field(
validator=[attrs.validators.instance_of(RequestInfo)],
converter=lambda d: RequestInfo(**d),
converter=lambda x: x if isinstance(x, RequestInfo) else RequestInfo(**x),
on_setattr=attrs.setters.frozen,
)

Expand All @@ -137,17 +136,19 @@ class BaseRQMeta(RQMetaWithFailureInfo):
validator=[optional_int_validator], default=None, on_setattr=attrs.setters.frozen
)

# import && lambda
# mutable fields
progress: float | None = attrs.field(
validator=[optional_float_validator],
default=None,
on_setattr=_update_value,
)
status: str = attrs.field(
validator=[str_validator], default="", on_setattr=_update_value
)

@staticmethod
def _get_resettable_fields() -> list[RQJobMetaField]:
"""Return a list of fields that must be reset on retry"""
return RQMetaWithFailureInfo._get_resettable_fields() + [RQJobMetaField.PROGRESS]
return RQMetaWithFailureInfo._get_resettable_fields() + [RQJobMetaField.PROGRESS, RQJobMetaField.STATUS]

@classmethod
def build(
Expand Down Expand Up @@ -189,16 +190,15 @@ def build(
@attrs.define(kw_only=True)
class ExportRQMeta(BaseRQMeta):
# will be changed to ExportResultInfo in the next PR
result_url: str | None = attrs.field(validator=[optional_str_validator])
result_url: str | None = attrs.field(validator=[optional_str_validator], default=None)

@staticmethod
def _get_resettable_fields() -> list[RQJobMetaField]:
"""Return a list of fields that must be reset on retry"""
base_fields = BaseRQMeta._get_resettable_fields()
return base_fields + [RQJobMetaField.RESULT]

@classmethod
def build(
def build_for(
cls,
*,
request: PatchedRequest,
Expand All @@ -221,27 +221,18 @@ class ImportRQMeta(BaseRQMeta):
)

# mutable fields
# TODO: move into base?
status: str = attrs.field(
validator=[optional_str_validator], default="", on_setattr=_update_value
)
task_progress: float | None = attrs.field(
validator=[optional_float_validator], default=None, on_setattr=_update_value
)
) # used when importing project dataset

@staticmethod
def _get_resettable_fields() -> list[RQJobMetaField]:
"""Return a list of fields that must be reset on retry"""
base_fields = BaseRQMeta._get_resettable_fields()

return base_fields + [
RQJobMetaField.PROGRESS,
RQJobMetaField.TASK_PROGRESS,
RQJobMetaField.STATUS,
]
return base_fields + [RQJobMetaField.TASK_PROGRESS]

@classmethod
def build(
def build_for(
cls,
*,
request: PatchedRequest,
Expand All @@ -255,42 +246,6 @@ def build(
tmp_file=tmp_file,
).to_dict()


@attrs.define(kw_only=True)
class LambdaRQMeta(BaseRQMeta):
# immutable fields
function_id: int | None = attrs.field(
validator=[optional_int_validator], default=None, on_setattr=attrs.setters.frozen
)
lambda_: bool | None = attrs.field(
validator=[optional_bool_validator],
init=False,
default=True,
on_setattr=attrs.setters.frozen,
)

def to_dict(self) -> dict:
d = asdict(self)
if v := d.pop(RQJobMetaField.LAMBDA + "_", None) is not None:
d[RQJobMetaField.LAMBDA] = v

return d

@classmethod
def build(
cls,
*,
request: PatchedRequest,
db_obj: Model,
function_id: int,
):
base_meta = BaseRQMeta.build(request=request, db_obj=db_obj)
return cls(
**base_meta,
function_id=function_id,
).to_dict()


class RQJobMetaField:
# common fields
FORMATTED_EXCEPTION = "formatted_exception"
Expand Down
20 changes: 11 additions & 9 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
BaseRQMeta,
ExportRQMeta,
ImportRQMeta,
LambdaRQMeta,
RequestAction,
RQId,
)
Expand All @@ -60,6 +59,7 @@
reverse,
take_by,
)
from cvat.apps.lambda_manager.rq import LambdaRQMeta
from utils.dataset_manifest import ImageManifestManager

slogger = ServerLogManager(__name__)
Expand Down Expand Up @@ -3553,17 +3553,19 @@ class RequestSerializer(serializers.Serializer):
result_url = serializers.URLField(required=False, allow_null=True)
result_id = serializers.IntegerField(required=False, allow_null=True)

def __init__(self, *args, **kwargs):
self._base_rq_job_meta: BaseRQMeta | None = None
super().__init__(*args, **kwargs)

@extend_schema_field(UserIdentifiersSerializer())
def get_owner(self, rq_job: RQJob) -> dict[str, Any]:
# TODO: define parsed meta once
rq_job_meta = BaseRQMeta.from_job(rq_job)
return UserIdentifiersSerializer(rq_job_meta.user.to_dict()).data
assert self._base_rq_job_meta
return UserIdentifiersSerializer(self._base_rq_job_meta.user.to_dict()).data

@extend_schema_field(
serializers.FloatField(min_value=0, max_value=1, required=False, allow_null=True)
)
def get_progress(self, rq_job: RQJob) -> Decimal:
# TODO: define parsed meta once
rq_job_meta = ImportRQMeta.from_job(rq_job)
# progress of task creation is stored in "task_progress" field
# progress of project import is stored in "progress" field
Expand All @@ -3585,19 +3587,19 @@ def get_expiry_date(self, rq_job: RQJob) -> Optional[str]:

@extend_schema_field(serializers.CharField(allow_blank=True))
def get_message(self, rq_job: RQJob) -> str:
# TODO: define parsed meta once
rq_job_meta = ImportRQMeta.from_job(rq_job)
assert self._base_rq_job_meta
rq_job_status = rq_job.get_status()
message = ''

if RQJobStatus.STARTED == rq_job_status:
message = rq_job_meta.status
message = self._base_rq_job_meta.status
elif RQJobStatus.FAILED == rq_job_status:
message = rq_job_meta.formatted_exception or parse_exception_message(str(rq_job.exc_info or "Unknown error"))
message = self._base_rq_job_meta.formatted_exception or parse_exception_message(str(rq_job.exc_info or "Unknown error"))

return message

def to_representation(self, rq_job: RQJob) -> dict[str, Any]:
self._base_rq_job_meta = BaseRQMeta.from_job(rq_job)
representation = super().to_representation(rq_job)

# FUTURE-TODO: support such statuses on UI
Expand Down
2 changes: 1 addition & 1 deletion cvat/apps/engine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def create(
func=_create_thread,
args=(db_task.pk, data),
job_id=rq_id,
meta=ImportRQMeta.build(request=request, db_obj=db_task),
meta=ImportRQMeta.build_for(request=request, db_obj=db_task),
depends_on=define_dependent_job(q, user_id),
failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds(),
)
Expand Down
4 changes: 2 additions & 2 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3446,7 +3446,7 @@ def _import_annotations(request, rq_id_factory, rq_func, db_obj, format_name,
user_id = request.user.id

with get_rq_lock_by_user(queue, user_id):
meta = ImportRQMeta.build(request=request, db_obj=db_obj, tmp_file=filename)
meta = ImportRQMeta.build_for(request=request, db_obj=db_obj, tmp_file=filename)
rq_job = queue.enqueue_call(
func=func,
args=func_args,
Expand Down Expand Up @@ -3548,7 +3548,7 @@ def _import_project_dataset(
user_id = request.user.id

with get_rq_lock_by_user(queue, user_id):
meta = ImportRQMeta.build(request=request, db_obj=db_obj, tmp_file=filename)
meta = ImportRQMeta.build_for(request=request, db_obj=db_obj, tmp_file=filename)
rq_job = queue.enqueue_call(
func=func,
args=func_args,
Expand Down
47 changes: 47 additions & 0 deletions cvat/apps/lambda_manager/rq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from __future__ import annotations

import attrs
from attrs import asdict
from django.db.models import Model

from cvat.apps.engine.middleware import PatchedRequest
from cvat.apps.engine.rq_job_handler import BaseRQMeta, RQJobMetaField


@attrs.define(kw_only=True)
class LambdaRQMeta(BaseRQMeta):
# immutable fields
function_id: int = attrs.field(
validator=[attrs.validators.instance_of(int)], default=None, on_setattr=attrs.setters.frozen
)
lambda_: bool = attrs.field(
validator=[attrs.validators.instance_of(bool)],
init=False,
default=True,
on_setattr=attrs.setters.frozen,
)

def to_dict(self) -> dict:
d = asdict(self)
if v := d.pop(RQJobMetaField.LAMBDA + "_", None) is not None:
d[RQJobMetaField.LAMBDA] = v

return d

@classmethod
def build_for(
cls,
*,
request: PatchedRequest,
db_obj: Model,
function_id: int,
):
base_meta = BaseRQMeta.build(request=request, db_obj=db_obj)
return cls(
**base_meta,
function_id=function_id,
).to_dict()
5 changes: 3 additions & 2 deletions cvat/apps/lambda_manager/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@
SourceType,
Task,
)
from cvat.apps.engine.rq_job_handler import LambdaRQMeta, RQId
from cvat.apps.engine.rq_job_handler import RQId
from cvat.apps.engine.serializers import LabeledDataSerializer
from cvat.apps.engine.utils import define_dependent_job, get_rq_lock_by_user
from cvat.apps.events.handlers import handle_function_call
from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
from cvat.apps.lambda_manager.models import FunctionKind
from cvat.apps.lambda_manager.permissions import LambdaPermission
from cvat.apps.lambda_manager.rq import LambdaRQMeta
from cvat.apps.lambda_manager.serializers import (
FunctionCallRequestSerializer,
FunctionCallSerializer,
Expand Down Expand Up @@ -640,7 +641,7 @@ def enqueue(
user_id = request.user.id

with get_rq_lock_by_user(queue, user_id):
meta = LambdaRQMeta.build(
meta = LambdaRQMeta.build_for(
request=request,
db_obj=Job.objects.get(pk=job) if job else Task.objects.get(pk=task),
function_id=lambda_func.id,
Expand Down

0 comments on commit 6208de4

Please sign in to comment.