Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backup task created from cloud #8972

Merged
merged 16 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Added

- Tasks created from cloud storage can be backed up now
(<https://github.com/cvat-ai/cvat/pull/8972>)
38 changes: 34 additions & 4 deletions cvat/apps/engine/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import re
import shutil
import tempfile
import uuid
from abc import ABCMeta, abstractmethod
from collections.abc import Collection, Iterable
Expand Down Expand Up @@ -47,7 +48,10 @@
retry_current_rq_job,
)
from cvat.apps.engine import models
from cvat.apps.engine.cloud_provider import import_resource_from_cloud_storage
from cvat.apps.engine.cloud_provider import (
db_storage_to_storage_instance,
import_resource_from_cloud_storage,
)
from cvat.apps.engine.location import StorageType, get_location_configuration
from cvat.apps.engine.log import ServerLogManager
from cvat.apps.engine.models import (
Expand Down Expand Up @@ -395,14 +399,14 @@ class TaskExporter(_ExporterBase, _TaskBackupBase):
def __init__(self, pk, version=Version.V1):
super().__init__(logger=slogger.task[pk])

self._db_task = (
self._db_task: models.Task = (
models.Task.objects
.prefetch_related('data__images', 'annotation_guide__assets')
.select_related('data__video', 'data__validation_layout', 'annotation_guide')
.get(pk=pk)
)

self._db_data = self._db_task.data
self._db_data: models.Data = self._db_task.data
self._version = version

db_labels = (self._db_task.project if self._db_task.project_id else self._db_task).label_set.all().prefetch_related(
Expand Down Expand Up @@ -443,8 +447,31 @@ def _write_data(self, zip_object, target_dir=None):
files=[self._db_data.get_manifest_path()],
target_dir=target_data_dir,
)
elif self._db_data.storage == StorageChoice.CLOUD_STORAGE:
assert self._db_task.dimension != models.DimensionType.DIM_3D, "Cloud storage cannot contain 3d images"
assert not hasattr(self._db_data, 'video'), "Only images can be stored in cloud storage"
assert self._db_data.related_files.count() == 0, "No related images can be stored in cloud storage"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related: #9071

media_files = [im.path for im in self._db_data.images.all()]
cloud_storage_instance = db_storage_to_storage_instance(self._db_data.cloud_storage)
with tempfile.TemporaryDirectory() as tmp_dir:
zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved
cloud_storage_instance.bulk_download_to_dir(files=media_files, upload_dir=tmp_dir)
self._write_files(
source_dir=tmp_dir,
zip_object=zip_object,
files=[
os.path.join(tmp_dir, file)
for file in media_files
],
target_dir=target_data_dir,
)
self._write_files(
source_dir=self._db_data.get_upload_dirname(),
zip_object=zip_object,
files=[self._db_data.get_manifest_path()],
target_dir=target_data_dir,
)
else:
raise NotImplementedError("We don't currently support backing up tasks with data from cloud storage")
raise NotImplementedError

def _write_task(self, zip_object, target_dir=None):
task_dir = self._db_task.get_dirname()
Expand Down Expand Up @@ -556,6 +583,9 @@ def serialize_data():
]
data['validation_layout'] = validation_params

if self._db_data.storage == StorageChoice.CLOUD_STORAGE:
data["storage"] = StorageChoice.LOCAL

return self._prepare_data_meta(data)

task = serialize_task()
Expand Down
13 changes: 8 additions & 5 deletions cvat/apps/engine/media_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,17 @@ def __getitem__(self, idx: int):
value = super().__getitem__(idx)
value_size = self._get_object_size(value)

while (
len(self._cache) + 1 > self.max_cache_entries or
self.used_cache_memory + value_size > self.max_cache_memory
):
def can_put_item_in_cache():
zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved
return (
len(self._cache) + 1 <= self.max_cache_entries and
self.used_cache_memory + value_size <= self.max_cache_memory
)

while len(self._cache) > 0 and not can_put_item_in_cache():
min_key = min(self._cache.keys())
self._cache.pop(min_key)

if self.used_cache_memory + value_size <= self.max_cache_memory:
if can_put_item_in_cache():
self._cache[idx] = self._CacheItem(value, value_size)

return value
Expand Down
141 changes: 140 additions & 1 deletion cvat/apps/engine/tests/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from django.conf import settings
from django.contrib.auth.models import Group, User
from django.http import HttpResponse
from django.test import override_settings
from pdf2image import convert_from_bytes
from PIL import Image
from pycocotools import coco as coco_loader
Expand All @@ -42,6 +43,7 @@

from cvat.apps.dataset_manager.tests.utils import TestDir
from cvat.apps.dataset_manager.util import current_function_name
from cvat.apps.engine.cloud_provider import AWS_S3, Status
from cvat.apps.engine.media_extractors import ValidateDimension, sort
from cvat.apps.engine.models import (
AttributeSpec,
Expand Down Expand Up @@ -1317,6 +1319,7 @@ def test_api_v2_projects_id_tasks_no_auth(self):
response = self._run_api_v2_projects_id_tasks(None, project.id)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)


class ProjectBackupAPITestCase(ApiTestBase):
@classmethod
def setUpTestData(cls):
Expand Down Expand Up @@ -1633,6 +1636,12 @@ def _run_api_v2_projects_id(self, pid, user):

return response.data

def _get_tasks_for_project(self, user, pid):
with ForceLogin(user, self.client):
response = self.client.get('/api/tasks?project_id={}'.format(pid))

return sorted(response.data["results"], key=lambda task: task["name"])

def _run_api_v2_projects_id_export_import(self, user):
for project in self.projects:
if user:
Expand Down Expand Up @@ -1669,7 +1678,7 @@ def _run_api_v2_projects_id_export_import(self, user):
}
response = self._run_api_v2_projects_import(user, uploaded_data)
self.assertEqual(response.status_code, HTTP_202_ACCEPTED)
if response.status_code == status.HTTP_200_OK:
if response.status_code == status.HTTP_202_ACCEPTED:
rq_id = response.data["rq_id"]
response = self._run_api_v2_projects_import(user, {"rq_id": rq_id})
self.assertEqual(response.status_code, HTTP_201_CREATED)
Expand All @@ -1691,6 +1700,26 @@ def _run_api_v2_projects_id_export_import(self, user):
"tasks",
),
)
self.assertEqual(original_project["tasks"]["count"], imported_project["tasks"]["count"])
original_tasks = self._get_tasks_for_project(user, original_project["id"])
imported_tasks = self._get_tasks_for_project(user, imported_project["id"])
for original_task, imported_task in zip(original_tasks, imported_tasks):
compare_objects(
self=self,
obj1=original_task,
obj2=imported_task,
ignore_keys=(
"id",
"url",
"created_date",
"updated_date",
"username",
"project_id",
"data",
# backup does not have this info for some reason
"overlap",
zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved
),
)

