Skip to content

Commit

Permalink
Reduce critical section && deprecate some API endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
Marishka17 committed Feb 14, 2025
1 parent 803cfd8 commit 694c5f4
Showing 1 changed file with 138 additions and 108 deletions.
246 changes: 138 additions & 108 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@
define_dependent_job,
get_rq_job_meta,
get_rq_lock_by_user,
get_rq_lock_for_job,
import_resource_with_clean_up_after,
parse_exception_message,
process_failed_job,
sendfile,
get_rq_lock_for_job,
)
from cvat.apps.engine.view_utils import tus_chunk_action
from cvat.apps.events.handlers import handle_dataset_import
Expand Down Expand Up @@ -1515,19 +1515,26 @@ def get_export_callback(self, save_images: bool) -> Callable:
'400': OpenApiResponse(description='Exporting without data is not allowed'),
'405': OpenApiResponse(description='Format is not available'),
})
@extend_schema(methods=['PUT'], summary='Replace task annotations / Get annotation import status',
@extend_schema(methods=['PUT'], summary='Replace task annotations',
description=textwrap.dedent("""
To check the status of an import request:
Utilizing this endpoint to check status of the import process is deprecated
in favor of the new requests API:
After initiating the annotation import, you will receive an rq_id parameter.
Make sure to include this parameter as a query parameter in your subsequent
PUT /api/tasks/id/annotations requests to track the status of the import.
GET /api/requests/<rq_id>, where `rq_id` parameter is returned in the response
on initializing request.
"""),
parameters=[
OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'),
OpenApiParameter('rq_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='rq id'),
# deprecated parameters
OpenApiParameter(
'format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats',
deprecated=True,
),
OpenApiParameter(
'rq_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='rq id',
deprecated=True,
),
],
request=PolymorphicProxySerializer('TaskAnnotationsUpdate',
# TODO: refactor to use required=False when possible
Expand All @@ -1542,9 +1549,10 @@ def get_export_callback(self, save_images: bool) -> Callable:
@extend_schema(methods=['POST'],
summary="Import annotations into a task",
description=textwrap.dedent("""
The request POST /api/tasks/id/annotations will initiate the import and will create
the rq job on the server in which the import will be carried out.
Please, use the PUT /api/tasks/id/annotations endpoint for checking status of the process.
The request POST /api/tasks/id/annotations initiates the import and creates
the rq job on the server in which the import is carried out.
Please, use the GET /api/requests/<rq_id> endpoint for checking status of the process.
The `rq_id` parameter can be found in the response on initiating request.
"""),
parameters=[
OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
Expand Down Expand Up @@ -1610,6 +1618,7 @@ def annotations(self, request, pk):
)
elif request.method == 'PUT':
format_name = request.query_params.get('format', '')
# deprecated logic, will be removed in one of the next releases
if format_name:
# NOTE: continue process of import annotations
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
Expand Down Expand Up @@ -2116,29 +2125,36 @@ def upload_finished(self, request):
'405': OpenApiResponse(description='Format is not available'),
})
@extend_schema(methods=['PUT'],
summary='Replace job annotations / Get annotation import status',
summary='Replace job annotations',
description=textwrap.dedent("""
To check the status of an import request:
After initiating the annotation import, you will receive an rq_id parameter.
Make sure to include this parameter as a query parameter in your subsequent
PUT /api/jobs/id/annotations requests to track the status of the import.
Utilizing this endpoint to check status of the import process is deprecated
in favor of the new requests API:
GET /api/requests/<rq_id>, where `rq_id` parameter is returned in the response
on initializing request.
"""),
parameters=[
OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'),
description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats',
deprecated=True,
),
OpenApiParameter('location', description='where to import the annotation from',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
enum=Location.list(),
deprecated=True,
),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False),
OpenApiParameter('use_default_location', description='Use the location that was configured in the task to import annotation',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True, deprecated=True),
location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False,
deprecated=True,
),
OpenApiParameter('filename', description='Annotation file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
deprecated=True,
),
OpenApiParameter('rq_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='rq id'),
description='rq id',
deprecated=True,
),
],
request=PolymorphicProxySerializer(
component_name='JobAnnotationsUpdate',
Expand Down Expand Up @@ -2189,6 +2205,7 @@ def annotations(self, request, pk):
elif request.method == 'PUT':
format_name = request.query_params.get('format', '')
if format_name:
# deprecated logic, will be removed in one of the next releases
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
location_conf = get_location_configuration(
db_instance=self._object, query_params=request.query_params, field_name=StorageType.SOURCE
Expand Down Expand Up @@ -3382,16 +3399,16 @@ def _import_annotations(request, rq_id_factory, rq_func, db_obj, format_name,
with get_rq_lock_for_job(queue, rq_id):
rq_job = queue.fetch_job(rq_id)

if rq_job and rq_id_should_be_checked and not is_rq_job_owner(rq_job, request.user.id):
return Response(status=status.HTTP_403_FORBIDDEN)
if rq_job:
if rq_id_should_be_checked and not is_rq_job_owner(rq_job, request.user.id):
return Response(status=status.HTTP_403_FORBIDDEN)

if request.method == 'POST':
if rq_job.get_status(refresh=False) not in (RQJobStatus.FINISHED, RQJobStatus.FAILED):
return Response(status=status.HTTP_409_CONFLICT, data='Import job already exists')

if rq_job and request.method == 'POST':
# If there is a previous job that has not been deleted
if rq_job.is_finished or rq_job.is_failed:
rq_job.delete()
rq_job = queue.fetch_job(rq_id)
else:
return Response(status=status.HTTP_409_CONFLICT, data='Import job already exists')
rq_job = None

if not rq_job:
# If filename is specified we consider that file was uploaded via TUS, so it exists in filesystem
Expand Down Expand Up @@ -3443,7 +3460,7 @@ def _import_annotations(request, rq_id_factory, rq_func, db_obj, format_name,
user_id = request.user.id

with get_rq_lock_by_user(queue, user_id):
rq_job = queue.enqueue_call(
queue.enqueue_call(
func=func,
args=func_args,
job_id=rq_id,
Expand All @@ -3456,28 +3473,38 @@ def _import_annotations(request, rq_id_factory, rq_func, db_obj, format_name,
failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds()
)

handle_dataset_import(db_obj, format_name=format_name, cloud_storage_id=db_storage.id if db_storage else None)
# log events after releasing Redis lock
if not rq_job:
handle_dataset_import(db_obj, format_name=format_name, cloud_storage_id=db_storage.id if db_storage else None)

serializer = RqIdSerializer(data={'rq_id': rq_id})
serializer.is_valid(raise_exception=True)
serializer = RqIdSerializer(data={'rq_id': rq_id})
serializer.is_valid(raise_exception=True)

return Response(serializer.data, status=status.HTTP_202_ACCEPTED)
return Response(serializer.data, status=status.HTTP_202_ACCEPTED)

# Deprecated logic, /api/requests API should be used instead
# https://greenbytes.de/tech/webdav/draft-ietf-httpapi-deprecation-header-latest.html#the-deprecation-http-response-header-field
deprecation_timestamp = int(datetime(2025, 2, 14, tzinfo=timezone.utc).timestamp())
response_headers = {
"Deprecation": f"@{deprecation_timestamp}"
}

rq_job_status = rq_job.get_status(refresh=False)
if RQJobStatus.FINISHED == rq_job_status:
rq_job.delete()
return Response(status=status.HTTP_201_CREATED, headers=response_headers)
elif RQJobStatus.FAILED == rq_job_status:
exc_info = process_failed_job(rq_job)

import_error_prefix = f'{CvatImportError.__module__}.{CvatImportError.__name__}:'
if exc_info.startswith("Traceback") and import_error_prefix in exc_info:
exc_message = exc_info.split(import_error_prefix)[-1].strip()
return Response(data=exc_message, status=status.HTTP_400_BAD_REQUEST, headers=response_headers)
else:
if rq_job.is_finished:
rq_job.delete()
return Response(status=status.HTTP_201_CREATED)
elif rq_job.is_failed:
exc_info = process_failed_job(rq_job)

import_error_prefix = f'{CvatImportError.__module__}.{CvatImportError.__name__}:'
if exc_info.startswith("Traceback") and import_error_prefix in exc_info:
exc_message = exc_info.split(import_error_prefix)[-1].strip()
return Response(data=exc_message, status=status.HTTP_400_BAD_REQUEST)
else:
return Response(data=exc_info,
status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return Response(data=exc_info,
status=status.HTTP_500_INTERNAL_SERVER_ERROR, headers=response_headers)

return Response(status=status.HTTP_202_ACCEPTED)
return Response(status=status.HTTP_202_ACCEPTED, headers=response_headers)

def _import_project_dataset(
request, rq_id_factory, rq_func, db_obj, format_name,
Expand All @@ -3499,73 +3526,76 @@ def _import_project_dataset(
with get_rq_lock_for_job(queue, rq_id):
rq_job = queue.fetch_job(rq_id)

if not rq_job or rq_job.is_finished or rq_job.is_failed:
if rq_job and (rq_job.is_finished or rq_job.is_failed):
# for some reason the previous job has not been deleted
# (e.g the user closed the browser tab when job has been created
# but no one requests for checking status were not made)
rq_job.delete()
if rq_job:
rq_job_status = rq_job.get_status(refresh=False)
if rq_job_status not in (RQJobStatus.FINISHED, RQJobStatus.FAILED):
return Response(status=status.HTTP_409_CONFLICT, data='Import job already exists')

location = location_conf.get('location') if location_conf else None
db_storage = None
# for some reason the previous job has not been deleted
# (e.g the user closed the browser tab when job has been created
# but no one requests for checking status were not made)
rq_job.delete()
rq_job = None

if not filename and location != Location.CLOUD_STORAGE:
serializer = DatasetFileSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
dataset_file = serializer.validated_data['dataset_file']
with NamedTemporaryFile(
prefix='cvat_{}'.format(db_obj.pk),
dir=settings.TMP_FILES_ROOT,
delete=False) as tf:
filename = tf.name
for chunk in dataset_file.chunks():
tf.write(chunk)
location = location_conf.get('location') if location_conf else None
db_storage = None

elif location == Location.CLOUD_STORAGE:
assert filename, 'The filename was not specified'
try:
storage_id = location_conf['storage_id']
except KeyError:
raise serializers.ValidationError(
'Cloud storage location was selected as the source,'
' but cloud storage id was not specified')
db_storage = get_cloud_storage_for_import_or_export(
storage_id=storage_id, request=request,
is_default=location_conf['is_default'])

key = filename
if not filename and location != Location.CLOUD_STORAGE:
serializer = DatasetFileSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
dataset_file = serializer.validated_data['dataset_file']
with NamedTemporaryFile(
prefix='cvat_{}'.format(db_obj.pk),
dir=settings.TMP_FILES_ROOT,
delete=False) as tf:
filename = tf.name
for chunk in dataset_file.chunks():
tf.write(chunk)

func = import_resource_with_clean_up_after
func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly)
elif location == Location.CLOUD_STORAGE:
assert filename, 'The filename was not specified'
try:
storage_id = location_conf['storage_id']
except KeyError:
raise serializers.ValidationError(
'Cloud storage location was selected as the source,'
' but cloud storage id was not specified')
db_storage = get_cloud_storage_for_import_or_export(
storage_id=storage_id, request=request,
is_default=location_conf['is_default'])

key = filename
with NamedTemporaryFile(
prefix='cvat_{}'.format(db_obj.pk),
dir=settings.TMP_FILES_ROOT,
delete=False) as tf:
filename = tf.name

func = import_resource_with_clean_up_after
func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly)

if location == Location.CLOUD_STORAGE:
func_args = (db_storage, key, func) + func_args
func = import_resource_from_cloud_storage

if location == Location.CLOUD_STORAGE:
func_args = (db_storage, key, func) + func_args
func = import_resource_from_cloud_storage
user_id = request.user.id

user_id = request.user.id
with get_rq_lock_by_user(queue, user_id):
queue.enqueue_call(
func=func,
args=func_args,
job_id=rq_id,
meta={
'tmp_file': filename,
**get_rq_job_meta(request=request, db_obj=db_obj),
},
depends_on=define_dependent_job(queue, user_id, rq_id=rq_id),
result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(),
failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds()
)

with get_rq_lock_by_user(queue, user_id):
rq_job = queue.enqueue_call(
func=func,
args=func_args,
job_id=rq_id,
meta={
'tmp_file': filename,
**get_rq_job_meta(request=request, db_obj=db_obj),
},
depends_on=define_dependent_job(queue, user_id, rq_id=rq_id),
result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(),
failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds()
)

handle_dataset_import(db_obj, format_name=format_name, cloud_storage_id=db_storage.id if db_storage else None)
else:
return Response(status=status.HTTP_409_CONFLICT, data='Import job already exists')
handle_dataset_import(db_obj, format_name=format_name, cloud_storage_id=db_storage.id if db_storage else None)

serializer = RqIdSerializer(data={'rq_id': rq_id})
serializer.is_valid(raise_exception=True)
Expand Down

0 comments on commit 694c5f4

Please sign in to comment.