diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index f41c15fc..94946a8a 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -1,3 +1,10 @@ +x-app-env: &default-app-env + DJANGO_EMAIL_URL: ${DJANGO_EMAIL_URL:?} + DJANGO_SECURE_SSL_REDIRECT: ${DJANGO_SECURE_SSL_REDIRECT:-true} + DJANGO_SETTINGS_MODULE: radis.settings.production + DJANGO_STATIC_ROOT: /var/www/web/static/ + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:?} + x-app: &default-app image: ghcr.io/openradx/radis:latest volumes: @@ -5,11 +12,7 @@ x-app: &default-app - ${SSL_SERVER_CERT_FILE:?}:/etc/web/ssl/cert.pem - ${SSL_SERVER_KEY_FILE:?}:/etc/web/ssl/key.pem environment: - DJANGO_EMAIL_URL: ${DJANGO_EMAIL_URL:?} - DJANGO_SECURE_SSL_REDIRECT: ${DJANGO_SECURE_SSL_REDIRECT:-true} - DJANGO_SETTINGS_MODULE: radis.settings.production - DJANGO_STATIC_ROOT: /var/www/web/static/ - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:?} + <<: *default-app-env x-deploy: &deploy replicas: 1 @@ -69,6 +72,15 @@ services: llm_worker: <<: *default-app + volumes: + - web_data:/var/www/web + - ${SSL_SERVER_CERT_FILE:?}:/etc/web/ssl/cert.pem + - ${SSL_SERVER_KEY_FILE:?}:/etc/web/ssl/key.pem + - ${RADIS_LLM_CA_BUNDLE:-/etc/ssl/certs/ca-certificates.crt}:/etc/ssl/certs/radis-ca-bundle.pem:ro + environment: + <<: *default-app-env + SSL_CERT_FILE: /etc/ssl/certs/radis-ca-bundle.pem + REQUESTS_CA_BUNDLE: /etc/ssl/certs/radis-ca-bundle.pem command: > bash -c " wait-for-it -s postgres.local:5432 -t 60 && diff --git a/example.env b/example.env index 687dc324..fa454540 100644 --- a/example.env +++ b/example.env @@ -73,6 +73,9 @@ SSL_IP_ADDRESSES=127.0.0.1 SSL_SERVER_CERT_FILE="./cert.pem" SSL_SERVER_KEY_FILE="./key.pem" SSL_SERVER_CHAIN_FILE="./chain.pem" +# Optional: custom CA bundle for outbound HTTPS (e.g., private LLM endpoints). +# Defaults to host system CA bundle if not set. +# RADIS_LLM_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt" # The timezone used by the server. TIME_ZONE="Europe/Berlin" diff --git a/radis-client/radis_client/client.py b/radis-client/radis_client/client.py index 05df9a55..cf5ec7f6 100644 --- a/radis-client/radis_client/client.py +++ b/radis-client/radis_client/client.py @@ -102,13 +102,17 @@ def update_report( return response.json() def update_reports_bulk( - self, reports: list[ReportData], upsert: bool = True + self, + reports: list[ReportData], + upsert: bool = True, + timeout: float | tuple[float, float] | None = None, ) -> dict[str, Any]: """Bulk upsert reports using a single request. Args: reports: The report payloads to upsert. upsert: Whether to perform upsert behavior when a report is missing. + timeout: Optional requests timeout (seconds). Returns: The response as JSON. @@ -119,6 +123,7 @@ def update_reports_bulk( json=payload, headers=self._headers, params={"upsert": upsert}, + timeout=timeout, ) response.raise_for_status() return response.json() diff --git a/radis/pgsearch/tasks.py b/radis/pgsearch/tasks.py new file mode 100644 index 00000000..2645a28d --- /dev/null +++ b/radis/pgsearch/tasks.py @@ -0,0 +1,30 @@ +import logging + +from procrastinate.contrib.django import app +from procrastinate.types import JSONValue + +from .utils.indexing import bulk_upsert_report_search_vectors + +logger = logging.getLogger(__name__) + + +@app.task +def bulk_index_reports(report_ids: list[int]) -> None: + if not report_ids: + return + logger.info("Indexing %s reports in bulk.", len(report_ids)) + bulk_upsert_report_search_vectors(report_ids) + + +def enqueue_bulk_index_reports(report_ids: list[int]) -> int | None: + if not report_ids: + return None + try: + payload: list[JSONValue] = [int(report_id) for report_id in report_ids] + except (TypeError, ValueError) as exc: + logger.error("Invalid report_id in bulk index request: %s", exc) + return None + return app.configure_task( + "radis.pgsearch.tasks.bulk_index_reports", + allow_unknown=False, + ).defer(report_ids=payload) diff --git a/radis/pgsearch/tests/test_indexing.py b/radis/pgsearch/tests/test_indexing.py new file mode 100644 index 00000000..344018f5 --- /dev/null +++ b/radis/pgsearch/tests/test_indexing.py @@ -0,0 +1,33 @@ +import pytest + +from radis.pgsearch.models import ReportSearchVector +from radis.pgsearch.utils.indexing import bulk_upsert_report_search_vectors +from radis.reports.models import Language, Report + + +@pytest.mark.django_db +def test_bulk_index_matches_signal_vector() -> None: + language = Language.objects.create(code="en") + report = Report.objects.create( + document_id="DOC-INDEX", + pacs_aet="PACS", + pacs_name="PACS", + pacs_link="", + patient_id="P1", + patient_birth_date="1980-01-01", + patient_sex="M", + study_description="Study", + study_datetime="2024-01-01T00:00:00Z", + study_instance_uid="1.2.3.4", + accession_number="ACC1", + body="Findings: No acute abnormality.", + language=language, + ) + + signal_vector = ReportSearchVector.objects.get(report=report).search_vector + ReportSearchVector.objects.filter(report=report).delete() + + bulk_upsert_report_search_vectors([report.pk]) + bulk_vector = ReportSearchVector.objects.get(report=report).search_vector + + assert signal_vector == bulk_vector diff --git a/radis/pgsearch/utils/indexing.py b/radis/pgsearch/utils/indexing.py new file mode 100644 index 00000000..882ba4b3 --- /dev/null +++ b/radis/pgsearch/utils/indexing.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import logging +from collections.abc import Iterable + +from django.conf import settings +from django.db import connection + +from radis.reports.models import Report + +from ..models import ReportSearchVector +from .language_utils import code_to_language + +logger = logging.getLogger(__name__) + + +def _chunked(items: list[int], size: int) -> Iterable[list[int]]: + for index in range(0, len(items), size): + yield items[index : index + size] + + +def bulk_upsert_report_search_vectors( + report_ids: Iterable[int], + chunk_size: int | None = None, +) -> None: + ids = sorted({int(report_id) for report_id in report_ids if report_id is not None}) + if not ids: + return + resolved_chunk_size = ( + settings.PGSEARCH_BULK_INDEX_CHUNK_SIZE if chunk_size is None else chunk_size + ) + + for chunk in _chunked(ids, resolved_chunk_size): + reports = ( + Report.objects.filter(id__in=chunk) + .select_related("language") + .only("id", "language__code") + ) + report_ids_found: set[int] = set() + config_to_ids: dict[str, list[int]] = {} + config_cache: dict[str, str] = {} + for report in reports: + report_ids_found.add(report.pk) + language_code = report.language.code + config = config_cache.get(language_code) + if config is None: + config = code_to_language(language_code) + config_cache[language_code] = config + config_to_ids.setdefault(config, []).append(report.pk) + missing_ids = set(chunk) - report_ids_found + if missing_ids: + logger.warning( + "Skipping %s missing reports during bulk index (ids=%s).", + len(missing_ids), + sorted(missing_ids)[:10], + ) + + for config, config_ids in config_to_ids.items(): + ReportSearchVector.objects.bulk_create( + [ReportSearchVector(report_id=report_id) for report_id in config_ids], + ignore_conflicts=True, + batch_size=settings.PGSEARCH_BULK_INSERT_BATCH_SIZE, + ) + + with connection.cursor() as cursor: + cursor.execute( + """ + UPDATE pgsearch_reportsearchvector v + SET search_vector = to_tsvector(%s::regconfig, r.body) + FROM reports_report r + WHERE v.report_id = r.id AND r.id = ANY(%s) + """, + [config, config_ids], + ) diff --git a/radis/pgsearch/utils/language_utils.py b/radis/pgsearch/utils/language_utils.py index 73a97685..4fd27f7b 100644 --- a/radis/pgsearch/utils/language_utils.py +++ b/radis/pgsearch/utils/language_utils.py @@ -20,7 +20,11 @@ def get_available_search_configs() -> set[str]: try: return _get_available_search_configs_cached() except DatabaseError as exc: - logger.warning("Failed to read pg_ts_config; falling back to simple. %s", exc) + logger.error( + "Failed to read pg_ts_config; falling back to simple. %s", + exc, + exc_info=True, + ) return set() diff --git a/radis/reports/api/serializers.py b/radis/reports/api/serializers.py index 1a27afb2..6d3f03f6 100644 --- a/radis/reports/api/serializers.py +++ b/radis/reports/api/serializers.py @@ -3,6 +3,7 @@ from django.db import transaction from rest_framework import serializers, validators from rest_framework.exceptions import ValidationError +from rest_framework.relations import PrimaryKeyRelatedField from ..models import Language, Metadata, Modality, Report @@ -50,6 +51,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) if self.context.get("skip_document_id_unique"): self._strip_unique_validator("document_id") + request = self.context.get("request") + if request is not None and "groups" in self.fields: + groups_field = self.fields["groups"] + if isinstance(groups_field, PrimaryKeyRelatedField): + if groups_field.queryset is not None: + if request.user.is_superuser: + groups_field.queryset = groups_field.queryset.all() + else: + groups_field.queryset = request.user.groups.all() class Meta: model = Report diff --git a/radis/reports/api/viewsets.py b/radis/reports/api/viewsets.py index e5323945..bb684b15 100644 --- a/radis/reports/api/viewsets.py +++ b/radis/reports/api/viewsets.py @@ -1,6 +1,7 @@ import logging from typing import Any +from django.conf import settings from django.db import transaction from django.http import Http404 from django.utils import timezone @@ -12,6 +13,9 @@ from rest_framework.response import Response from rest_framework.serializers import BaseSerializer +from radis.pgsearch.tasks import enqueue_bulk_index_reports +from radis.pgsearch.utils.indexing import bulk_upsert_report_search_vectors + from ..models import Language, Metadata, Modality, Report from ..site import ( document_fetchers, @@ -26,10 +30,58 @@ BULK_DB_BATCH_SIZE = 1000 -def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[str], list[str]]: +def _bulk_upsert_reports( + validated_reports: list[dict[str, Any]], +) -> tuple[list[str], list[str]]: if not validated_reports: return [], [] + deduped_reports: dict[str, dict[str, Any]] = {} + duplicate_count = 0 + for report in validated_reports: + document_id = report["document_id"] + if document_id in deduped_reports: + duplicate_count += 1 + deduped_reports[document_id] = report + if duplicate_count: + logger.warning( + "Bulk upsert payload contained %s duplicate document_ids; keeping last occurrence.", + duplicate_count, + ) + validated_reports = list(deduped_reports.values()) + + def _dedupe_by_key( + items: list[dict[str, Any]], key_name: str + ) -> tuple[list[dict[str, Any]], int]: + if not items: + return [], 0 + by_key: dict[str, dict[str, Any]] = {} + for item in items: + key = item[key_name] + by_key[key] = item + return list(by_key.values()), len(items) - len(by_key) + + def _dedupe_metadata(items: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], int]: + if not items: + return [], 0 + by_key: dict[str, dict[str, Any]] = {} + duplicates = 0 + for item in items: + key = item["key"] + if key in by_key: + duplicates += 1 + by_key[key] = item + return list(by_key.values()), duplicates + + def _dedupe_groups(items: list[Any]) -> tuple[list[int], int]: + if not items: + return [], 0 + by_id: dict[int, int] = {} + for group in items: + group_id = int(getattr(group, "pk", group)) + by_id[group_id] = group_id + return list(by_id.values()), len(items) - len(by_id) + document_ids = [report["document_id"] for report in validated_reports] language_codes = {report["language"]["code"] for report in validated_reports} @@ -135,9 +187,12 @@ def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[ Metadata.objects.filter(report_id__in=report_ids).delete() metadata_rows: list[Metadata] = [] + metadata_duplicate_count = 0 for report_data in validated_reports: report_id = report_id_by_document_id[report_data["document_id"]] - for item in report_data.get("metadata", []): + metadata_items, duplicates = _dedupe_metadata(report_data.get("metadata", [])) + metadata_duplicate_count += duplicates + for item in metadata_items: metadata_rows.append( Metadata(report_id=report_id, key=item["key"], value=item["value"]) ) @@ -148,9 +203,14 @@ def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[ modality_through.objects.filter(report_id__in=report_ids).delete() modality_rows = [] + modality_duplicate_count = 0 for report_data in validated_reports: report_id = report_id_by_document_id[report_data["document_id"]] - for modality in report_data.get("modalities", []): + modality_items, duplicates = _dedupe_by_key( + report_data.get("modalities", []), "code" + ) + modality_duplicate_count += duplicates + for modality in modality_items: modality_id = modality_by_code[modality["code"]].pk modality_rows.append( modality_through(report_id=report_id, modality_id=modality_id) @@ -162,13 +222,31 @@ def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[ group_through.objects.filter(report_id__in=report_ids).delete() group_rows = [] + group_duplicate_count = 0 for report_data in validated_reports: report_id = report_id_by_document_id[report_data["document_id"]] - for group in report_data.get("groups", []): - group_rows.append(group_through(report_id=report_id, group_id=group.pk)) + group_items, duplicates = _dedupe_groups(report_data.get("groups", [])) + group_duplicate_count += duplicates + for group_id in group_items: + group_rows.append(group_through(report_id=report_id, group_id=group_id)) if group_rows: group_through.objects.bulk_create(group_rows, batch_size=BULK_DB_BATCH_SIZE) + if metadata_duplicate_count or modality_duplicate_count or group_duplicate_count: + logger.warning( + "Bulk upsert payload contained duplicate metadata/modality/group entries " + "(metadata=%s modalities=%s groups=%s); duplicates were dropped.", + metadata_duplicate_count, + modality_duplicate_count, + group_duplicate_count, + ) + + touched_report_ids = [ + report_id_by_document_id[document_id] + for document_id in [*created_ids, *updated_ids] + if document_id in report_id_by_document_id + ] + def on_commit(): if created_ids: created_reports = list(Report.objects.filter(document_id__in=created_ids)) @@ -178,6 +256,11 @@ def on_commit(): updated_reports = list(Report.objects.filter(document_id__in=updated_ids)) for handler in reports_updated_handlers: handler.handle(updated_reports) + if touched_report_ids: + if settings.PGSEARCH_SYNC_INDEXING: + bulk_upsert_report_search_vectors(touched_report_ids) + else: + enqueue_bulk_index_reports(touched_report_ids) transaction.on_commit(on_commit) @@ -268,6 +351,13 @@ def bulk_upsert(self, request: Request) -> Response: status=status.HTTP_400_BAD_REQUEST, ) + replace = request.GET.get("replace", "true").lower() in ["true", "1", "yes"] + if not replace: + return Response( + {"detail": "replace=false is not supported for bulk upsert. Use replace=true."}, + status=status.HTTP_400_BAD_REQUEST, + ) + valid_payloads: list[dict[str, Any]] = [] errors: list[dict[str, Any]] = [] for index, payload in enumerate(request.data): diff --git a/radis/reports/tests/__init__.py b/radis/reports/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/radis/reports/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/radis/reports/tests/test_bulk_upsert.py b/radis/reports/tests/test_bulk_upsert.py new file mode 100644 index 00000000..dcd4ebde --- /dev/null +++ b/radis/reports/tests/test_bulk_upsert.py @@ -0,0 +1,186 @@ +import json +from datetime import date + +import pytest +from adit_radis_shared.accounts.factories import GroupFactory, UserFactory +from adit_radis_shared.token_authentication.models import Token +from django.test import Client + +from radis.reports.api.viewsets import _bulk_upsert_reports +from radis.reports.models import Language, Metadata, Modality, Report + + +@pytest.mark.django_db +def test_bulk_upsert_creates_and_updates_reports(client: Client): + user = UserFactory.create(is_active=True, is_staff=True) + group = GroupFactory.create() + user.groups.add(group) + _, token = Token.objects.create_token(user, "bulk upsert test", None) + payload = [ + { + "document_id": "DOC-1", + "language": "en", + "groups": [group.pk], + "pacs_aet": "PACS", + "pacs_name": "Test PACS", + "pacs_link": "", + "patient_id": "P1", + "patient_birth_date": "1980-01-01", + "patient_sex": "M", + "study_description": "Study 1", + "study_datetime": "2024-01-01T00:00:00Z", + "study_instance_uid": "1.2.3.4", + "accession_number": "ACC1", + "modalities": ["CT"], + "metadata": {"ris_filename": "file1"}, + "body": "Report body 1", + }, + { + "document_id": "DOC-2", + "language": "de", + "groups": [group.pk], + "pacs_aet": "PACS", + "pacs_name": "Test PACS", + "pacs_link": "", + "patient_id": "P2", + "patient_birth_date": "1975-05-05", + "patient_sex": "F", + "study_description": "Study 2", + "study_datetime": "2024-01-02T00:00:00Z", + "study_instance_uid": "2.3.4.5", + "accession_number": "ACC2", + "modalities": ["MR"], + "metadata": {"ris_filename": "file2"}, + "body": "Report body 2", + }, + ] + + response = client.post( + "/api/reports/bulk-upsert/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 200 + assert response.json() == {"created": 2, "updated": 0, "invalid": 0} + + assert Report.objects.count() == 2 + assert Language.objects.filter(code="en").exists() + assert Language.objects.filter(code="de").exists() + assert Modality.objects.filter(code="CT").exists() + assert Modality.objects.filter(code="MR").exists() + + payload[0]["body"] = "Updated body" + payload[0]["metadata"] = {"ris_filename": "file1", "extra": "value"} + + response = client.post( + "/api/reports/bulk-upsert/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 200 + assert response.json() == {"created": 0, "updated": 2, "invalid": 0} + + report = Report.objects.get(document_id="DOC-1") + assert report.body == "Updated body" + assert Metadata.objects.filter(report=report).count() == 2 + + +@pytest.mark.django_db +def test_bulk_upsert_dedupes_payload_entries(client: Client): + user = UserFactory.create(is_active=True, is_staff=True) + group = GroupFactory.create() + user.groups.add(group) + _, token = Token.objects.create_token(user, "bulk upsert dedupe test", None) + + payload = [ + { + "document_id": "DOC-1", + "language": "en", + "groups": [group.pk, group.pk], + "pacs_aet": "PACS", + "pacs_name": "Test PACS", + "pacs_link": "", + "patient_id": "P1", + "patient_birth_date": "1980-01-01", + "patient_sex": "M", + "study_description": "Study 1", + "study_datetime": "2024-01-01T00:00:00Z", + "study_instance_uid": "1.2.3.4", + "accession_number": "ACC1", + "modalities": ["CT", "CT"], + "metadata": {"ris_filename": "file1", "extra": "value"}, + "body": "First version", + }, + { + "document_id": "DOC-1", + "language": "en", + "groups": [group.pk], + "pacs_aet": "PACS", + "pacs_name": "Test PACS", + "pacs_link": "", + "patient_id": "P1", + "patient_birth_date": "1980-01-01", + "patient_sex": "M", + "study_description": "Study 1", + "study_datetime": "2024-01-01T00:00:00Z", + "study_instance_uid": "1.2.3.4", + "accession_number": "ACC1", + "modalities": ["CT"], + "metadata": {"ris_filename": "file2", "extra": "value"}, + "body": "Second version", + }, + ] + + response = client.post( + "/api/reports/bulk-upsert/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 200 + assert response.json() == {"created": 1, "updated": 0, "invalid": 0} + + report = Report.objects.get(document_id="DOC-1") + assert report.body == "Second version" + assert report.modalities.count() == 1 + assert report.groups.count() == 1 + assert Metadata.objects.filter(report=report).count() == 2 + + +@pytest.mark.django_db +def test_bulk_upsert_dedupes_metadata_keys(): + group = GroupFactory.create() + + validated_reports = [ + { + "document_id": "DOC-1", + "language": {"code": "en"}, + "groups": [group], + "pacs_aet": "PACS", + "pacs_name": "Test PACS", + "pacs_link": "", + "patient_id": "P1", + "patient_birth_date": date(1980, 1, 1), + "patient_sex": "M", + "study_description": "Study 1", + "study_datetime": "2024-01-01T00:00:00Z", + "study_instance_uid": "1.2.3.4", + "accession_number": "ACC1", + "modalities": [{"code": "CT"}], + "metadata": [ + {"key": "ris_filename", "value": "file1"}, + {"key": "ris_filename", "value": "file2"}, + ], + "body": "Report body 1", + }, + ] + + created_ids, updated_ids = _bulk_upsert_reports(validated_reports) + assert created_ids == ["DOC-1"] + assert updated_ids == [] + + report = Report.objects.get(document_id="DOC-1") + metadata = Metadata.objects.get(report=report, key="ris_filename") + assert metadata.value == "file2" diff --git a/radis/settings/base.py b/radis/settings/base.py index 1f9b1c6b..e838bcf0 100644 --- a/radis/settings/base.py +++ b/radis/settings/base.py @@ -154,6 +154,11 @@ database_url = f"postgres://postgres:postgres@localhost:{postgres_dev_port}/postgres" DATABASES = {"default": env.dj_db_url("DATABASE_URL", default=database_url)} +# pgsearch indexing tuning (bulk upsert/backfill) +PGSEARCH_BULK_INDEX_CHUNK_SIZE = env.int("PGSEARCH_BULK_INDEX_CHUNK_SIZE", default=5000) +PGSEARCH_BULK_INSERT_BATCH_SIZE = env.int("PGSEARCH_BULK_INSERT_BATCH_SIZE", default=1000) +PGSEARCH_SYNC_INDEXING = env.bool("PGSEARCH_SYNC_INDEXING", default=False) + # Default primary key field type # https://docs.djangoproject.com/en/5.0/ref/settings/#default-auto-field DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"