def test_api_v2_projects_id_export_admin(self):
self._run_api_v2_projects_id_export_import(self.admin)
Expand All @@ -1704,6 +1733,116 @@ def test_api_v2_projects_id_export_somebody(self):
def test_api_v2_projects_id_export_no_auth(self):
self._run_api_v2_projects_id_export_import(None)


@override_settings(MEDIA_CACHE_ALLOW_STATIC_CACHE=False)
class ProjectCloudBackupAPINoStaticChunksTestCase(ProjectBackupAPITestCase):
@classmethod
def setUpTestData(cls):
create_db_users(cls)
cls.client = APIClient()
cls._create_cloud_storage()
cls._create_media()
cls._create_projects()

@classmethod
def _create_cloud_storage(cls):
data = {
"provider_type": "AWS_S3_BUCKET",
"resource": "test",
"display_name": "Bucket",
"credentials_type": "KEY_SECRET_KEY_PAIR",
"key": "minio_access_key",
"secret_key": "minio_secret_key",
"specific_attributes": "endpoint_url=http://minio:9000",
"description": "Some description",
"manifests": [],
}

class MockAWS(AWS_S3):
_files = {}

def get_status(self):
return Status.AVAILABLE

@classmethod
def create_file(cls, key, _bytes):
cls._files[key] = _bytes

