Skip to content

Commit

Permalink
Refactor bulk_create calls (#9047)
Browse files Browse the repository at this point in the history
- Added default batch size in bulk create calls
- Removed SQLite workaround in `bulk_create()` calls
  • Loading branch information
zhiltsov-max authored Feb 13, 2025
1 parent a31a782 commit 7959be8
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 139 deletions.
2 changes: 1 addition & 1 deletion cvat/apps/dataset_manager/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from django.utils import timezone

from cvat.apps.dataset_manager.formats.utils import get_label_color
from cvat.apps.dataset_manager.util import add_prefetch_fields
from cvat.apps.engine import models
from cvat.apps.engine.frame_provider import FrameOutputType, FrameQuality, TaskFrameProvider
from cvat.apps.engine.lazy_list import LazyList
from cvat.apps.engine.model_utils import add_prefetch_fields
from cvat.apps.engine.models import (
AttributeSpec,
AttributeType,
Expand Down
3 changes: 2 additions & 1 deletion cvat/apps/dataset_manager/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from cvat.apps.dataset_manager.util import TmpDirManager
from cvat.apps.engine import models
from cvat.apps.engine.log import DatasetLogManager
from cvat.apps.engine.model_utils import bulk_create
from cvat.apps.engine.rq_job_handler import RQJobMetaField
from cvat.apps.engine.serializers import DataSerializer, TaskWriteSerializer
from cvat.apps.engine.task import _create_thread as create_task
Expand Down Expand Up @@ -128,7 +129,7 @@ def add_labels(self, labels: list[models.Label], attributes: list[tuple[str, mod
label, = filter(lambda l: l.name == label_name, labels)
attribute.label = label
if attributes:
models.AttributeSpec.objects.bulk_create([a[1] for a in attributes])
bulk_create(models.AttributeSpec, [a[1] for a in attributes])

def init_from_db(self):
self.reset()
Expand Down
57 changes: 10 additions & 47 deletions cvat/apps/dataset_manager/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,10 @@
TaskData,
)
from cvat.apps.dataset_manager.formats.registry import make_exporter, make_importer
from cvat.apps.dataset_manager.util import (
TmpDirManager,
add_prefetch_fields,
bulk_create,
faster_deepcopy,
get_cached,
)
from cvat.apps.dataset_manager.util import TmpDirManager, faster_deepcopy
from cvat.apps.engine import models, serializers
from cvat.apps.engine.log import DatasetLogManager
from cvat.apps.engine.model_utils import add_prefetch_fields, bulk_create, get_cached
from cvat.apps.engine.plugins import plugin_decorator
from cvat.apps.engine.utils import take_by
from cvat.apps.events.handlers import handle_annotations_change
Expand Down Expand Up @@ -284,38 +279,22 @@ def create_tracks(tracks, parent_track=None):
if elements or parent_track is None:
track["elements"] = elements

db_tracks = bulk_create(
db_model=models.LabeledTrack,
objects=db_tracks,
flt_param={"job_id": self.db_job.id}
)
db_tracks = bulk_create(models.LabeledTrack, db_tracks)

for db_attr_val in db_track_attr_vals:
db_attr_val.track_id = db_tracks[db_attr_val.track_id].id

bulk_create(
db_model=models.LabeledTrackAttributeVal,
objects=db_track_attr_vals,
flt_param={}
)
bulk_create(models.LabeledTrackAttributeVal, db_track_attr_vals)

for db_shape in db_shapes:
db_shape.track_id = db_tracks[db_shape.track_id].id

db_shapes = bulk_create(
db_model=models.TrackedShape,
objects=db_shapes,
flt_param={"track__job_id": self.db_job.id}
)
db_shapes = bulk_create(models.TrackedShape, db_shapes)

for db_attr_val in db_shape_attr_vals:
db_attr_val.shape_id = db_shapes[db_attr_val.shape_id].id

bulk_create(
db_model=models.TrackedShapeAttributeVal,
objects=db_shape_attr_vals,
flt_param={}
)
bulk_create(models.TrackedShapeAttributeVal, db_shape_attr_vals,)

shape_idx = 0
for track, db_track in zip(tracks, db_tracks):
Expand Down Expand Up @@ -355,20 +334,12 @@ def create_shapes(shapes, parent_shape=None):
if shape_elements or parent_shape is None:
shape["elements"] = shape_elements

db_shapes = bulk_create(
db_model=models.LabeledShape,
objects=db_shapes,
flt_param={"job_id": self.db_job.id}
)
db_shapes = bulk_create(models.LabeledShape, db_shapes)

for db_attr_val in db_attr_vals:
db_attr_val.shape_id = db_shapes[db_attr_val.shape_id].id

bulk_create(
db_model=models.LabeledShapeAttributeVal,
objects=db_attr_vals,
flt_param={}
)
bulk_create(models.LabeledShapeAttributeVal, db_attr_vals)

for shape, db_shape in zip(shapes, db_shapes):
shape["id"] = db_shape.id
Expand Down Expand Up @@ -399,20 +370,12 @@ def _save_tags_to_db(self, tags):
db_tags.append(db_tag)
tag["attributes"] = attributes

db_tags = bulk_create(
db_model=models.LabeledImage,
objects=db_tags,
flt_param={"job_id": self.db_job.id}
)
db_tags = bulk_create(models.LabeledImage, db_tags)

for db_attr_val in db_attr_vals:
db_attr_val.image_id = db_tags[db_attr_val.tag_id].id

bulk_create(
db_model=models.LabeledImageAttributeVal,
objects=db_attr_vals,
flt_param={}
)
bulk_create(models.LabeledImageAttributeVal, db_attr_vals)

for tag, db_tag in zip(tags, db_tags):
tag["id"] = db_tag.id
Expand Down
61 changes: 2 additions & 59 deletions cvat/apps/dataset_manager/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,19 @@
import re
import tempfile
import zipfile
from collections.abc import Generator, Iterable, Sequence
from collections.abc import Generator
from contextlib import contextmanager
from copy import deepcopy
from datetime import timedelta
from enum import Enum
from threading import Lock
from typing import Any, TypeVar
from typing import Any

import attrs
import django_rq
from datumaro.util import to_snake_case
from datumaro.util.os_util import make_file_name
from django.conf import settings
from django.db import models
from pottery import Redlock


Expand All @@ -38,62 +37,6 @@ def make_zip_archive(src_path, dst_path):
archive.write(path, osp.relpath(path, src_path))


_ModelT = TypeVar("_ModelT", bound=models.Model)

def bulk_create(
db_model: type[_ModelT],
objects: Iterable[_ModelT],
*,
flt_param: dict[str, Any] | None = None,
batch_size: int | None = 10000
) -> list[_ModelT]:
if objects:
if flt_param:
if "postgresql" in settings.DATABASES["default"]["ENGINE"]:
return db_model.objects.bulk_create(objects, batch_size=batch_size)
else:
ids = list(db_model.objects.filter(**flt_param).values_list('id', flat=True))
db_model.objects.bulk_create(objects, batch_size=batch_size)

return list(db_model.objects.exclude(id__in=ids).filter(**flt_param))
else:
return db_model.objects.bulk_create(objects, batch_size=batch_size)

return []


def is_prefetched(queryset: models.QuerySet, field: str) -> bool:
return field in queryset._prefetch_related_lookups


def add_prefetch_fields(queryset: models.QuerySet, fields: Sequence[str]) -> models.QuerySet:
for field in fields:
if not is_prefetched(queryset, field):
queryset = queryset.prefetch_related(field)

return queryset


def get_cached(queryset: models.QuerySet, pk: int) -> models.Model:
"""
Like regular queryset.get(), but checks for the cached values first
instead of just making a request.
"""

# Read more about caching insights:
# https://www.mattduck.com/2021-01-django-orm-result-cache.html
# The field is initialized on accessing the query results, eg. on iteration
if getattr(queryset, '_result_cache'):
result = next((obj for obj in queryset if obj.pk == pk), None)
else:
result = None

if result is None:
result = queryset.get(id=pk)

return result


def faster_deepcopy(v):
"A slightly optimized version of the default deepcopy, can be used as a drop-in replacement."
# Default deepcopy is very slow, here we do shallow copy for primitive types and containers
Expand Down
2 changes: 2 additions & 0 deletions cvat/apps/engine/default_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,5 @@
MAX_CONSENSUS_REPLICAS = int(os.getenv("CVAT_MAX_CONSENSUS_REPLICAS", 11))
if MAX_CONSENSUS_REPLICAS < 1:
raise ImproperlyConfigured(f"MAX_CONSENSUS_REPLICAS must be >= 1, got {MAX_CONSENSUS_REPLICAS}")

DEFAULT_DB_BULK_CREATE_BATCH_SIZE = int(os.getenv("CVAT_DEFAULT_DB_BULK_CREATE_BATCH_SIZE", 5000))
77 changes: 76 additions & 1 deletion cvat/apps/engine/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

from __future__ import annotations

from typing import TypeVar, Union
from collections.abc import Iterable
from typing import Sequence, TypeVar, Union

from django.conf import settings
from django.db import models

_T = TypeVar("_T")

Expand All @@ -18,3 +22,74 @@ class Undefined:
The reverse side of one-to-one relationship.
May be undefined in the object, should be accessed via getattr().
"""


_ModelT = TypeVar("_ModelT", bound=models.Model)
_unspecified = object()


def bulk_create(
db_model: type[_ModelT],
objs: Iterable[_ModelT],
*,
batch_size: int | None = _unspecified,
ignore_conflicts: bool = False,
update_conflicts: bool | None = False,
update_fields: Sequence[str] | None = None,
unique_fields: Sequence[str] | None = None,
) -> list[_ModelT]:
"""
Like Django's Model.objects.bulk_create(), but applies the default batch size configured by
the DEFAULT_DB_BULK_CREATE_BATCH_SIZE setting.
"""

if batch_size is _unspecified:
batch_size = settings.DEFAULT_DB_BULK_CREATE_BATCH_SIZE

if not objs:
return []

return db_model.objects.bulk_create(
objs,
batch_size=batch_size,
ignore_conflicts=ignore_conflicts,
update_conflicts=update_conflicts,
update_fields=update_fields,
unique_fields=unique_fields,
)


def is_prefetched(queryset: models.QuerySet, field: str) -> bool:
"Checks if a field is being prefetched in the queryset"
return field in queryset._prefetch_related_lookups


_QuerysetT = TypeVar("_QuerysetT", bound=models.QuerySet)


def add_prefetch_fields(queryset: _QuerysetT, fields: Sequence[str]) -> _QuerysetT:
for field in fields:
if not is_prefetched(queryset, field):
queryset = queryset.prefetch_related(field)

return queryset


def get_cached(queryset: _QuerysetT, pk: int) -> _ModelT:
"""
Like regular queryset.get(), but checks for the cached values first
instead of just making a request.
"""

# Read more about caching insights:
# https://www.mattduck.com/2021-01-django-orm-result-cache.html
# The field is initialized on accessing the query results, eg. on iteration
if getattr(queryset, "_result_cache"):
result = next((obj for obj in queryset if obj.pk == pk), None)
else:
result = None

if result is None:
result = queryset.get(id=pk)

return result
Loading

0 comments on commit 7959be8

Please sign in to comment.