diff --git a/example.env b/example.env index 687dc324..7eb81355 100644 --- a/example.env +++ b/example.env @@ -109,6 +109,10 @@ REPORT_LLM_PROVIDER_URL="http://host.docker.internal:11434/v1" # 'cli generate-example-reports'. REPORT_LLM_PROVIDER_API_KEY="ollama" +# Labels +# Automatically enqueue a backfill when a new label question is created. +LABELS_AUTO_BACKFILL_ON_NEW_QUESTION=true + # Docker swarm mode does not respect the Docker Proxy client configuration # (see https://docs.docker.com/network/proxy/#configure-the-docker-client), # but we can set those environment variables manually. diff --git a/radis/core/static/core/core.js b/radis/core/static/core/core.js index cfff69a0..034a425b 100644 --- a/radis/core/static/core/core.js +++ b/radis/core/static/core/core.js @@ -70,10 +70,18 @@ function FormSet(rootEl) { console.log(this.formCount); }, addForm() { + if (!template || !container || !totalForms) { + return; + } const newForm = template.content.cloneNode(true); const idx = totalForms.value; container.append(newForm); - const lastForm = container.querySelector(".formset-form:last-child"); + const lastForm = + container.querySelector(".formset-form:last-child") ?? + container.querySelector("c-formset-form:last-child"); + if (!lastForm) { + return; + } lastForm.innerHTML = lastForm.innerHTML.replace(/__prefix__/g, idx); totalForms.value = (parseInt(idx) + 1).toString(); this.formCount = parseInt(totalForms.value); @@ -82,7 +90,12 @@ function FormSet(rootEl) { * @param {HTMLElement} btnEl - The delete button element that was clicked */ removeForm(btnEl) { - btnEl.closest(".formset-form").remove(); + const formEl = + btnEl.closest(".formset-form") ?? btnEl.closest("c-formset-form"); + if (!formEl) { + return; + } + formEl.remove(); const idx = totalForms.value; totalForms.value = (parseInt(idx) - 1).toString(); this.formCount = parseInt(totalForms.value); diff --git a/radis/core/templates/cotton/formset.html b/radis/core/templates/cotton/formset.html index 20c6af7b..d4edff6b 100644 --- a/radis/core/templates/cotton/formset.html +++ b/radis/core/templates/cotton/formset.html @@ -7,7 +7,7 @@ {% crispy formset.empty_form %}
- {% for form in formset %}{{ form|crispy }}{% endfor %} + {% for form in formset %}{% crispy form %}{% endfor %}
{% if add_form_label %}
diff --git a/radis/labels/__init__.py b/radis/labels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/radis/labels/admin.py b/radis/labels/admin.py new file mode 100644 index 00000000..34429366 --- /dev/null +++ b/radis/labels/admin.py @@ -0,0 +1,47 @@ +from django.contrib import admin + +from .models import LabelBackfillJob, LabelChoice, LabelGroup, LabelQuestion, ReportLabel + + +class LabelChoiceInline(admin.TabularInline): + model = LabelChoice + extra = 0 + + +@admin.register(LabelGroup) +class LabelGroupAdmin(admin.ModelAdmin): + list_display = ("name", "is_active", "order") + list_filter = ("is_active",) + search_fields = ("name",) + ordering = ("order", "name") + + +@admin.register(LabelQuestion) +class LabelQuestionAdmin(admin.ModelAdmin): + list_display = ("label", "question", "group", "is_active", "order") + list_filter = ("group", "is_active") + search_fields = ("label", "question") + ordering = ("group__order", "order", "label") + inlines = (LabelChoiceInline,) + + +@admin.register(ReportLabel) +class ReportLabelAdmin(admin.ModelAdmin): + list_display = ("report", "question", "choice", "confidence", "verified", "created_at") + list_filter = ("verified", "question__group") + search_fields = ("report__document_id", "question__name", "choice__label") + ordering = ("-created_at",) + + +@admin.register(LabelBackfillJob) +class LabelBackfillJobAdmin(admin.ModelAdmin): + list_display = ( + "id", + "label_group", + "status", + "processed_reports", + "total_reports", + "created_at", + ) + list_filter = ("status",) + ordering = ("-created_at",) diff --git a/radis/labels/apps.py b/radis/labels/apps.py new file mode 100644 index 00000000..7b84a671 --- /dev/null +++ b/radis/labels/apps.py @@ -0,0 +1,42 @@ +from django.apps import AppConfig + + +class LabelsConfig(AppConfig): + name = "radis.labels" + + def ready(self) -> None: + register_app() + + from radis.reports.site import ( + ReportsCreatedHandler, + ReportsUpdatedHandler, + register_reports_created_handler, + register_reports_updated_handler, + ) + + from . import signals # noqa: F401 + from .site import handle_reports_created, handle_reports_updated + + register_reports_created_handler( + ReportsCreatedHandler( + name="Labels", + handle=handle_reports_created, + ) + ) + register_reports_updated_handler( + ReportsUpdatedHandler( + name="Labels", + handle=handle_reports_updated, + ) + ) + + +def register_app() -> None: + from adit_radis_shared.common.site import MainMenuItem, register_main_menu_item + + register_main_menu_item( + MainMenuItem( + url_name="label_group_list", + label="Auto Labels", + ) + ) diff --git a/radis/labels/constants.py b/radis/labels/constants.py new file mode 100644 index 00000000..ad213cc8 --- /dev/null +++ b/radis/labels/constants.py @@ -0,0 +1,5 @@ +DEFAULT_LABEL_CHOICES = [ + {"value": "yes", "label": "Yes", "is_unknown": False, "order": 1}, + {"value": "no", "label": "No", "is_unknown": False, "order": 2}, + {"value": "cannot_decide", "label": "Cannot decide", "is_unknown": True, "order": 3}, +] diff --git a/radis/labels/forms.py b/radis/labels/forms.py new file mode 100644 index 00000000..915a9138 --- /dev/null +++ b/radis/labels/forms.py @@ -0,0 +1,66 @@ +from crispy_forms.helper import FormHelper +from crispy_forms.layout import Column, Layout, Row +from django import forms + +from .models import LabelGroup, LabelQuestion + + +class LabelGroupForm(forms.ModelForm): + class Meta: + model = LabelGroup + fields = [ + "name", + "description", + "is_active", + "order", + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.helper = FormHelper() + self.helper.form_tag = False + self.helper.layout = Layout( + Row( + Column("name", "description"), + Column("is_active", "order", css_class="col-3"), + ) + ) + + +class LabelQuestionForm(forms.ModelForm): + class Meta: + model = LabelQuestion + fields = [ + "label", + "question", + "is_active", + "order", + ] + + def __init__(self, *args, **kwargs): + self.group = kwargs.pop("group", None) + super().__init__(*args, **kwargs) + + self.helper = FormHelper() + self.helper.form_tag = False + self.helper.layout = Layout( + "label", + "question", + Row(Column("is_active"), Column("order", css_class="col-3")), + ) + + self.fields["question"].required = False + self.fields["question"].help_text = "Optional. If left empty, the label is used." + + def clean_label(self): + label = self.cleaned_data.get("label", "") + if not label or not self.group: + return label + + existing = LabelQuestion.objects.filter(group=self.group, label__iexact=label) + if self.instance and self.instance.pk: + existing = existing.exclude(pk=self.instance.pk) + if existing.exists(): + raise forms.ValidationError("A question with this label already exists in this group.") + return label diff --git a/radis/labels/management/__init__.py b/radis/labels/management/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/radis/labels/management/commands/__init__.py b/radis/labels/management/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/radis/labels/management/commands/labels_backfill.py b/radis/labels/management/commands/labels_backfill.py new file mode 100644 index 00000000..b5a10b9b --- /dev/null +++ b/radis/labels/management/commands/labels_backfill.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from itertools import batched + +from django.core.management.base import BaseCommand, CommandError +from django.utils import timezone + +from radis.reports.models import Report + +from ...models import LabelBackfillJob, LabelGroup +from ...tasks import process_label_group + + +class Command(BaseCommand): + help = "Enqueue labeling tasks for existing reports." + + def add_arguments(self, parser): + parser.add_argument( + "--group", + dest="group", + help="Label group name or ID. If omitted, all active groups are used.", + ) + parser.add_argument( + "--batch-size", + dest="batch_size", + type=int, + default=None, + help="Override the task batch size.", + ) + parser.add_argument( + "--limit", + dest="limit", + type=int, + default=None, + help="Limit the number of reports to enqueue.", + ) + + def handle(self, *args, **options): + group_value = options.get("group") + batch_size = options.get("batch_size") + limit = options.get("limit") + + if group_value: + group = self._get_group(group_value) + groups = [group] + else: + groups = list(LabelGroup.objects.filter(is_active=True)) + + if not groups: + self.stdout.write(self.style.WARNING("No active label groups found.")) + return + + report_ids = Report.objects.order_by("id").values_list("id", flat=True) + if limit: + report_ids = report_ids[:limit] + report_ids = list(report_ids) + + if not report_ids: + self.stdout.write(self.style.WARNING("No reports found.")) + return + + if batch_size is None: + from django.conf import settings + + batch_size = settings.LABELING_TASK_BATCH_SIZE + + for group in groups: + backfill_job = LabelBackfillJob.objects.create( + label_group=group, + status=LabelBackfillJob.Status.IN_PROGRESS, + started_at=timezone.now(), + total_reports=len(report_ids), + ) + + for report_batch in batched(report_ids, batch_size): + process_label_group.defer( + label_group_id=group.id, + report_ids=list(report_batch), + backfill_job_id=backfill_job.id, + ) + + self.stdout.write( + self.style.SUCCESS( + f"Enqueued labeling for {len(report_ids)} reports " + f"in group '{group.name}' (backfill job #{backfill_job.id})." + ) + ) + + def _get_group(self, value: str) -> LabelGroup: + if value.isdigit(): + group = LabelGroup.objects.filter(id=int(value)).first() + else: + matches = LabelGroup.objects.filter(name=value) + if matches.count() > 1: + raise CommandError( + f"Multiple label groups named '{value}' exist. Use the numeric ID." + ) + group = matches.first() + + if not group: + raise CommandError(f"Label group '{value}' not found.") + + return group diff --git a/radis/labels/management/commands/labels_seed.py b/radis/labels/management/commands/labels_seed.py new file mode 100644 index 00000000..dd6f450e --- /dev/null +++ b/radis/labels/management/commands/labels_seed.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import json +from pathlib import Path + +from django.core.management.base import BaseCommand, CommandError + +from ...models import LabelChoice, LabelGroup, LabelQuestion + + +class Command(BaseCommand): + help = "Seed label groups, questions, and choices from a JSON file." + + def add_arguments(self, parser): + parser.add_argument( + "--file", + dest="file", + default="resources/labels/seed.json", + help="Path to the seed JSON file.", + ) + + def handle(self, *args, **options): + seed_path = Path(options["file"]).resolve() + if not seed_path.exists(): + raise CommandError(f"Seed file not found: {seed_path}") + + payload = json.loads(seed_path.read_text()) + groups = payload.get("groups", []) + if not groups: + self.stdout.write(self.style.WARNING("No groups found in seed file.")) + return + + for group_data in groups: + group = self._upsert_group(group_data) + for question_data in group_data.get("questions", []): + question = self._upsert_question(group, question_data) + if question_data.get("choices"): + self._upsert_choice(question, {}) + + self.stdout.write(self.style.SUCCESS("Label seed import completed.")) + + def _upsert_group(self, data: dict) -> LabelGroup: + name = data.get("name") + if not name: + raise CommandError("Label group requires 'name'.") + + groups = LabelGroup.objects.filter(name=name) + if groups.count() > 1: + raise CommandError( + f"Multiple label groups named '{name}' exist. Use unique names before seeding." + ) + + defaults = { + "description": data.get("description", ""), + "is_active": data.get("is_active", True), + "order": data.get("order", 0), + } + + if groups.exists(): + group = groups.first() + if group is None: + raise CommandError( + f"Label group '{name}' could not be resolved. Try seeding again." + ) + for key, value in defaults.items(): + setattr(group, key, value) + group.save(update_fields=list(defaults.keys())) + return group + + return LabelGroup.objects.create(name=name, **defaults) + + def _upsert_question(self, group: LabelGroup, data: dict) -> LabelQuestion: + label = data.get("label") + question_text = data.get("question", "") + if not label: + raise CommandError("Label question requires 'label'.") + + question, _ = LabelQuestion.objects.update_or_create( + group=group, + label=label, + defaults={ + "question": question_text, + "is_active": data.get("is_active", True), + "order": data.get("order", 0), + }, + ) + return question + + def _upsert_choice(self, question: LabelQuestion, data: dict) -> LabelChoice: + raise CommandError( + "Custom choices are not supported. Labels use fixed Yes/No/Cannot decide choices." + ) diff --git a/radis/labels/migrations/0001_initial.py b/radis/labels/migrations/0001_initial.py new file mode 100644 index 00000000..cba93480 --- /dev/null +++ b/radis/labels/migrations/0001_initial.py @@ -0,0 +1,96 @@ +# Generated by Django 6.0.1 on 2026-02-06 13:36 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('reports', '0013_alter_report_options'), + ] + + operations = [ + migrations.CreateModel( + name='LabelGroup', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ('slug', models.SlugField(unique=True)), + ('description', models.TextField(blank=True, default='')), + ('is_active', models.BooleanField(default=True)), + ('order', models.PositiveSmallIntegerField(default=0)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ], + options={ + 'ordering': ['order', 'name'], + }, + ), + migrations.CreateModel( + name='LabelQuestion', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ('question', models.CharField(max_length=300)), + ('description', models.CharField(blank=True, default='', max_length=300)), + ('is_active', models.BooleanField(default=True)), + ('order', models.PositiveSmallIntegerField(default=0)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('group', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='questions', to='labels.labelgroup')), + ], + options={ + 'ordering': ['order', 'name'], + }, + ), + migrations.CreateModel( + name='LabelChoice', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('value', models.CharField(max_length=50)), + ('label', models.CharField(max_length=100)), + ('is_unknown', models.BooleanField(default=False)), + ('order', models.PositiveSmallIntegerField(default=0)), + ('question', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='choices', to='labels.labelquestion')), + ], + options={ + 'ordering': ['order', 'label'], + }, + ), + migrations.CreateModel( + name='ReportLabel', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('confidence', models.FloatField(blank=True, null=True)), + ('rationale', models.TextField(blank=True, default='')), + ('verified', models.BooleanField(default=False)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('choice', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='report_labels', to='labels.labelchoice')), + ('question', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='report_labels', to='labels.labelquestion')), + ('report', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='labels', to='reports.report')), + ], + options={ + 'ordering': ['-created_at'], + }, + ), + migrations.AddConstraint( + model_name='labelquestion', + constraint=models.UniqueConstraint(fields=('group', 'name'), name='unique_label_question_name_per_group'), + ), + migrations.AddConstraint( + model_name='labelchoice', + constraint=models.UniqueConstraint(fields=('question', 'value'), name='unique_label_choice_value_per_question'), + ), + migrations.AddIndex( + model_name='reportlabel', + index=models.Index(fields=['report', 'question'], name='labels_repo_report__1b8d1d_idx'), + ), + migrations.AddConstraint( + model_name='reportlabel', + constraint=models.UniqueConstraint(fields=('report', 'question'), name='unique_report_label_per_question'), + ), + ] diff --git a/radis/labels/migrations/0002_remove_labelquestion_description.py b/radis/labels/migrations/0002_remove_labelquestion_description.py new file mode 100644 index 00000000..38a1882d --- /dev/null +++ b/radis/labels/migrations/0002_remove_labelquestion_description.py @@ -0,0 +1,17 @@ +# Generated by Django 6.0.1 on 2026-02-08 19:26 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('labels', '0001_initial'), + ] + + operations = [ + migrations.RemoveField( + model_name='labelquestion', + name='description', + ), + ] diff --git a/radis/labels/migrations/0003_alter_labelquestion_options_and_more.py b/radis/labels/migrations/0003_alter_labelquestion_options_and_more.py new file mode 100644 index 00000000..6ed8db58 --- /dev/null +++ b/radis/labels/migrations/0003_alter_labelquestion_options_and_more.py @@ -0,0 +1,29 @@ +# Generated by Django 6.0.1 on 2026-02-08 19:48 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('labels', '0002_remove_labelquestion_description'), + ] + + operations = [ + migrations.AlterModelOptions( + name='labelquestion', + options={'ordering': ['order', 'question']}, + ), + migrations.RemoveConstraint( + model_name='labelquestion', + name='unique_label_question_name_per_group', + ), + migrations.AddConstraint( + model_name='labelquestion', + constraint=models.UniqueConstraint(fields=('group', 'question'), name='unique_label_question_per_group'), + ), + migrations.RemoveField( + model_name='labelquestion', + name='name', + ), + ] diff --git a/radis/labels/migrations/0004_add_label_to_labelquestion.py b/radis/labels/migrations/0004_add_label_to_labelquestion.py new file mode 100644 index 00000000..4a93917a --- /dev/null +++ b/radis/labels/migrations/0004_add_label_to_labelquestion.py @@ -0,0 +1,40 @@ +# Generated by Codex on 2026-02-08 + +from django.db import migrations, models + + +def populate_label_from_question(apps, schema_editor): + LabelQuestion = apps.get_model("labels", "LabelQuestion") + LabelQuestion.objects.filter(label="").update(label=models.F("question")) + + +class Migration(migrations.Migration): + + dependencies = [ + ("labels", "0003_alter_labelquestion_options_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="labelquestion", + name="label", + field=models.CharField(default="", max_length=200), + preserve_default=False, + ), + migrations.RunPython(populate_label_from_question, migrations.RunPython.noop), + migrations.RemoveConstraint( + model_name="labelquestion", + name="unique_label_question_per_group", + ), + migrations.AddConstraint( + model_name="labelquestion", + constraint=models.UniqueConstraint( + fields=("group", "label"), + name="unique_label_question_label_per_group", + ), + ), + migrations.AlterModelOptions( + name="labelquestion", + options={"ordering": ["order", "label"]}, + ), + ] diff --git a/radis/labels/migrations/0005_remove_labelgroup_slug.py b/radis/labels/migrations/0005_remove_labelgroup_slug.py new file mode 100644 index 00000000..773703a0 --- /dev/null +++ b/radis/labels/migrations/0005_remove_labelgroup_slug.py @@ -0,0 +1,14 @@ +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("labels", "0004_add_label_to_labelquestion"), + ] + + operations = [ + migrations.RemoveField( + model_name="labelgroup", + name="slug", + ), + ] diff --git a/radis/labels/migrations/0006_labelbackfilljob.py b/radis/labels/migrations/0006_labelbackfilljob.py new file mode 100644 index 00000000..0c1a8e80 --- /dev/null +++ b/radis/labels/migrations/0006_labelbackfilljob.py @@ -0,0 +1,58 @@ +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("labels", "0005_remove_labelgroup_slug"), + ] + + operations = [ + migrations.CreateModel( + name="LabelBackfillJob", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "status", + models.CharField( + choices=[ + ("PE", "Pending"), + ("IP", "In Progress"), + ("CI", "Canceling"), + ("CA", "Canceled"), + ("SU", "Success"), + ("FA", "Failure"), + ], + default="PE", + max_length=2, + ), + ), + ("total_reports", models.PositiveIntegerField(default=0)), + ("processed_reports", models.PositiveIntegerField(default=0)), + ("message", models.TextField(blank=True, default="")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("started_at", models.DateTimeField(blank=True, null=True)), + ("ended_at", models.DateTimeField(blank=True, null=True)), + ( + "label_group", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="backfill_jobs", + to="labels.labelgroup", + ), + ), + ], + options={ + "ordering": ["-created_at"], + }, + ), + ] diff --git a/radis/labels/migrations/__init__.py b/radis/labels/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/radis/labels/models.py b/radis/labels/models.py new file mode 100644 index 00000000..04a3027e --- /dev/null +++ b/radis/labels/models.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from typing import Callable + +from django.db import models + +from radis.reports.models import Report + + +class LabelBackfillJob(models.Model): + class Status(models.TextChoices): + PENDING = "PE", "Pending" + IN_PROGRESS = "IP", "In Progress" + CANCELING = "CI", "Canceling" + CANCELED = "CA", "Canceled" + SUCCESS = "SU", "Success" + FAILURE = "FA", "Failure" + + id: int + label_group_id: int + label_group = models.ForeignKey( + "LabelGroup", on_delete=models.CASCADE, related_name="backfill_jobs" + ) + status = models.CharField(max_length=2, choices=Status.choices, default=Status.PENDING) + get_status_display: Callable[[], str] + total_reports = models.PositiveIntegerField(default=0) + processed_reports = models.PositiveIntegerField(default=0) + message = models.TextField(blank=True, default="") + created_at = models.DateTimeField(auto_now_add=True) + started_at = models.DateTimeField(null=True, blank=True) + ended_at = models.DateTimeField(null=True, blank=True) + + class Meta: + ordering = ["-created_at"] + + def __str__(self) -> str: + return f"LabelBackfillJob [{self.pk}]" + + @property + def is_cancelable(self) -> bool: + return self.status in [ + self.Status.PENDING, + self.Status.IN_PROGRESS, + ] + + @property + def is_active(self) -> bool: + return self.status in [ + self.Status.PENDING, + self.Status.IN_PROGRESS, + ] + + @property + def progress_percent(self) -> int: + if self.total_reports == 0: + return 0 + return min(int((self.processed_reports / self.total_reports) * 100), 100) + + +class LabelGroup(models.Model): + id: int + name = models.CharField(max_length=100) + description = models.TextField(blank=True, default="") + is_active = models.BooleanField(default=True) + order = models.PositiveSmallIntegerField(default=0) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + questions: models.QuerySet["LabelQuestion"] + + class Meta: + ordering = ["order", "name"] + + def __str__(self) -> str: + return f"LabelGroup {self.name} [{self.pk}]" + + +class LabelQuestion(models.Model): + id: int + group_id: int + group = models.ForeignKey[LabelGroup]( + LabelGroup, on_delete=models.CASCADE, related_name="questions" + ) + label = models.CharField(max_length=200) + question = models.CharField(max_length=300, blank=True, default="") + is_active = models.BooleanField(default=True) + order = models.PositiveSmallIntegerField(default=0) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + choices: models.QuerySet["LabelChoice"] + get_name_display: Callable[[], str] + + class Meta: + ordering = ["order", "label"] + constraints = [ + models.UniqueConstraint( + fields=["group", "label"], + name="unique_label_question_label_per_group", + ) + ] + + def __str__(self) -> str: + return f"LabelQuestion {self.label} [{self.pk}]" + + def save(self, *args, **kwargs) -> None: + if not self.question: + self.question = self.label + super().save(*args, **kwargs) + + +class LabelChoice(models.Model): + id: int + question = models.ForeignKey[LabelQuestion]( + LabelQuestion, on_delete=models.CASCADE, related_name="choices" + ) + value = models.CharField(max_length=50) + label = models.CharField(max_length=100) + is_unknown = models.BooleanField(default=False) + order = models.PositiveSmallIntegerField(default=0) + + get_label_display: Callable[[], str] + + class Meta: + ordering = ["order", "label"] + constraints = [ + models.UniqueConstraint( + fields=["question", "value"], + name="unique_label_choice_value_per_question", + ) + ] + + def __str__(self) -> str: + return f"LabelChoice {self.label} [{self.pk}]" + + +class ReportLabel(models.Model): + report_id: int + question_id: int + report = models.ForeignKey[Report](Report, on_delete=models.CASCADE, related_name="labels") + question = models.ForeignKey[LabelQuestion]( + LabelQuestion, on_delete=models.CASCADE, related_name="report_labels" + ) + choice = models.ForeignKey[LabelChoice]( + LabelChoice, on_delete=models.PROTECT, related_name="report_labels" + ) + confidence = models.FloatField(null=True, blank=True) + rationale = models.TextField(blank=True, default="") + verified = models.BooleanField(default=False) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["-created_at"] + constraints = [ + models.UniqueConstraint( + fields=["report", "question"], + name="unique_report_label_per_question", + ) + ] + indexes = [ + models.Index(fields=["report", "question"]), + ] + + def __str__(self) -> str: + return f"ReportLabel report={self.report_id} question={self.question_id} [{self.pk}]" diff --git a/radis/labels/processors.py b/radis/labels/processors.py new file mode 100644 index 00000000..b60f5e8e --- /dev/null +++ b/radis/labels/processors.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import logging +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from string import Template + +from django import db +from django.conf import settings +from django.db.models import Prefetch + +from radis.chats.utils.chat_client import ChatClient +from radis.reports.models import Report + +from .models import LabelChoice, LabelGroup, LabelQuestion, ReportLabel +from .utils.processor_utils import generate_labeling_schema, generate_questions_for_prompt + +logger = logging.getLogger(__name__) + + +class LabelGroupProcessor: + def __init__(self, group: LabelGroup) -> None: + self.group = group + self.client = ChatClient() + + def process_reports(self, report_ids: list[int], overwrite_existing: bool = False) -> None: + if not report_ids: + return + + questions = list(self.group.questions.filter(is_active=True).prefetch_related("choices")) + if not questions: + logger.info("No active label questions for group %s", self.group) + return + + choice_maps, unknown_choices = _build_choice_maps(questions) + + labels_qs = ReportLabel.objects.filter(question__group=self.group) + reports = ( + Report.objects.filter(id__in=report_ids) + .prefetch_related(Prefetch("labels", queryset=labels_qs, to_attr="labels_for_group")) + .only("id", "body") + ) + + with ThreadPoolExecutor(max_workers=settings.LABELING_LLM_CONCURRENCY_LIMIT) as executor: + futures: list[Future] = [] + future_report_ids: dict[Future, int] = {} + try: + for report in reports: + future = executor.submit( + self._process_report, + report, + questions, + choice_maps, + unknown_choices, + overwrite_existing, + ) + futures.append(future) + future_report_ids[future] = report.id + + for future in as_completed(futures): + try: + future.result() + except Exception: + report_id = future_report_ids.get(future) + logger.exception( + "Labeling failed for report %s in group %s", report_id, self.group + ) + finally: + db.close_old_connections() + + def _process_report( + self, + report: Report, + questions: list[LabelQuestion], + choice_maps: dict[int, dict[str, LabelChoice]], + unknown_choices: dict[int, LabelChoice | None], + overwrite_existing: bool, + ) -> None: + if overwrite_existing: + missing_questions = questions + else: + labels_for_group = getattr(report, "labels_for_group", []) + existing_question_ids = {label.question_id for label in labels_for_group} + missing_questions = [ + question for question in questions if question.id not in existing_question_ids + ] + missing_questions = [ + question for question in missing_questions if choice_maps.get(question.id) + ] + if not missing_questions: + return + + schema = generate_labeling_schema(missing_questions) + prompt = Template(settings.LABELS_SYSTEM_PROMPT).substitute( + { + "report": report.body, + "questions": generate_questions_for_prompt(missing_questions), + } + ) + + result = self.client.extract_data(prompt.strip(), schema) + + for index, question in enumerate(missing_questions): + field_name = f"question_{index}" + answer = getattr(result, field_name) + choice = _resolve_choice( + answer.choice, + choice_maps[question.id], + unknown_choices.get(question.id), + ) + confidence = _normalize_confidence(answer.confidence) + rationale = (answer.rationale or "").strip() + + ReportLabel.objects.update_or_create( + report=report, + question=question, + defaults={ + "choice": choice, + "confidence": confidence, + "rationale": rationale, + "verified": False, + }, + ) + + db.close_old_connections() + + +def _build_choice_maps( + questions: list[LabelQuestion], +) -> tuple[dict[int, dict[str, LabelChoice]], dict[int, LabelChoice | None]]: + choice_maps: dict[int, dict[str, LabelChoice]] = {} + unknown_choices: dict[int, LabelChoice | None] = {} + + for question in questions: + choices = list(question.choices.all()) + if not choices: + logger.warning("LabelQuestion %s has no choices, skipping.", question) + choice_maps[question.id] = {} + unknown_choices[question.id] = None + continue + choice_maps[question.id] = {choice.value: choice for choice in choices} + unknown_choice = next((choice for choice in choices if choice.is_unknown), None) + unknown_choices[question.id] = unknown_choice + + return choice_maps, unknown_choices + + +def _resolve_choice( + value: str, + choices: dict[str, LabelChoice], + unknown_choice: LabelChoice | None, +) -> LabelChoice: + choice = choices.get(value) + if choice is not None: + return choice + if unknown_choice is not None: + return unknown_choice + return next(iter(choices.values())) + + +def _normalize_confidence(confidence: float | None) -> float | None: + if confidence is None: + return None + if confidence < 0: + return 0.0 + if confidence > 1: + return 1.0 + return confidence diff --git a/radis/labels/signals.py b/radis/labels/signals.py new file mode 100644 index 00000000..e10cdbdd --- /dev/null +++ b/radis/labels/signals.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import logging + +from django.conf import settings +from django.db import transaction +from django.db.models.signals import post_save +from django.dispatch import receiver + +from .constants import DEFAULT_LABEL_CHOICES +from .models import LabelBackfillJob, LabelChoice, LabelQuestion +from .tasks import enqueue_label_group_backfill + +logger = logging.getLogger(__name__) + + +@receiver(post_save, sender=LabelQuestion) +def enqueue_backfill_for_new_question( + sender, instance: LabelQuestion, created: bool, **kwargs +) -> None: + if not created: + return + if not instance.is_active: + return + if not settings.LABELS_AUTO_BACKFILL_ON_NEW_QUESTION: + return + + # Dedup: skip if there is already an active backfill for this group + active_exists = LabelBackfillJob.objects.filter( + label_group_id=instance.group_id, + status__in=[LabelBackfillJob.Status.PENDING, LabelBackfillJob.Status.IN_PROGRESS], + ).exists() + + if active_exists: + logger.info( + "Skipping backfill for group %s — active backfill already exists.", + instance.group_id, + ) + return + + backfill_job = LabelBackfillJob.objects.create(label_group_id=instance.group_id) + + transaction.on_commit( + lambda: enqueue_label_group_backfill.defer( + label_group_id=instance.group_id, + backfill_job_id=backfill_job.id, + ) + ) + + +@receiver(post_save, sender=LabelQuestion) +def ensure_default_choices(sender, instance: LabelQuestion, created: bool, **kwargs) -> None: + if not created: + return + if instance.choices.exists(): + return + + choices = [ + LabelChoice( + question=instance, + value=choice["value"], + label=choice["label"], + is_unknown=choice["is_unknown"], + order=choice["order"], + ) + for choice in DEFAULT_LABEL_CHOICES + ] + LabelChoice.objects.bulk_create(choices) diff --git a/radis/labels/site.py b/radis/labels/site.py new file mode 100644 index 00000000..f05f4b9c --- /dev/null +++ b/radis/labels/site.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Iterable + +from django.db import transaction + +from radis.reports.models import Report + +from .tasks import enqueue_labeling_for_reports + + +def handle_reports_created(reports: Iterable[Report]) -> None: + report_ids = [int(getattr(report, "id")) for report in reports] + if not report_ids: + return + + def on_commit() -> None: + enqueue_labeling_for_reports(report_ids) + + transaction.on_commit(on_commit) + + +def handle_reports_updated(reports: Iterable[Report]) -> None: + report_ids = [int(getattr(report, "id")) for report in reports] + if not report_ids: + return + + def on_commit() -> None: + enqueue_labeling_for_reports(report_ids, overwrite_existing=True) + + transaction.on_commit(on_commit) diff --git a/radis/labels/static/labels/labels.css b/radis/labels/static/labels/labels.css new file mode 100644 index 00000000..e69de29b diff --git a/radis/labels/tables.py b/radis/labels/tables.py new file mode 100644 index 00000000..e7d50f29 --- /dev/null +++ b/radis/labels/tables.py @@ -0,0 +1,19 @@ +import django_tables2 as tables +from django_tables2.utils import A + +from .models import LabelGroup + + +class LabelGroupTable(tables.Table): + name = tables.LinkColumn( + viewname="label_group_detail", + args=[A("pk")], + attrs={"td": {"class": "w-100"}}, + ) + + class Meta: + model = LabelGroup + fields = ("name", "is_active", "order") + order_by = ("order", "name") + empty_text = "No label groups found" + attrs = {"class": "table table-bordered table-hover"} diff --git a/radis/labels/tasks.py b/radis/labels/tasks.py new file mode 100644 index 00000000..e897bd35 --- /dev/null +++ b/radis/labels/tasks.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import logging +from itertools import batched + +from django.conf import settings +from django.db.models import F +from django.utils import timezone +from procrastinate.contrib.django import app + +from radis.reports.models import Report + +from .models import LabelBackfillJob, LabelGroup +from .processors import LabelGroupProcessor + +logger = logging.getLogger(__name__) + + +@app.task(queue="llm") +def process_label_group( + label_group_id: int, + report_ids: list[int], + overwrite_existing: bool = False, + backfill_job_id: int | None = None, +) -> None: + # If this is part of a backfill, check cancellation status before doing any work + if backfill_job_id is not None: + try: + backfill_job = LabelBackfillJob.objects.get(id=backfill_job_id) + except LabelBackfillJob.DoesNotExist: + logger.warning("Backfill job %s not found, skipping batch.", backfill_job_id) + return + + if backfill_job.status in ( + LabelBackfillJob.Status.CANCELING, + LabelBackfillJob.Status.CANCELED, + ): + logger.info( + "Backfill job %s is %s, skipping batch.", + backfill_job, + backfill_job.get_status_display(), + ) + _increment_and_maybe_finalize(backfill_job_id, len(report_ids)) + return + + group = LabelGroup.objects.get(id=label_group_id) + processor = LabelGroupProcessor(group) + processor.process_reports(report_ids, overwrite_existing=overwrite_existing) + + # After processing, update backfill progress + if backfill_job_id is not None: + _increment_and_maybe_finalize(backfill_job_id, len(report_ids)) + + +def _increment_and_maybe_finalize(backfill_job_id: int, count: int) -> None: + """Atomically increment processed_reports and check for completion.""" + LabelBackfillJob.objects.filter(id=backfill_job_id).update( + processed_reports=F("processed_reports") + count + ) + + try: + backfill_job = LabelBackfillJob.objects.get(id=backfill_job_id) + except LabelBackfillJob.DoesNotExist: + return + + if backfill_job.processed_reports >= backfill_job.total_reports: + if backfill_job.status == LabelBackfillJob.Status.CANCELING: + backfill_job.status = LabelBackfillJob.Status.CANCELED + elif backfill_job.status == LabelBackfillJob.Status.IN_PROGRESS: + backfill_job.status = LabelBackfillJob.Status.SUCCESS + backfill_job.ended_at = timezone.now() + backfill_job.save() + + +def enqueue_labeling_for_reports( + report_ids: list[int], + groups: list[LabelGroup] | None = None, + overwrite_existing: bool = False, +) -> None: + if not report_ids: + return + + if groups is None: + active_groups = list(LabelGroup.objects.filter(is_active=True)) + else: + active_groups = [group for group in groups if group.is_active] + if not active_groups: + logger.info("No active label groups found, skipping labeling for reports %s", report_ids) + return + + batch_size = settings.LABELING_TASK_BATCH_SIZE + for group in active_groups: + for report_batch in batched(report_ids, batch_size): + process_label_group.defer( + label_group_id=group.id, + report_ids=list(report_batch), + overwrite_existing=overwrite_existing, + ) + + +@app.task +def enqueue_label_group_backfill(label_group_id: int, backfill_job_id: int) -> None: + group = LabelGroup.objects.get(id=label_group_id) + if not group.is_active: + logger.info("Label group %s is inactive. Skipping backfill.", group) + try: + backfill_job = LabelBackfillJob.objects.get(id=backfill_job_id) + backfill_job.status = LabelBackfillJob.Status.CANCELED + backfill_job.message = "Label group is inactive." + backfill_job.ended_at = timezone.now() + backfill_job.save() + except LabelBackfillJob.DoesNotExist: + pass + return + + try: + backfill_job = LabelBackfillJob.objects.get(id=backfill_job_id) + except LabelBackfillJob.DoesNotExist: + logger.warning("Backfill job %s not found, aborting.", backfill_job_id) + return + + # Count total reports + total_reports = Report.objects.count() + backfill_job.status = LabelBackfillJob.Status.IN_PROGRESS + backfill_job.started_at = timezone.now() + backfill_job.total_reports = total_reports + backfill_job.save() + + if total_reports == 0: + backfill_job.status = LabelBackfillJob.Status.SUCCESS + backfill_job.message = "No reports to process." + backfill_job.ended_at = timezone.now() + backfill_job.save() + return + + batch_size = settings.LABELING_TASK_BATCH_SIZE + current_batch: list[int] = [] + report_ids = ( + Report.objects.order_by("id").values_list("id", flat=True).iterator(chunk_size=batch_size) + ) + + for report_id in report_ids: + current_batch.append(report_id) + if len(current_batch) >= batch_size: + process_label_group.defer( + label_group_id=group.id, + report_ids=current_batch, + overwrite_existing=False, + backfill_job_id=backfill_job.id, + ) + current_batch = [] + + if current_batch: + process_label_group.defer( + label_group_id=group.id, + report_ids=current_batch, + overwrite_existing=False, + backfill_job_id=backfill_job.id, + ) diff --git a/radis/labels/templates/labels/label_group_confirm_delete.html b/radis/labels/templates/labels/label_group_confirm_delete.html new file mode 100644 index 00000000..046d2351 --- /dev/null +++ b/radis/labels/templates/labels/label_group_confirm_delete.html @@ -0,0 +1,18 @@ +{% extends "labels/labels_layout.html" %} +{% block title %} + Delete Label Group +{% endblock title %} +{% block heading %} + +{% endblock heading %} +{% block content %} +
You are about to delete the label group "{{ object.name }}".
+
+ {% csrf_token %} +
+ Cancel + +
+
+{% endblock content %} diff --git a/radis/labels/templates/labels/label_group_detail.html b/radis/labels/templates/labels/label_group_detail.html new file mode 100644 index 00000000..a59e9ad0 --- /dev/null +++ b/radis/labels/templates/labels/label_group_detail.html @@ -0,0 +1,112 @@ +{% extends "labels/labels_layout.html" %} +{% load bootstrap_icon from common_extras %} +{% load labels_extras %} +{% block title %} + {{ object.name }} +{% endblock title %} +{% block heading %} + +{% endblock heading %} +{% block content %} +
+
+
+
+
{{ object.name }}
+ {% if object.description %}

{{ object.description }}

{% endif %} +
+ +
+
+
+ {% if backfill_job %} +
+
+
Backfill Status
+
+
Status
+
+ {{ backfill_job.get_status_display }} +
+ {% if backfill_job.total_reports > 0 %} +
Progress
+
+ {{ backfill_job.processed_reports }} of {{ backfill_job.total_reports }} +
+
+
+
+ {% endif %} + {% if backfill_job.message %} +
Message
+
+ {{ backfill_job.message }} +
+ {% endif %} +
+ {% if backfill_job.is_cancelable %} +
+ {% csrf_token %} + +
+ {% endif %} +
+
+ {% endif %} + + {% if object.questions.exists %} +
+ {% for question in object.questions.all %} +
+
+
+
{{ question.label }}
+ {% if question.question %}
{{ question.question }}
{% endif %} +
+
+ Edit + Delete +
+
+ {% if question.choices.exists %} +
+
Choices
+
+ {% for choice in question.choices.all %} + + {{ choice.value }} · {{ choice.label }} + {% if choice.is_unknown %}(Unknown){% endif %} + + {% endfor %} +
+
+ {% endif %} +
+ {% endfor %} +
+ {% else %} +
No questions yet.
+ {% endif %} +{% endblock content %} diff --git a/radis/labels/templates/labels/label_group_form.html b/radis/labels/templates/labels/label_group_form.html new file mode 100644 index 00000000..27458a50 --- /dev/null +++ b/radis/labels/templates/labels/label_group_form.html @@ -0,0 +1,27 @@ +{% extends "labels/labels_layout.html" %} +{% load crispy from crispy_forms_tags %} +{% block title %} + {% if form.instance.pk %} + Update Label Group + {% else %} + Create Label Group + {% endif %} +{% endblock title %} +{% block heading %} + +{% endblock heading %} +{% block content %} +
+ {% crispy form %} +
+ Cancel + +
+
+{% endblock content %} diff --git a/radis/labels/templates/labels/label_group_list.html b/radis/labels/templates/labels/label_group_list.html new file mode 100644 index 00000000..b61595d5 --- /dev/null +++ b/radis/labels/templates/labels/label_group_list.html @@ -0,0 +1,19 @@ +{% extends "labels/labels_layout.html" %} +{% load render_table from django_tables2 %} +{% load bootstrap_icon from common_extras %} +{% block title %} + Labels +{% endblock title %} +{% block heading %} + + + + {% bootstrap_icon "plus-lg" %} + Add Group + + + +{% endblock heading %} +{% block content %} +
{% render_table table %}
+{% endblock content %} diff --git a/radis/labels/templates/labels/label_question_confirm_delete.html b/radis/labels/templates/labels/label_question_confirm_delete.html new file mode 100644 index 00000000..72ba3735 --- /dev/null +++ b/radis/labels/templates/labels/label_question_confirm_delete.html @@ -0,0 +1,18 @@ +{% extends "labels/labels_layout.html" %} +{% block title %} + Delete Question +{% endblock title %} +{% block heading %} + +{% endblock heading %} +{% block content %} +
You are about to delete the question "{{ object.name }}".
+
+ {% csrf_token %} +
+ Cancel + +
+
+{% endblock content %} diff --git a/radis/labels/templates/labels/label_question_form.html b/radis/labels/templates/labels/label_question_form.html new file mode 100644 index 00000000..e3ceab54 --- /dev/null +++ b/radis/labels/templates/labels/label_question_form.html @@ -0,0 +1,28 @@ +{% extends "labels/labels_layout.html" %} +{% load crispy from crispy_forms_tags %} +{% block title %} + {% if form.instance.pk %} + Update Question + {% else %} + Create Question + {% endif %} +{% endblock title %} +{% block heading %} + +{% endblock heading %} +{% block content %} +
+ {% crispy form %} +
+ Cancel + +
+
+{% endblock content %} diff --git a/radis/labels/templates/labels/labels_layout.html b/radis/labels/templates/labels/labels_layout.html new file mode 100644 index 00000000..7c354715 --- /dev/null +++ b/radis/labels/templates/labels/labels_layout.html @@ -0,0 +1,11 @@ +{% extends "core/core_layout.html" %} +{% load static from static %} +{% block css %} + {{ block.super }} + +{% endblock css %} +{% block script %} + {{ block.super }} +{% endblock script %} diff --git a/radis/labels/templatetags/__init__.py b/radis/labels/templatetags/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/radis/labels/templatetags/labels_extras.py b/radis/labels/templatetags/labels_extras.py new file mode 100644 index 00000000..c0a649f5 --- /dev/null +++ b/radis/labels/templatetags/labels_extras.py @@ -0,0 +1,18 @@ +from django.template import Library + +from ..models import LabelBackfillJob + +register = Library() + + +@register.filter +def backfill_status_css(status: str) -> str: + css_classes: dict[str, str] = { + LabelBackfillJob.Status.PENDING: "text-secondary", + LabelBackfillJob.Status.IN_PROGRESS: "text-info", + LabelBackfillJob.Status.CANCELING: "text-muted", + LabelBackfillJob.Status.CANCELED: "text-muted", + LabelBackfillJob.Status.SUCCESS: "text-success", + LabelBackfillJob.Status.FAILURE: "text-danger", + } + return css_classes.get(status, "") diff --git a/radis/labels/tests/__init__.py b/radis/labels/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/radis/labels/tests/test_backfill.py b/radis/labels/tests/test_backfill.py new file mode 100644 index 00000000..f20510e4 --- /dev/null +++ b/radis/labels/tests/test_backfill.py @@ -0,0 +1,435 @@ +from unittest.mock import patch + +import pytest +from adit_radis_shared.accounts.factories import UserFactory +from django.test import Client + +from radis.labels.models import LabelBackfillJob, LabelGroup, LabelQuestion +from radis.labels.tasks import _increment_and_maybe_finalize + +# -- Model tests -- + + +@pytest.mark.django_db +class TestLabelBackfillJobModel: + def _create_job(self, **kwargs) -> LabelBackfillJob: + group = LabelGroup.objects.create(name="Findings") + defaults = {"label_group": group} + defaults.update(kwargs) + return LabelBackfillJob.objects.create(**defaults) + + def test_default_status_is_pending(self): + job = self._create_job() + assert job.status == LabelBackfillJob.Status.PENDING + + def test_str(self): + job = self._create_job() + assert str(job) == f"LabelBackfillJob [{job.pk}]" + + def test_is_cancelable_pending(self): + job = self._create_job(status=LabelBackfillJob.Status.PENDING) + assert job.is_cancelable is True + + def test_is_cancelable_in_progress(self): + job = self._create_job(status=LabelBackfillJob.Status.IN_PROGRESS) + assert job.is_cancelable is True + + def test_is_not_cancelable_success(self): + job = self._create_job(status=LabelBackfillJob.Status.SUCCESS) + assert job.is_cancelable is False + + def test_is_not_cancelable_canceled(self): + job = self._create_job(status=LabelBackfillJob.Status.CANCELED) + assert job.is_cancelable is False + + def test_is_not_cancelable_canceling(self): + job = self._create_job(status=LabelBackfillJob.Status.CANCELING) + assert job.is_cancelable is False + + def test_is_not_cancelable_failure(self): + job = self._create_job(status=LabelBackfillJob.Status.FAILURE) + assert job.is_cancelable is False + + def test_is_active_pending(self): + job = self._create_job(status=LabelBackfillJob.Status.PENDING) + assert job.is_active is True + + def test_is_active_in_progress(self): + job = self._create_job(status=LabelBackfillJob.Status.IN_PROGRESS) + assert job.is_active is True + + def test_is_not_active_terminal_states(self): + for status in [ + LabelBackfillJob.Status.SUCCESS, + LabelBackfillJob.Status.FAILURE, + LabelBackfillJob.Status.CANCELED, + LabelBackfillJob.Status.CANCELING, + ]: + job = self._create_job(status=status) + assert job.is_active is False, f"Expected is_active=False for status={status}" + + def test_progress_percent_zero_total(self): + job = self._create_job(total_reports=0, processed_reports=0) + assert job.progress_percent == 0 + + def test_progress_percent_partial(self): + job = self._create_job(total_reports=200, processed_reports=50) + assert job.progress_percent == 25 + + def test_progress_percent_complete(self): + job = self._create_job(total_reports=100, processed_reports=100) + assert job.progress_percent == 100 + + def test_progress_percent_capped_at_100(self): + job = self._create_job(total_reports=100, processed_reports=150) + assert job.progress_percent == 100 + + def test_ordering_by_created_at_descending(self): + group = LabelGroup.objects.create(name="TestGroup") + job1 = LabelBackfillJob.objects.create(label_group=group) + job2 = LabelBackfillJob.objects.create(label_group=group) + jobs = list(LabelBackfillJob.objects.all()) + assert jobs[0] == job2 + assert jobs[1] == job1 + + def test_cascade_delete_with_group(self): + group = LabelGroup.objects.create(name="DeleteMe") + LabelBackfillJob.objects.create(label_group=group) + assert LabelBackfillJob.objects.count() == 1 + group.delete() + assert LabelBackfillJob.objects.count() == 0 + + +# -- Signal dedup tests -- + + +@pytest.mark.django_db +class TestSignalDedup: + @patch("radis.labels.signals.enqueue_label_group_backfill") + def test_creating_question_creates_backfill_job(self, mock_task): + mock_task.defer = lambda **kw: None + group = LabelGroup.objects.create(name="Findings") + LabelQuestion.objects.create(group=group, label="PE present?") + + assert LabelBackfillJob.objects.filter(label_group=group).count() == 1 + + @patch("radis.labels.signals.enqueue_label_group_backfill") + def test_second_question_skips_backfill_when_active(self, mock_task): + mock_task.defer = lambda **kw: None + group = LabelGroup.objects.create(name="Findings") + LabelQuestion.objects.create(group=group, label="PE present?") + LabelQuestion.objects.create(group=group, label="Pneumonia present?") + + # Only one backfill job should exist + assert LabelBackfillJob.objects.filter(label_group=group).count() == 1 + + @patch("radis.labels.signals.enqueue_label_group_backfill") + def test_new_question_after_completed_backfill_creates_new_job(self, mock_task): + mock_task.defer = lambda **kw: None + group = LabelGroup.objects.create(name="Findings") + LabelQuestion.objects.create(group=group, label="PE present?") + + # Simulate first backfill completing + job = LabelBackfillJob.objects.get(label_group=group) + job.status = LabelBackfillJob.Status.SUCCESS + job.save() + + LabelQuestion.objects.create(group=group, label="Pneumonia present?") + assert LabelBackfillJob.objects.filter(label_group=group).count() == 2 + + @patch("radis.labels.signals.enqueue_label_group_backfill") + def test_inactive_question_does_not_trigger_backfill(self, mock_task): + mock_task.defer = lambda **kw: None + group = LabelGroup.objects.create(name="Findings") + LabelQuestion.objects.create(group=group, label="Draft Q", is_active=False) + + assert LabelBackfillJob.objects.filter(label_group=group).count() == 0 + + @patch("radis.labels.signals.enqueue_label_group_backfill") + def test_updating_question_does_not_trigger_backfill(self, mock_task): + mock_task.defer = lambda **kw: None + group = LabelGroup.objects.create(name="Findings") + question = LabelQuestion.objects.create(group=group, label="PE present?") + + # Clear the job created by the initial create + LabelBackfillJob.objects.all().delete() + + question.label = "Updated label" + question.save() + assert LabelBackfillJob.objects.filter(label_group=group).count() == 0 + + @patch("radis.labels.signals.enqueue_label_group_backfill") + @pytest.mark.django_db(transaction=True) + def test_dedup_across_different_groups(self, mock_task): + mock_task.defer = lambda **kw: None + group1 = LabelGroup.objects.create(name="Group A") + group2 = LabelGroup.objects.create(name="Group B") + + LabelQuestion.objects.create(group=group1, label="Q1") + LabelQuestion.objects.create(group=group2, label="Q2") + + # Each group should get its own backfill + assert LabelBackfillJob.objects.filter(label_group=group1).count() == 1 + assert LabelBackfillJob.objects.filter(label_group=group2).count() == 1 + + +# -- increment_and_maybe_finalize tests -- + + +@pytest.mark.django_db +class TestIncrementAndMaybeFinalize: + def _create_active_job(self, total=100, processed=0) -> LabelBackfillJob: + group = LabelGroup.objects.create(name="TestGroup") + return LabelBackfillJob.objects.create( + label_group=group, + status=LabelBackfillJob.Status.IN_PROGRESS, + total_reports=total, + processed_reports=processed, + ) + + def test_increments_processed_reports(self): + job = self._create_active_job(total=100, processed=0) + _increment_and_maybe_finalize(job.id, 25) + job.refresh_from_db() + assert job.processed_reports == 25 + + def test_does_not_finalize_when_incomplete(self): + job = self._create_active_job(total=100, processed=0) + _increment_and_maybe_finalize(job.id, 25) + job.refresh_from_db() + assert job.status == LabelBackfillJob.Status.IN_PROGRESS + assert job.ended_at is None + + def test_finalizes_to_success_when_complete(self): + job = self._create_active_job(total=100, processed=75) + _increment_and_maybe_finalize(job.id, 25) + job.refresh_from_db() + assert job.status == LabelBackfillJob.Status.SUCCESS + assert job.ended_at is not None + + def test_finalizes_to_canceled_when_canceling(self): + group = LabelGroup.objects.create(name="TestGroup") + job = LabelBackfillJob.objects.create( + label_group=group, + status=LabelBackfillJob.Status.CANCELING, + total_reports=100, + processed_reports=75, + ) + _increment_and_maybe_finalize(job.id, 25) + job.refresh_from_db() + assert job.status == LabelBackfillJob.Status.CANCELED + assert job.ended_at is not None + + def test_handles_missing_job_gracefully(self): + # Should not raise + _increment_and_maybe_finalize(99999, 10) + + def test_over_counting_still_finalizes(self): + job = self._create_active_job(total=100, processed=90) + _increment_and_maybe_finalize(job.id, 20) + job.refresh_from_db() + assert job.processed_reports == 110 + assert job.status == LabelBackfillJob.Status.SUCCESS + + +# -- Cancel view tests -- + + +@pytest.mark.django_db +class TestLabelBackfillCancelView: + def _create_job(self, status=LabelBackfillJob.Status.IN_PROGRESS) -> LabelBackfillJob: + group = LabelGroup.objects.create(name="TestGroup") + return LabelBackfillJob.objects.create( + label_group=group, + status=status, + total_reports=100, + ) + + def test_cancel_requires_login(self, client: Client): + job = self._create_job() + response = client.post(f"/labels/backfill/{job.pk}/cancel/") + assert response.status_code == 302 + assert "/accounts/login/" in response["Location"] + + def test_cancel_sets_canceling_status(self, client: Client): + user = UserFactory.create(is_active=True, is_staff=True) + client.force_login(user) + job = self._create_job(status=LabelBackfillJob.Status.IN_PROGRESS) + + response = client.post(f"/labels/backfill/{job.pk}/cancel/") + assert response.status_code == 302 + + job.refresh_from_db() + assert job.status == LabelBackfillJob.Status.CANCELING + + def test_cancel_pending_job(self, client: Client): + user = UserFactory.create(is_active=True, is_staff=True) + client.force_login(user) + job = self._create_job(status=LabelBackfillJob.Status.PENDING) + + response = client.post(f"/labels/backfill/{job.pk}/cancel/") + assert response.status_code == 302 + + job.refresh_from_db() + assert job.status == LabelBackfillJob.Status.CANCELING + + def test_cancel_rejected_for_non_staff(self, client: Client): + user = UserFactory.create(is_active=True, is_staff=False) + client.force_login(user) + job = self._create_job(status=LabelBackfillJob.Status.IN_PROGRESS) + + response = client.post(f"/labels/backfill/{job.pk}/cancel/") + assert response.status_code == 403 + + job.refresh_from_db() + assert job.status == LabelBackfillJob.Status.IN_PROGRESS + + def test_cancel_already_completed_returns_400(self, client: Client): + user = UserFactory.create(is_active=True, is_staff=True) + client.force_login(user) + job = self._create_job(status=LabelBackfillJob.Status.SUCCESS) + + response = client.post(f"/labels/backfill/{job.pk}/cancel/") + assert response.status_code == 400 + + def test_cancel_nonexistent_job_returns_404(self, client: Client): + user = UserFactory.create(is_active=True, is_staff=True) + client.force_login(user) + + response = client.post("/labels/backfill/99999/cancel/") + assert response.status_code == 404 + + def test_cancel_redirects_to_group_detail(self, client: Client): + user = UserFactory.create(is_active=True, is_staff=True) + client.force_login(user) + job = self._create_job() + + response = client.post(f"/labels/backfill/{job.pk}/cancel/") + assert response.status_code == 302 + assert f"/labels/{job.label_group_id}/" in response["Location"] + + def test_get_not_allowed(self, client: Client): + user = UserFactory.create(is_active=True, is_staff=True) + client.force_login(user) + job = self._create_job() + + response = client.get(f"/labels/backfill/{job.pk}/cancel/") + assert response.status_code == 405 + + +# -- Detail view context tests -- + + +@pytest.mark.django_db +class TestLabelGroupDetailViewBackfill: + def test_detail_view_includes_backfill_job(self, client: Client): + user = UserFactory.create(is_active=True) + client.force_login(user) + + group = LabelGroup.objects.create(name="Findings") + job = LabelBackfillJob.objects.create( + label_group=group, + status=LabelBackfillJob.Status.IN_PROGRESS, + total_reports=500, + processed_reports=200, + ) + + response = client.get(f"/labels/{group.pk}/") + assert response.status_code == 200 + assert response.context["backfill_job"] == job + + def test_detail_view_returns_most_recent_backfill(self, client: Client): + user = UserFactory.create(is_active=True) + client.force_login(user) + + group = LabelGroup.objects.create(name="Findings") + LabelBackfillJob.objects.create( + label_group=group, + status=LabelBackfillJob.Status.SUCCESS, + ) + job2 = LabelBackfillJob.objects.create( + label_group=group, + status=LabelBackfillJob.Status.IN_PROGRESS, + ) + + response = client.get(f"/labels/{group.pk}/") + assert response.context["backfill_job"] == job2 + + def test_detail_view_no_backfill_job(self, client: Client): + user = UserFactory.create(is_active=True) + client.force_login(user) + + group = LabelGroup.objects.create(name="Findings") + response = client.get(f"/labels/{group.pk}/") + assert response.context["backfill_job"] is None + + +# -- Management command tests -- + + +@pytest.mark.django_db +class TestLabelsBackfillCommand: + @patch("radis.labels.tasks.process_label_group") + def test_command_creates_backfill_job(self, mock_task): + mock_task.defer = lambda **kw: None + from django.core.management import call_command + + from radis.reports.models import Language, Report + + group = LabelGroup.objects.create(name="Findings") + lang = Language.objects.create(code="en") + Report.objects.create( + document_id="doc-1", + body="Test report body", + patient_birth_date="2000-01-01", + patient_sex="M", + study_datetime="2024-01-15T10:00:00Z", + language=lang, + ) + + call_command("labels_backfill", group=str(group.id)) + + job = LabelBackfillJob.objects.get(label_group=group) + assert job.status == LabelBackfillJob.Status.IN_PROGRESS + assert job.total_reports == 1 + assert job.started_at is not None + + +# -- Templatetag tests -- + + +class TestBackfillStatusCssFilter: + def test_pending(self): + from radis.labels.templatetags.labels_extras import backfill_status_css + + assert backfill_status_css(LabelBackfillJob.Status.PENDING) == "text-secondary" + + def test_in_progress(self): + from radis.labels.templatetags.labels_extras import backfill_status_css + + assert backfill_status_css(LabelBackfillJob.Status.IN_PROGRESS) == "text-info" + + def test_success(self): + from radis.labels.templatetags.labels_extras import backfill_status_css + + assert backfill_status_css(LabelBackfillJob.Status.SUCCESS) == "text-success" + + def test_failure(self): + from radis.labels.templatetags.labels_extras import backfill_status_css + + assert backfill_status_css(LabelBackfillJob.Status.FAILURE) == "text-danger" + + def test_canceling(self): + from radis.labels.templatetags.labels_extras import backfill_status_css + + assert backfill_status_css(LabelBackfillJob.Status.CANCELING) == "text-muted" + + def test_canceled(self): + from radis.labels.templatetags.labels_extras import backfill_status_css + + assert backfill_status_css(LabelBackfillJob.Status.CANCELED) == "text-muted" + + def test_unknown_status(self): + from radis.labels.templatetags.labels_extras import backfill_status_css + + assert backfill_status_css("XX") == "" diff --git a/radis/labels/tests/test_management.py b/radis/labels/tests/test_management.py new file mode 100644 index 00000000..afea6cdc --- /dev/null +++ b/radis/labels/tests/test_management.py @@ -0,0 +1,37 @@ +import json +from pathlib import Path + +import pytest +from django.core.management import call_command + +from radis.labels.models import LabelChoice, LabelGroup, LabelQuestion + + +@pytest.mark.django_db +def test_labels_seed_command_creates_objects(tmp_path: Path): + payload = { + "groups": [ + { + "name": "Embolism", + "questions": [ + { + "label": "Pulmonary embolism", + "question": "Pulmonary embolism present?", + } + ], + } + ] + } + seed_file = tmp_path / "labels_seed.json" + seed_file.write_text(json.dumps(payload)) + + call_command("labels_seed", file=str(seed_file)) + + group = LabelGroup.objects.get(name="Embolism") + question = LabelQuestion.objects.get(group=group, label="Pulmonary embolism") + choices = LabelChoice.objects.filter(question=question) + + assert group.name == "Embolism" + assert question.question == "Pulmonary embolism present?" + assert choices.count() == 3 + assert choices.filter(value="cannot_decide", is_unknown=True).exists() diff --git a/radis/labels/tests/test_utils.py b/radis/labels/tests/test_utils.py new file mode 100644 index 00000000..43d891ff --- /dev/null +++ b/radis/labels/tests/test_utils.py @@ -0,0 +1,59 @@ +import pytest +from pydantic import ValidationError + +from radis.labels.models import LabelGroup, LabelQuestion +from radis.labels.utils.processor_utils import ( + generate_labeling_schema, + generate_questions_for_prompt, +) + + +@pytest.mark.django_db +def test_generate_questions_for_prompt_includes_choices(): + group = LabelGroup.objects.create(name="Finding") + question = LabelQuestion.objects.create( + group=group, + label="Pulmonary embolism", + question="Pulmonary embolism present?", + ) + choices = list(question.choices.all()) + + prompt = generate_questions_for_prompt([question]) + + assert "question_0" in prompt + assert "Pulmonary embolism present?" in prompt + assert any(choice.value in prompt for choice in choices) + assert "yes (Yes)" in prompt + assert "no (No)" in prompt + assert "cannot_decide (Cannot decide)" in prompt + + +@pytest.mark.django_db +def test_label_question_auto_generates_prompt_from_label(): + group = LabelGroup.objects.create(name="Finding") + question = LabelQuestion.objects.create( + group=group, + label="Pulmonary embolism", + question="", + ) + + assert question.question == "Pulmonary embolism" + + +@pytest.mark.django_db +def test_generate_labeling_schema_enforces_choice_enum(): + group = LabelGroup.objects.create(name="Finding") + question = LabelQuestion.objects.create( + group=group, + label="Broken bones", + question="Is there a case of broken bones on this report?", + ) + + Schema = generate_labeling_schema([question]) + + # Valid values should validate. + Schema.model_validate({"question_0": {"choice": "cannot_decide"}}) + + # Anything outside the configured choice values should fail validation. + with pytest.raises(ValidationError): + Schema.model_validate({"question_0": {"choice": "maybe"}}) diff --git a/radis/labels/tests/test_views.py b/radis/labels/tests/test_views.py new file mode 100644 index 00000000..707a4a4c --- /dev/null +++ b/radis/labels/tests/test_views.py @@ -0,0 +1,130 @@ +import pytest +from adit_radis_shared.accounts.factories import UserFactory +from django.test import Client + +from radis.labels.models import LabelChoice, LabelGroup, LabelQuestion + + +def create_group(): + return LabelGroup.objects.create(name="Findings") + + +# -- List view (login required, no staff check) -- + + +@pytest.mark.django_db +def test_label_group_list_requires_login(client: Client): + response = client.get("/labels/") + assert response.status_code == 302 + assert "/accounts/login/" in response["Location"] + + +@pytest.mark.django_db +def test_label_group_list_view(client: Client): + user = UserFactory.create(is_active=True) + create_group() + client.force_login(user) + response = client.get("/labels/") + assert response.status_code == 200 + + +# -- Group create (staff required) -- + + +@pytest.mark.django_db +def test_label_group_create_view(client: Client): + user = UserFactory.create(is_active=True, is_staff=True) + client.force_login(user) + response = client.post( + "/labels/create/", + { + "name": "Protocols", + "description": "Standard protocol labels", + "is_active": True, + "order": 1, + }, + ) + assert response.status_code == 302 + assert LabelGroup.objects.filter(name="Protocols").exists() + + +@pytest.mark.django_db +def test_label_group_create_rejected_for_non_staff(client: Client): + user = UserFactory.create(is_active=True, is_staff=False) + client.force_login(user) + response = client.post( + "/labels/create/", + {"name": "Should Fail", "is_active": True, "order": 1}, + ) + assert response.status_code == 403 + assert not LabelGroup.objects.filter(name="Should Fail").exists() + + +# -- Question create (staff required) -- + + +@pytest.mark.django_db +def test_label_question_create_view_with_choices(client: Client): + user = UserFactory.create(is_active=True, is_staff=True) + group = create_group() + client.force_login(user) + + response = client.post( + f"/labels/{group.pk}/questions/create/", + { + "label": "Pulmonary embolism", + "question": "Pulmonary embolism present?", + "is_active": True, + "order": 1, + }, + ) + + assert response.status_code == 302 + question = LabelQuestion.objects.get(group=group, label="Pulmonary embolism") + assert LabelChoice.objects.filter(question=question).count() == 3 + + +@pytest.mark.django_db +def test_label_question_create_rejected_for_non_staff(client: Client): + user = UserFactory.create(is_active=True, is_staff=False) + group = create_group() + client.force_login(user) + + response = client.post( + f"/labels/{group.pk}/questions/create/", + {"label": "Should Fail", "question": "Nope", "is_active": True, "order": 1}, + ) + assert response.status_code == 403 + assert not LabelQuestion.objects.filter(label="Should Fail").exists() + + +# -- Group update (staff required) -- + + +@pytest.mark.django_db +def test_label_group_update_rejected_for_non_staff(client: Client): + user = UserFactory.create(is_active=True, is_staff=False) + group = create_group() + client.force_login(user) + + response = client.post( + f"/labels/{group.pk}/update/", + {"name": "Hacked", "is_active": True, "order": 1}, + ) + assert response.status_code == 403 + group.refresh_from_db() + assert group.name == "Findings" + + +# -- Group delete (staff required) -- + + +@pytest.mark.django_db +def test_label_group_delete_rejected_for_non_staff(client: Client): + user = UserFactory.create(is_active=True, is_staff=False) + group = create_group() + client.force_login(user) + + response = client.post(f"/labels/{group.pk}/delete/") + assert response.status_code == 403 + assert LabelGroup.objects.filter(pk=group.pk).exists() diff --git a/radis/labels/urls.py b/radis/labels/urls.py new file mode 100644 index 00000000..aec32e05 --- /dev/null +++ b/radis/labels/urls.py @@ -0,0 +1,39 @@ +from django.urls import path + +from . import views + +urlpatterns = [ + path("", views.LabelGroupListView.as_view(), name="label_group_list"), + path("create/", views.LabelGroupCreateView.as_view(), name="label_group_create"), + path("/", views.LabelGroupDetailView.as_view(), name="label_group_detail"), + path( + "/update/", + views.LabelGroupUpdateView.as_view(), + name="label_group_update", + ), + path( + "/delete/", + views.LabelGroupDeleteView.as_view(), + name="label_group_delete", + ), + path( + "/questions/create/", + views.LabelQuestionCreateView.as_view(), + name="label_question_create", + ), + path( + "/questions//update/", + views.LabelQuestionUpdateView.as_view(), + name="label_question_update", + ), + path( + "/questions//delete/", + views.LabelQuestionDeleteView.as_view(), + name="label_question_delete", + ), + path( + "backfill//cancel/", + views.LabelBackfillCancelView.as_view(), + name="label_backfill_cancel", + ), +] diff --git a/radis/labels/utils/processor_utils.py b/radis/labels/utils/processor_utils.py new file mode 100644 index 00000000..798b6512 --- /dev/null +++ b/radis/labels/utils/processor_utils.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, create_model + +from ..models import LabelQuestion + + +def _generate_label_answer_schema(index: int, question: LabelQuestion) -> type[BaseModel]: + """Create a strict answer schema for one question. + + We enforce that `choice` is exactly one of the configured choice values using `Literal[...]`. + This mirrors how the `selectionOutputType` branch enforces allowed options for extraction + fields. + """ + + choice_values = tuple( + choice.value + for choice in question.choices.all() + if isinstance(choice.value, str) and choice.value + ) + if not choice_values: + raise ValueError( + f"LabelQuestion {question.pk or ''} has no valid choices configured." + ) + + ChoiceType = Literal[*choice_values] + return create_model( + f"LabelAnswer_{index}", + choice=(ChoiceType, ...), + confidence=(float | None, None), + rationale=(str | None, None), + ) + + +def generate_labeling_schema(questions: list[LabelQuestion]) -> type[BaseModel]: + field_definitions: dict[str, Any] = {} + for index, question in enumerate(questions): + AnswerSchema = _generate_label_answer_schema(index, question) + field_definitions[f"question_{index}"] = (AnswerSchema, ...) + + return create_model("LabelingModel", **field_definitions) + + +def generate_questions_for_prompt(questions: list[LabelQuestion]) -> str: + prompt = "" + for index, question in enumerate(questions): + choices = ", ".join( + [f"{choice.value} ({choice.label})" for choice in question.choices.all()] + ) + prompt += f"question_{index}: {question.question}\n" + prompt += f"choices (return exactly one choice value): {choices}\n" + + return prompt diff --git a/radis/labels/views.py b/radis/labels/views.py new file mode 100644 index 00000000..3f98fcb7 --- /dev/null +++ b/radis/labels/views.py @@ -0,0 +1,173 @@ +from typing import Any + +from adit_radis_shared.common.types import AuthenticatedHttpRequest +from django.contrib import messages +from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin +from django.core.exceptions import SuspiciousOperation +from django.db.models import Prefetch, QuerySet +from django.http import HttpResponse, HttpResponseRedirect +from django.shortcuts import get_object_or_404, redirect +from django.urls import reverse, reverse_lazy +from django.views import View +from django.views.generic import CreateView, DeleteView, DetailView, UpdateView +from django_tables2 import SingleTableView + +from .forms import LabelGroupForm, LabelQuestionForm +from .models import LabelBackfillJob, LabelGroup, LabelQuestion +from .tables import LabelGroupTable + + +class StaffRequiredMixin(LoginRequiredMixin, UserPassesTestMixin): + request: AuthenticatedHttpRequest + + def test_func(self) -> bool: + return self.request.user.is_staff + + +class LabelGroupListView(LoginRequiredMixin, SingleTableView): + model = LabelGroup + table_class = LabelGroupTable + template_name = "labels/label_group_list.html" + paginate_by = 30 + request: AuthenticatedHttpRequest + + def get_queryset(self) -> QuerySet[LabelGroup]: + return LabelGroup.objects.all().order_by("order", "name") + + +class LabelGroupDetailView(LoginRequiredMixin, DetailView): + model = LabelGroup + template_name = "labels/label_group_detail.html" + + def get_queryset(self) -> QuerySet[LabelGroup]: + return LabelGroup.objects.prefetch_related( + Prefetch( + "questions", + queryset=LabelQuestion.objects.prefetch_related("choices").order_by( + "order", "label" + ), + ) + ) + + def get_context_data(self, **kwargs: Any) -> dict[str, Any]: + context = super().get_context_data(**kwargs) + context["backfill_job"] = ( + LabelBackfillJob.objects.filter(label_group=self.object).order_by("-created_at").first() + ) + return context + + +class LabelGroupCreateView(StaffRequiredMixin, CreateView): + template_name = "labels/label_group_form.html" + form_class = LabelGroupForm + success_url = reverse_lazy("label_group_list") + + +class LabelGroupUpdateView(StaffRequiredMixin, UpdateView): + template_name = "labels/label_group_form.html" + form_class = LabelGroupForm + model = LabelGroup + + def get_success_url(self) -> str: + return reverse("label_group_detail", kwargs={"pk": self.object.pk}) + + +class LabelGroupDeleteView(StaffRequiredMixin, DeleteView): + model = LabelGroup + success_url = reverse_lazy("label_group_list") + template_name = "labels/label_group_confirm_delete.html" + + +class LabelQuestionCreateView(StaffRequiredMixin, CreateView): + template_name = "labels/label_question_form.html" + form_class = LabelQuestionForm + model = LabelQuestion + request: AuthenticatedHttpRequest + + def dispatch(self, request, *args, **kwargs): + self.group = LabelGroup.objects.get(pk=kwargs["group_pk"]) + return super().dispatch(request, *args, **kwargs) + + def get_context_data(self, **kwargs: Any) -> dict[str, Any]: + ctx = super().get_context_data(**kwargs) + ctx["group"] = self.group + return ctx + + def get_form_kwargs(self) -> dict[str, Any]: + kwargs = super().get_form_kwargs() + kwargs["group"] = self.group + return kwargs + + def form_valid(self, form) -> HttpResponse: + form.instance.group = self.group + self.object = form.save() + return HttpResponseRedirect(self.get_success_url()) + + def get_success_url(self) -> str: + return reverse("label_group_detail", kwargs={"pk": self.group.pk}) + + +class LabelQuestionUpdateView(StaffRequiredMixin, UpdateView): + template_name = "labels/label_question_form.html" + form_class = LabelQuestionForm + model = LabelQuestion + request: AuthenticatedHttpRequest + + def dispatch(self, request, *args, **kwargs): + self.group = LabelGroup.objects.get(pk=kwargs["group_pk"]) + return super().dispatch(request, *args, **kwargs) + + def get_queryset(self) -> QuerySet[LabelQuestion]: + return LabelQuestion.objects.filter(group=self.group).prefetch_related("choices") + + def get_context_data(self, **kwargs: Any) -> dict[str, Any]: + ctx = super().get_context_data(**kwargs) + ctx["group"] = self.group + return ctx + + def get_form_kwargs(self) -> dict[str, Any]: + kwargs = super().get_form_kwargs() + kwargs["group"] = self.group + return kwargs + + def form_valid(self, form) -> HttpResponse: + self.object = form.save() + return HttpResponseRedirect(self.get_success_url()) + + def get_success_url(self) -> str: + return reverse("label_group_detail", kwargs={"pk": self.group.pk}) + + +class LabelQuestionDeleteView(StaffRequiredMixin, DeleteView): + model = LabelQuestion + template_name = "labels/label_question_confirm_delete.html" + + def dispatch(self, request, *args, **kwargs): + self.group = LabelGroup.objects.get(pk=kwargs["group_pk"]) + return super().dispatch(request, *args, **kwargs) + + def get_queryset(self) -> QuerySet[LabelQuestion]: + return LabelQuestion.objects.filter(group=self.group) + + def get_success_url(self) -> str: + return reverse("label_group_detail", kwargs={"pk": self.group.pk}) + + +class LabelBackfillCancelView(StaffRequiredMixin, View): + def post(self, request: AuthenticatedHttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: + backfill_job = get_object_or_404(LabelBackfillJob, pk=kwargs["pk"]) + + if not backfill_job.is_cancelable: + raise SuspiciousOperation( + f"Backfill job {backfill_job.pk} with status " + f"{backfill_job.get_status_display()} is not cancelable." + ) + + backfill_job.status = LabelBackfillJob.Status.CANCELING + backfill_job.save() + + messages.success( + request, + f"Backfill for {backfill_job.label_group.name} is being cancelled.", + ) + return redirect("label_group_detail", pk=backfill_job.label_group_id) diff --git a/radis/reports/models.py b/radis/reports/models.py index a56b0f54..50e44ea5 100644 --- a/radis/reports/models.py +++ b/radis/reports/models.py @@ -75,6 +75,7 @@ class Report(models.Model): created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) + id: int metadata: models.QuerySet["Metadata"] class Meta: diff --git a/radis/reports/templates/reports/report_detail.html b/radis/reports/templates/reports/report_detail.html index 160ace1a..a93b30eb 100644 --- a/radis/reports/templates/reports/report_detail.html +++ b/radis/reports/templates/reports/report_detail.html @@ -68,6 +68,48 @@ {% endif %} + {% if report.labels.exists %} +
+
+
Labels
+ + + {% for label in report.labels.all %} + + + + + + {% if label.rationale %} + + + + + {% endif %} + {% if label.question.question %} + + + + + {% endif %} + {% endfor %} + +
+ {{ label.question.group.name }} - {{ label.question.label }} + {{ label.choice.label }} + {% if label.confidence is not None %} + {{ label.confidence|floatformat:2 }} + {% else %} + n/a + {% endif %} +
+ {{ label.rationale }} +
+ Prompt: {{ label.question.question }} +
+
+
+ {% endif %}
{% include "reports/_report_buttons_panel.html" with hide_view_button=True %}
diff --git a/radis/reports/views.py b/radis/reports/views.py index 00ff8520..db8f0c2a 100644 --- a/radis/reports/views.py +++ b/radis/reports/views.py @@ -1,10 +1,12 @@ from adit_radis_shared.common.mixins import PageSizeSelectMixin from adit_radis_shared.common.types import AuthenticatedHttpRequest from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin -from django.db.models import QuerySet +from django.db.models import Prefetch, QuerySet from django.views.generic.detail import DetailView from django_filters.views import FilterView +from radis.labels.models import ReportLabel + from .filters import ReportFilter from .models import Report @@ -36,7 +38,15 @@ def test_func(self) -> bool | None: def get_queryset(self) -> QuerySet[Report]: active_group = self.request.user.active_group assert active_group - return super().get_queryset().filter(groups=active_group) + labels_queryset = ReportLabel.objects.select_related( + "question__group", "choice" + ).order_by("question__group__order", "question__order", "question__label") + return ( + super() + .get_queryset() + .filter(groups=active_group) + .prefetch_related(Prefetch("labels", queryset=labels_queryset)) + ) class ReportBodyView(ReportDetailView): diff --git a/radis/settings/base.py b/radis/settings/base.py index 1f9b1c6b..4ba00a78 100644 --- a/radis/settings/base.py +++ b/radis/settings/base.py @@ -81,6 +81,7 @@ "radis.search.apps.SearchConfig", "radis.extractions.apps.ExtractionsConfig", "radis.subscriptions.apps.SubscriptionsConfig", + "radis.labels.apps.LabelsConfig", "radis.collections.apps.CollectionsConfig", "radis.notes.apps.NotesConfig", "radis.chats.apps.ChatsConfig", @@ -373,6 +374,24 @@ $questions """ +# Labels +LABELS_SYSTEM_PROMPT = """ +You are an AI medical assistant with extensive knowledge in radiology and general medicine. +You have been trained on a wide range of medical literature, including the latest research +and guidelines in radiological practices. +Assign a single choice to each question based only on the report text. The report and questions +can be given in any language. Don't hallucinate. +For each question return: choice (one of the provided choice values), confidence (0.0 to 1.0), +and rationale (short justification grounded in the report). +If there is not enough evidence, select the choice value that represents \"Unknown\". + +Radiology Report: +$report + +Questions: +$questions +""" + # Extraction OUTPUT_FIELDS_SYSTEM_PROMPT = """ You are an AI medical assistant with extensive knowledge in radiology and general medicine. @@ -407,6 +426,13 @@ # continuous batching capability of the LLM or a combination of both should be used. EXTRACTION_LLM_CONCURRENCY_LIMIT = 6 +# Labels +LABELS_AUTO_BACKFILL_ON_NEW_QUESTION = env.bool( + "LABELS_AUTO_BACKFILL_ON_NEW_QUESTION", default=True +) +LABELING_TASK_BATCH_SIZE = 100 +LABELING_LLM_CONCURRENCY_LIMIT = 6 + START_EXTRACTION_JOB_UNVERIFIED = False # Subscription diff --git a/radis/urls.py b/radis/urls.py index a131cee5..6352a0b0 100644 --- a/radis/urls.py +++ b/radis/urls.py @@ -29,17 +29,20 @@ path("reports/", include("radis.reports.urls")), path("api/reports/", include("radis.reports.api.urls")), path("search/", include("radis.search.urls")), + path("labels/", include("radis.labels.urls")), path("extractions/", include("radis.extractions.urls")), path("collections/", include("radis.collections.urls")), path("notes/", include("radis.notes.urls")), path("subscriptions/", include("radis.subscriptions.urls")), ] -# Debug Toolbar in Debug mode only -if settings.DEBUG: +# Developer tooling URLs (only available when the corresponding apps are installed). +# Note: some test runners force DEBUG=False even when using the development settings module, +# so we gate this on INSTALLED_APPS instead of settings.DEBUG to keep URL reversing stable. +if "django_browser_reload" in settings.INSTALLED_APPS: + urlpatterns = [path("__reload__/", include("django_browser_reload.urls"))] + urlpatterns + +if "debug_toolbar" in settings.INSTALLED_APPS: import debug_toolbar - urlpatterns = [ - path("__reload__/", include("django_browser_reload.urls")), - path("__debug__/", include(debug_toolbar.urls)), - ] + urlpatterns + urlpatterns = [path("__debug__/", include(debug_toolbar.urls))] + urlpatterns diff --git a/resources/labels/seed.json b/resources/labels/seed.json new file mode 100644 index 00000000..bc8eb384 --- /dev/null +++ b/resources/labels/seed.json @@ -0,0 +1,38 @@ +{ + "groups": [ + { + "name": "Chest CT Triage", + "description": "Triage labels for chest CT reports", + "is_active": true, + "order": 1, + "questions": [ + { + "label": "PE", + "question": "Is there evidence of pulmonary embolism?", + "is_active": true, + "order": 1 + }, + { + "label": "PNEUMONIA", + "question": "Is pneumonia present?", + "is_active": true, + "order": 2 + } + ] + }, + { + "name": "Abdominal CT Findings", + "description": "Key abdominal findings", + "is_active": true, + "order": 2, + "questions": [ + { + "label": "APPENDICITIS", + "question": "Is acute appendicitis present?", + "is_active": true, + "order": 1 + } + ] + } + ] +}