def get_file_status(self, key):
return Status.AVAILABLE if key in self._files else Status.NOT_FOUND

def _download_range_of_bytes(self, key, stop_byte, start_byte):
return self._files[key][start_byte:stop_byte]

def _download_fileobj_to_stream(self, key, stream):
stream.write(self._files[key])

cls.mock_aws = MockAWS

cls.aws_patch = mock.patch("cvat.apps.engine.cloud_provider.AWS_S3", MockAWS)
cls.aws_patch.start()

with ForceLogin(cls.owner, cls.client):
response = cls.client.post('/api/cloudstorages', data=data, format="json")
assert response.status_code == status.HTTP_201_CREATED, (response.status_code, response.content)
cls.cloud_storage_id = response.json()["id"]

@classmethod
def tearDownClass(cls):
cls.aws_patch.stop()
super().tearDownClass()

@classmethod
def _create_media(cls):
cls.media_data = []
cls.media = {'files': [], 'dirs': []}
for file in [
generate_random_image_file("test_1.jpg")[1],
generate_random_image_file("test_2.jpg")[1],
generate_pdf_file("test_pdf_1.pdf", 7)[1],
generate_zip_archive_file("test_archive_1.zip", 10)[1],
generate_video_file("test_video.mp4")[1],
]:
cls.mock_aws.create_file(file.name, file.getvalue())

cls.media_data.extend([
# image list cloud
{
"server_files[0]": "test_1.jpg",
"server_files[1]": "test_2.jpg",
"image_quality": 75,
"cloud_storage_id": cls.cloud_storage_id,
"storage": StorageChoice.CLOUD_STORAGE,
},
# video cloud
{
"server_files[0]": "test_video.mp4",
"image_quality": 75,
"cloud_storage_id": cls.cloud_storage_id,
"storage": StorageChoice.CLOUD_STORAGE,
},
# zip archive cloud
{
"server_files[0]": "test_archive_1.zip",
"image_quality": 50,
"cloud_storage_id": cls.cloud_storage_id,
"storage": StorageChoice.CLOUD_STORAGE,
},
# pdf cloud
{
"server_files[0]": "test_pdf_1.pdf",
"image_quality": 54,
"cloud_storage_id": cls.cloud_storage_id,
"storage": StorageChoice.CLOUD_STORAGE,
},
])


@override_settings(MEDIA_CACHE_ALLOW_STATIC_CACHE=True)
class ProjectCloudBackupAPIStaticChunksTestCase(ProjectCloudBackupAPINoStaticChunksTestCase):
pass


class ProjectExportAPITestCase(ApiTestBase):
@classmethod
def setUpTestData(cls):
Expand Down
28 changes: 28 additions & 0 deletions tests/python/rest_api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4162,6 +4162,34 @@ def test_cannot_export_backup_for_task_without_data(self, tasks):
assert exc.status == HTTPStatus.BAD_REQUEST
assert "Backup of a task without data is not allowed" == exc.body.encode()

@pytest.mark.with_external_services
def test_can_export_and_import_backup_task_with_cloud_storage(self, tasks):
zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved
zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved
cloud_storage_content = ["image_case_65_1.png", "image_case_65_2.png"]
task_spec = {
"name": "Task with files from cloud storage",
"labels": [
{
"name": "car",
}
],
}
data_spec = {
"image_quality": 75,
"use_cache": False,
"cloud_storage_id": 1,
"server_files": cloud_storage_content,
}
task_id, _ = create_task(self.user, task_spec, data_spec)

task = self.client.tasks.retrieve(task_id)

filename = self.tmp_dir / f"cloud_task_{task.id}_backup.zip"
task.download_backup(filename)

assert filename.is_file()
assert filename.stat().st_size > 0
self._test_can_restore_task_from_backup(task_id)

@pytest.mark.parametrize("mode", ["annotation", "interpolation"])
def test_can_import_backup(self, tasks, mode):
task_id = next(t for t in tasks if t["mode"] == mode)["id"]
Expand Down
Loading