Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion radis-client/radis_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,17 @@ def update_report(
return response.json()

def update_reports_bulk(
self, reports: list[ReportData], upsert: bool = True
self,
reports: list[ReportData],
upsert: bool = True,
timeout: float | tuple[float, float] | None = None,
) -> dict[str, Any]:
"""Bulk upsert reports using a single request.

Args:
reports: The report payloads to upsert.
upsert: Whether to perform upsert behavior when a report is missing.
timeout: Optional requests timeout (seconds).

Returns:
The response as JSON.
Expand All @@ -119,6 +123,7 @@ def update_reports_bulk(
json=payload,
headers=self._headers,
params={"upsert": upsert},
timeout=timeout,
)
response.raise_for_status()
return response.json()
Expand Down
26 changes: 26 additions & 0 deletions radis/pgsearch/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import logging
from typing import Any

from procrastinate.contrib.django import app

from .utils.indexing import bulk_upsert_report_search_vectors

logger = logging.getLogger(__name__)


@app.task
def bulk_index_reports(report_ids: list[int]) -> None:
if not report_ids:
return
logger.info("Indexing %s reports in bulk.", len(report_ids))
bulk_upsert_report_search_vectors(report_ids)


def enqueue_bulk_index_reports(report_ids: list[int]) -> int | None:
if not report_ids:
return None
payload: list[Any] = [int(report_id) for report_id in report_ids]
return app.configure_task(
"radis.pgsearch.tasks.bulk_index_reports",
allow_unknown=False,
).defer(report_ids=payload)
33 changes: 33 additions & 0 deletions radis/pgsearch/tests/test_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

from radis.pgsearch.models import ReportSearchVector
from radis.pgsearch.utils.indexing import bulk_upsert_report_search_vectors
from radis.reports.models import Language, Report


@pytest.mark.django_db
def test_bulk_index_matches_signal_vector() -> None:
language = Language.objects.create(code="en")
report = Report.objects.create(
document_id="DOC-INDEX",
pacs_aet="PACS",
pacs_name="PACS",
pacs_link="",
patient_id="P1",
patient_birth_date="1980-01-01",
patient_sex="M",
study_description="Study",
study_datetime="2024-01-01T00:00:00Z",
study_instance_uid="1.2.3.4",
accession_number="ACC1",
body="Findings: No acute abnormality.",
language=language,
)

signal_vector = ReportSearchVector.objects.get(report=report).search_vector
ReportSearchVector.objects.filter(report=report).delete()

bulk_upsert_report_search_vectors([report.pk])
bulk_vector = ReportSearchVector.objects.get(report=report).search_vector

assert signal_vector == bulk_vector
62 changes: 62 additions & 0 deletions radis/pgsearch/utils/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

from collections.abc import Iterable

from django.conf import settings
from django.db import connection

from radis.reports.models import Report

from ..models import ReportSearchVector
from .language_utils import code_to_language


def _chunked(items: list[int], size: int) -> Iterable[list[int]]:
for index in range(0, len(items), size):
yield items[index : index + size]


def bulk_upsert_report_search_vectors(
report_ids: Iterable[int],
chunk_size: int | None = None,
) -> None:
ids = sorted({int(report_id) for report_id in report_ids if report_id is not None})
if not ids:
return
resolved_chunk_size = (
settings.PGSEARCH_BULK_INDEX_CHUNK_SIZE if chunk_size is None else chunk_size
)

for chunk in _chunked(ids, resolved_chunk_size):
reports = (
Report.objects.filter(id__in=chunk)
.select_related("language")
.only("id", "language__code")
)
config_to_ids: dict[str, list[int]] = {}
config_cache: dict[str, str] = {}
for report in reports:
language_code = report.language.code
config = config_cache.get(language_code)
if config is None:
config = code_to_language(language_code)
config_cache[language_code] = config
config_to_ids.setdefault(config, []).append(report.pk)

for config, config_ids in config_to_ids.items():
ReportSearchVector.objects.bulk_create(
[ReportSearchVector(report_id=report_id) for report_id in config_ids],
ignore_conflicts=True,
batch_size=settings.PGSEARCH_BULK_INSERT_BATCH_SIZE,
)

with connection.cursor() as cursor:
cursor.execute(
"""
UPDATE pgsearch_reportsearchvector v
SET search_vector = to_tsvector(%s::regconfig, r.body)
FROM reports_report r
WHERE v.report_id = r.id AND r.id = ANY(%s)
""",
[config, config_ids],
)
2 changes: 1 addition & 1 deletion radis/pgsearch/utils/language_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_available_search_configs() -> set[str]:
try:
return _get_available_search_configs_cached()
except DatabaseError as exc:
logger.warning("Failed to read pg_ts_config; falling back to simple. %s", exc)
logger.error("Failed to read pg_ts_config; falling back to simple. %s", exc)
return set()


Expand Down
10 changes: 10 additions & 0 deletions radis/reports/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db import transaction
from rest_framework import serializers, validators
from rest_framework.exceptions import ValidationError
from rest_framework.relations import PrimaryKeyRelatedField

from ..models import Language, Metadata, Modality, Report

Expand Down Expand Up @@ -50,6 +51,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if self.context.get("skip_document_id_unique"):
self._strip_unique_validator("document_id")
request = self.context.get("request")
if request is not None and "groups" in self.fields:
groups_field = self.fields["groups"]
if isinstance(groups_field, PrimaryKeyRelatedField):
if groups_field.queryset is not None:
if request.user.is_superuser:
groups_field.queryset = groups_field.queryset.all()
else:
groups_field.queryset = request.user.groups.all()

class Meta:
model = Report
Expand Down
103 changes: 97 additions & 6 deletions radis/reports/api/viewsets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Any

from django.conf import settings
from django.db import transaction
from django.http import Http404
from django.utils import timezone
Expand All @@ -12,6 +13,9 @@
from rest_framework.response import Response
from rest_framework.serializers import BaseSerializer

from radis.pgsearch.tasks import enqueue_bulk_index_reports
from radis.pgsearch.utils.indexing import bulk_upsert_report_search_vectors

from ..models import Language, Metadata, Modality, Report
from ..site import (
document_fetchers,
Expand All @@ -26,10 +30,59 @@
BULK_DB_BATCH_SIZE = 1000


def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[str], list[str]]:
def _bulk_upsert_reports(
validated_reports: list[dict[str, Any]],
replace: bool = True,
) -> tuple[list[str], list[str]]:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if not validated_reports:
return [], []

deduped_reports: dict[str, dict[str, Any]] = {}
duplicate_count = 0
for report in validated_reports:
document_id = report["document_id"]
if document_id in deduped_reports:
duplicate_count += 1
deduped_reports[document_id] = report
if duplicate_count:
logger.warning(
"Bulk upsert payload contained %s duplicate document_ids; keeping last occurrence.",
duplicate_count,
)
validated_reports = list(deduped_reports.values())

def _dedupe_by_key(
items: list[dict[str, Any]], key_name: str
) -> tuple[list[dict[str, Any]], int]:
if not items:
return [], 0
by_key: dict[str, dict[str, Any]] = {}
for item in items:
key = item[key_name]
by_key[key] = item
return list(by_key.values()), len(items) - len(by_key)

def _dedupe_metadata(items: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], int]:
if not items:
return [], 0
by_key: dict[str, dict[str, Any]] = {}
duplicates = 0
for item in items:
key = item["key"]
if key in by_key:
duplicates += 1
by_key[key] = item
return list(by_key.values()), duplicates

def _dedupe_groups(items: list[Any]) -> tuple[list[int], int]:
if not items:
return [], 0
by_id: dict[int, int] = {}
for group in items:
group_id = int(getattr(group, "pk", group))
by_id[group_id] = group_id
return list(by_id.values()), len(items) - len(by_id)

document_ids = [report["document_id"] for report in validated_reports]

language_codes = {report["language"]["code"] for report in validated_reports}
Expand Down Expand Up @@ -135,9 +188,12 @@ def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[
Metadata.objects.filter(report_id__in=report_ids).delete()

metadata_rows: list[Metadata] = []
metadata_duplicate_count = 0
for report_data in validated_reports:
report_id = report_id_by_document_id[report_data["document_id"]]
for item in report_data.get("metadata", []):
metadata_items, duplicates = _dedupe_metadata(report_data.get("metadata", []))
metadata_duplicate_count += duplicates
for item in metadata_items:
metadata_rows.append(
Metadata(report_id=report_id, key=item["key"], value=item["value"])
)
Expand All @@ -148,9 +204,14 @@ def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[
modality_through.objects.filter(report_id__in=report_ids).delete()

modality_rows = []
modality_duplicate_count = 0
for report_data in validated_reports:
report_id = report_id_by_document_id[report_data["document_id"]]
for modality in report_data.get("modalities", []):
modality_items, duplicates = _dedupe_by_key(
report_data.get("modalities", []), "code"
)
modality_duplicate_count += duplicates
for modality in modality_items:
modality_id = modality_by_code[modality["code"]].pk
modality_rows.append(
modality_through(report_id=report_id, modality_id=modality_id)
Expand All @@ -162,13 +223,31 @@ def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[
group_through.objects.filter(report_id__in=report_ids).delete()

group_rows = []
group_duplicate_count = 0
for report_data in validated_reports:
report_id = report_id_by_document_id[report_data["document_id"]]
for group in report_data.get("groups", []):
group_rows.append(group_through(report_id=report_id, group_id=group.pk))
group_items, duplicates = _dedupe_groups(report_data.get("groups", []))
group_duplicate_count += duplicates
for group_id in group_items:
group_rows.append(group_through(report_id=report_id, group_id=group_id))
if group_rows:
group_through.objects.bulk_create(group_rows, batch_size=BULK_DB_BATCH_SIZE)

if metadata_duplicate_count or modality_duplicate_count or group_duplicate_count:
logger.warning(
"Bulk upsert payload contained duplicate metadata/modality/group entries "
"(metadata=%s modalities=%s groups=%s); duplicates were dropped.",
metadata_duplicate_count,
modality_duplicate_count,
group_duplicate_count,
)

touched_report_ids = [
report_id_by_document_id[document_id]
for document_id in [*created_ids, *updated_ids]
if document_id in report_id_by_document_id
]

def on_commit():
if created_ids:
created_reports = list(Report.objects.filter(document_id__in=created_ids))
Expand All @@ -178,6 +257,11 @@ def on_commit():
updated_reports = list(Report.objects.filter(document_id__in=updated_ids))
for handler in reports_updated_handlers:
handler.handle(updated_reports)
if touched_report_ids:
if settings.PGSEARCH_SYNC_INDEXING:
bulk_upsert_report_search_vectors(touched_report_ids)
else:
enqueue_bulk_index_reports(touched_report_ids)

transaction.on_commit(on_commit)

Expand Down Expand Up @@ -268,6 +352,13 @@ def bulk_upsert(self, request: Request) -> Response:
status=status.HTTP_400_BAD_REQUEST,
)

replace = request.GET.get("replace", "true").lower() in ["true", "1", "yes"]
if not replace:
return Response(
{"detail": "replace=false is not supported for bulk upsert. Use replace=true."},
status=status.HTTP_400_BAD_REQUEST,
)

valid_payloads: list[dict[str, Any]] = []
errors: list[dict[str, Any]] = []
for index, payload in enumerate(request.data):
Expand Down Expand Up @@ -305,7 +396,7 @@ def bulk_upsert(self, request: Request) -> Response:
created_ids: list[str] = []
updated_ids: list[str] = []
if valid_payloads:
created_ids, updated_ids = _bulk_upsert_reports(valid_payloads)
created_ids, updated_ids = _bulk_upsert_reports(valid_payloads, replace=replace)

response_body: dict[str, Any] = {
"created": len(created_ids),
Expand Down
1 change: 1 addition & 0 deletions radis/reports/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading