diff --git a/src/dispatch/signal/models.py b/src/dispatch/signal/models.py index f491d882d723..ad8b8dcb0e8a 100644 --- a/src/dispatch/signal/models.py +++ b/src/dispatch/signal/models.py @@ -379,6 +379,13 @@ class AdditionalMetadata(DispatchBase): important: Optional[bool] +class SignalStats(DispatchBase): + num_signal_instances_alerted: Optional[int] + num_signal_instances_snoozed: Optional[int] + num_snoozes_active: Optional[int] + num_snoozes_expired: Optional[int] + + class SignalInstanceBase(DispatchBase): project: Optional[ProjectRead] case: Optional[CaseReadMinimal] diff --git a/src/dispatch/signal/service.py b/src/dispatch/signal/service.py index 00e59ba1d5f9..fe2419b0f0c7 100644 --- a/src/dispatch/signal/service.py +++ b/src/dispatch/signal/service.py @@ -7,10 +7,11 @@ from fastapi import HTTPException, status from pydantic.error_wrappers import ErrorWrapper, ValidationError -from sqlalchemy import asc, desc, or_ +from sqlalchemy import asc, desc, or_, func, and_, select, cast from sqlalchemy.orm import Session from sqlalchemy.orm.query import Query from sqlalchemy.sql.expression import true +from sqlalchemy.dialects.postgresql import JSONB from dispatch.auth.models import DispatchUser from dispatch.case.models import Case @@ -34,6 +35,7 @@ SignalNotIdentifiedException, ) from .models import ( + assoc_signal_instance_entities, Signal, SignalCreate, SignalEngagement, @@ -48,6 +50,7 @@ SignalFilterUpdate, SignalInstance, SignalInstanceCreate, + SignalStats, SignalUpdate, assoc_signal_entity_types, ) @@ -965,3 +968,75 @@ def get_cases_for_signal_by_resolution_reason( .order_by(desc(Case.created_at)) .limit(limit) ) + + +def get_signal_stats( + *, db_session: Session, entity_value: str, entity_type_id: int, num_days: int | None +) -> Optional[SignalStats]: + """Gets a signal statistics for a given named entity and type.""" + entity_subquery = ( + db_session.query( + func.jsonb_build_array( + func.jsonb_build_object( + "or", + func.jsonb_build_array( + func.jsonb_build_object( + "model", "Entity", "field", "id", "op", "==", "value", Entity.id + ) + ), + ) + ) + ) + .filter(and_(Entity.value == entity_value, Entity.entity_type_id == entity_type_id)) + .as_scalar() + ) + + active_count = func.count().filter(SignalFilter.expiration > func.current_date()) + expired_count = func.count().filter(SignalFilter.expiration <= func.current_date()) + + query = db_session.query( + active_count.label("active_count"), expired_count.label("expired_count") + ).filter(cast(SignalFilter.expression, JSONB).op("@>")(entity_subquery)) + + snooze_result = db_session.execute(query).fetchone() + + # Calculate the date threshold based on num_days + date_threshold = datetime.utcnow() - timedelta(days=num_days) if num_days is not None else None + + count_with_snooze = func.count().filter(SignalInstance.filter_action == "snooze") + count_without_snooze = func.count().filter( + (SignalInstance.filter_action != "snooze") | (SignalInstance.filter_action.is_(None)) + ) + + query = ( + select( + [ + count_with_snooze.label("count_with_snooze"), + count_without_snooze.label("count_without_snooze"), + ] + ) + .select_from( + assoc_signal_instance_entities.join( + Entity, assoc_signal_instance_entities.c.entity_id == Entity.id + ).join( + SignalInstance, + assoc_signal_instance_entities.c.signal_instance_id == SignalInstance.id, + ) + ) + .where( + and_( + Entity.value == entity_value, + Entity.entity_type_id == entity_type_id, + SignalInstance.created_at >= date_threshold if date_threshold else True, + ) + ) + ) + + signal_result = db_session.execute(query).fetchone() + + return SignalStats( + num_signal_instances_alerted=signal_result.count_without_snooze, + num_signal_instances_snoozed=signal_result.count_with_snooze, + num_snoozes_active=snooze_result.active_count, + num_snoozes_expired=snooze_result.expired_count, + ) diff --git a/src/dispatch/signal/views.py b/src/dispatch/signal/views.py index 398e9f5fa71e..276a497630ff 100644 --- a/src/dispatch/signal/views.py +++ b/src/dispatch/signal/views.py @@ -1,7 +1,16 @@ import logging from typing import Union -from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, Response, status +from fastapi import ( + APIRouter, + BackgroundTasks, + Depends, + HTTPException, + Query, + Request, + Response, + status, +) from pydantic.error_wrappers import ErrorWrapper, ValidationError from sqlalchemy.exc import IntegrityError @@ -30,6 +39,7 @@ SignalInstanceRead, SignalPagination, SignalRead, + SignalStats, SignalUpdate, ) from .service import ( @@ -40,6 +50,7 @@ delete_signal_filter, get, get_by_primary_or_external_id, + get_signal_stats, get_signal_engagement, get_signal_filter, update, @@ -272,6 +283,23 @@ def get_signals(common: CommonParameters): return search_filter_sort_paginate(model="Signal", **common) +@router.get("/stats", response_model=SignalStats) +def return_signal_stats( + db_session: DbSession, + entity_value: str = Query(..., description="The name of the entity"), + entity_type_id: int = Query(..., description="The ID of the entity type"), + num_days: int = Query(None, description="The number of days to look back"), +): + """Gets a signal statistics given a named entity and entity type id.""" + signal_data = get_signal_stats( + db_session=db_session, + entity_value=entity_value, + entity_type_id=entity_type_id, + num_days=num_days, + ) + return signal_data + + @router.get("/{signal_id}", response_model=SignalRead) def get_signal(db_session: DbSession, signal_id: Union[str, PrimaryKey]): """Gets a signal by its id.""" diff --git a/tests/signal/test_signal_data_service.py b/tests/signal/test_signal_data_service.py new file mode 100644 index 000000000000..914594dbd782 --- /dev/null +++ b/tests/signal/test_signal_data_service.py @@ -0,0 +1,163 @@ +from datetime import datetime, timedelta, timezone + + +def test_get_signal_stats_basic(session, entity, entity_type, signal, signal_instance): + """Test the basic functionality of get_signal_stats.""" + from dispatch.signal.service import get_signal_stats + + # Setup: Associate the entity with the signal instance + entity.entity_type = entity_type + signal_instance.entities.append(entity) + signal_instance.signal = signal + session.commit() + + # Execute: Call the service function + signal_data = get_signal_stats( + db_session=session, + entity_value=entity.value, + entity_type_id=entity_type.id, + num_days=None, + ) + + # Assert: Check the result + assert signal_data is not None + assert signal_data.num_signal_instances_alerted >= 0 + assert signal_data.num_signal_instances_snoozed >= 0 + assert signal_data.num_snoozes_active >= 0 + assert signal_data.num_snoozes_expired >= 0 + + +def test_get_signal_stats_with_num_days(session, entity, entity_type, signal, signal_instance): + """Test get_signal_stats with the num_days parameter.""" + from dispatch.signal.service import get_signal_stats + + # Setup: Associate the entity with the signal instance + entity.entity_type = entity_type + signal_instance.entities.append(entity) + signal_instance.signal = signal + # Set created_at to a specific time for testing + signal_instance.created_at = datetime.utcnow() - timedelta(days=3) + session.commit() + + # Execute: Call the service function with num_days=7 (should include our instance) + signal_data_7_days = get_signal_stats( + db_session=session, + entity_value=entity.value, + entity_type_id=entity_type.id, + num_days=7, + ) + + # Execute: Call the service function with num_days=1 (should exclude our instance) + signal_data_1_day = get_signal_stats( + db_session=session, + entity_value=entity.value, + entity_type_id=entity_type.id, + num_days=1, + ) + + # Assert: Check the results + assert signal_data_7_days is not None + assert ( + signal_data_7_days.num_signal_instances_alerted + + signal_data_7_days.num_signal_instances_snoozed + > 0 + ) + + assert signal_data_1_day is not None + assert ( + signal_data_1_day.num_signal_instances_alerted + + signal_data_1_day.num_signal_instances_snoozed + == 0 + ) + + +def test_get_signal_stats_with_snooze_filter( + session, entity, entity_type, signal, signal_instance, signal_filter +): + """Test get_signal_stats with a snooze filter applied.""" + from dispatch.signal.service import get_signal_stats + from dispatch.signal.models import SignalFilterAction + + # Setup: Associate the entity with the signal instance and add a snooze filter + entity.entity_type = entity_type + signal_instance.entities.append(entity) + signal_instance.signal = signal + signal_instance.filter_action = SignalFilterAction.snooze + + # Create a snooze filter that's active + signal_filter.action = SignalFilterAction.snooze + signal_filter.expiration = datetime.now(timezone.utc) + timedelta(days=1) + signal_filter.expression = [ + {"or": [{"model": "Entity", "field": "id", "op": "==", "value": entity.id}]} + ] + signal.filters.append(signal_filter) + + session.commit() + + # Execute: Call the service function + signal_data = get_signal_stats( + db_session=session, + entity_value=entity.value, + entity_type_id=entity_type.id, + num_days=None, + ) + + # Assert: Check the result + assert signal_data is not None + assert signal_data.num_signal_instances_snoozed > 0 + assert signal_data.num_snoozes_active > 0 + + +def test_get_signal_stats_with_expired_snooze_filter( + session, entity, entity_type, signal, signal_instance, signal_filter +): + """Test get_signal_stats with an expired snooze filter.""" + from dispatch.signal.service import get_signal_stats + from dispatch.signal.models import SignalFilterAction + + # Setup: Associate the entity with the signal instance and add an expired snooze filter + entity.entity_type = entity_type + signal_instance.entities.append(entity) + signal_instance.signal = signal + + # Create a snooze filter that's expired + signal_filter.action = SignalFilterAction.snooze + signal_filter.expiration = datetime.now(timezone.utc) - timedelta(days=1) + signal_filter.expression = [ + {"or": [{"model": "Entity", "field": "id", "op": "==", "value": entity.id}]} + ] + signal.filters.append(signal_filter) + + session.commit() + + # Execute: Call the service function + signal_data = get_signal_stats( + db_session=session, + entity_value=entity.value, + entity_type_id=entity_type.id, + num_days=None, + ) + + # Assert: Check the result + assert signal_data is not None + assert signal_data.num_snoozes_expired > 0 + + +def test_get_signal_stats_not_found(session, entity_type): + """Test get_signal_stats when no signals are found.""" + from dispatch.signal.service import get_signal_stats + + # Execute: Call the service function with a non-existent entity value + signal_data = get_signal_stats( + db_session=session, + entity_value="non-existent-entity", + entity_type_id=entity_type.id, + num_days=None, + ) + + # Assert: Check the result + assert signal_data is not None + assert signal_data.num_signal_instances_alerted == 0 + assert signal_data.num_signal_instances_snoozed == 0 + assert signal_data.num_snoozes_active == 0 + assert signal_data.num_snoozes_expired == 0