From ce42e5eb90983c5f3efb2e186d637a282b1cb5d1 Mon Sep 17 00:00:00 2001 From: Marc Vilanova <39573146+mvilanova@users.noreply.github.com> Date: Fri, 14 Mar 2025 08:33:15 -0700 Subject: [PATCH 1/2] =?UTF-8?q?Revert=20"Revert=20"feat(signals):=20improv?= =?UTF-8?q?es=20aws=20sqs=20signal=20consumer=20plugin=20(#5817=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 2d4421805030434377cb388257ec21810547a59f. --- src/dispatch/plugins/dispatch_aws/plugin.py | 299 ++++++++++++++------ 1 file changed, 211 insertions(+), 88 deletions(-) diff --git a/src/dispatch/plugins/dispatch_aws/plugin.py b/src/dispatch/plugins/dispatch_aws/plugin.py index 4f66bc358cb9..6cf1004f5f0b 100644 --- a/src/dispatch/plugins/dispatch_aws/plugin.py +++ b/src/dispatch/plugins/dispatch_aws/plugin.py @@ -9,6 +9,8 @@ import base64 import json import logging +import signal as os_signal +import time import zlib from typing import TypedDict @@ -18,11 +20,14 @@ from sqlalchemy.exc import IntegrityError, ResourceClosedError from sqlalchemy.orm import Session +from dispatch.database.core import get_session from dispatch.metrics import provider as metrics_provider from dispatch.plugins.bases import SignalConsumerPlugin from dispatch.plugins.dispatch_aws.config import AWSSQSConfiguration +from dispatch.project import service as project_service from dispatch.project.models import Project from dispatch.signal import service as signal_service +from dispatch.signal.exceptions import SignalNotIdentifiedException from dispatch.signal.models import SignalInstanceCreate from . import __version__ @@ -53,103 +58,221 @@ class AWSSQSSignalConsumerPlugin(SignalConsumerPlugin): def __init__(self): self.configuration_schema = AWSSQSConfiguration + self._shutdown = False - def consume(self, db_session: Session, project: Project) -> None: - client = boto3.client("sqs", region_name=self.configuration.region) - queue_url: str = client.get_queue_url( - QueueName=self.configuration.queue_name, - QueueOwnerAWSAccountId=self.configuration.queue_owner, - )["QueueUrl"] - - while True: - response = client.receive_message( - QueueUrl=queue_url, - MaxNumberOfMessages=self.configuration.batch_size, - VisibilityTimeout=40, - WaitTimeSeconds=20, + def _setup_signal_handlers(self): + """Setup handlers for graceful shutdown.""" + + def handle_shutdown(signum, frame): + self._shutdown = True + log.info("Received shutdown signal, finishing current batch before exiting...") + + # Handle graceful shutdown signals + os_signal.signal(os_signal.SIGTERM, handle_shutdown) + os_signal.signal(os_signal.SIGINT, handle_shutdown) + + def _process_message( + self, db_session: Session, message: dict, project: Project + ) -> SqsEntries | None: + """Process a single SQS message and return entry for deletion if successful. + + Uses a nested transaction (SAVEPOINT) for message-level isolation within the batch transaction. + If the message processing fails, only its SAVEPOINT is rolled back, not affecting other messages. + + Args: + db_session: The SQLAlchemy session for database operations + message: The SQS message to process + project: The project context for the signal + + Returns: + SqsEntries if message was processed successfully, None otherwise + """ + try: + message_body = json.loads(message["Body"]) + message_body_message = message_body.get("Message") + message_attributes = message_body.get("MessageAttributes", {}) + + if message_attributes.get("compressed", {}).get("Value") == "zlib": + # Message is compressed, decompress it + message_body_message = decompress_json(message_body_message) + + signal_data = json.loads(message_body_message) + except Exception as e: + log.exception(f"Unable to extract signal data from SQS message: {e}") + return None + + try: + signal_instance_in = SignalInstanceCreate( + project=project, raw=signal_data, **signal_data + ) + except ValidationError as e: + log.warning( + f"Received a signal instance that does not conform to the SignalInstanceCreate pydantic model. Skipping creation: {e}" ) - if not response.get("Messages") or len(response["Messages"]) == 0: - log.info("No messages received from SQS.") - continue + return None - entries: list[SqsEntries] = [] - for message in response["Messages"]: - try: - message_body = json.loads(message["Body"]) - message_body_message = message_body.get("Message") - message_attributes = message_body.get("MessageAttributes", {}) + # if the signal has an existing uuid we check if it already exists + if signal_instance_in.raw and signal_instance_in.raw.get("id"): + if signal_service.get_signal_instance( + db_session=db_session, signal_instance_id=signal_instance_in.raw["id"] + ): + log.info( + f"Received a signal that already exists in the database. Skipping signal instance creation: {signal_instance_in.raw['id']}" + ) + return None - if message_attributes.get("compressed", {}).get("Value") == "zlib": - # Message is compressed, decompress it - message_body_message = decompress_json(message_body_message) + try: + # Get the signal definition first + signal_definition = None + external_id = None - signal_data = json.loads(message_body_message) - except Exception as e: - log.exception(f"Unable to extract signal data from SQS message: {e}") - continue + if not signal_instance_in.project: + log.warning( + f"No project provided for signal instance creation. Skipping signal instance creation: {signal_instance_in.raw['id']}" + ) + return None + + project_obj = project_service.get_by_name_or_default( + db_session=db_session, project_in=signal_instance_in.project + ) + + if not signal_instance_in.signal: + external_id = signal_instance_in.external_id + + # this assumes the external_ids are uuids + if not external_id: + msg = "A detection external id must be provided in order to get the signal definition." + raise SignalNotIdentifiedException(msg) + + # Try to get the signal definition by external ID or variant + signal_definition = signal_service.get_by_variant_or_external_id( + db_session=db_session, + project_id=project_obj.id, + external_id=external_id, + ) + + if not signal_definition: + # Fall back to default signal definition + signal_definition = signal_service.get_default( + db_session=db_session, + project_id=project_obj.id, + ) + + if not signal_definition: + log.warning( + f"No signal definition could be found for external_id {external_id} and no default exists." + ) + return None + + signal_instance_in.signal = signal_definition + + with db_session.begin_nested(): + # Use create_instance directly to avoid the extra commit in create_signal_instance + signal_instance = signal_service.create_instance( + db_session=db_session, signal_instance_in=signal_instance_in + ) + metrics_provider.counter( + "aws-sqs-signal-consumer.signal.received", + tags={ + "signalName": signal_instance.signal.name, + "externalId": signal_instance.signal.external_id, + }, + ) + + log.debug( + f"Received a signal with name {signal_instance.signal.name} and id {signal_instance.signal.id}" + ) + + return {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]} + + except IntegrityError as e: + if isinstance(e.orig, UniqueViolation): + log.info( + f"Received a signal that already exists in the database. Skipping signal instance creation: {e}" + ) + else: + log.exception( + f"Encountered an integrity error when trying to create a signal instance: {e}" + ) + return None + except (ResourceClosedError, Exception) as e: + log.exception( + f"Encountered an error when trying to create a signal instance. Signal name/variant: {signal_instance_in.raw.get('name', '') if signal_instance_in.raw else ''} / {signal_instance_in.raw.get('variant', '') if signal_instance_in.raw else ''}. Error: {e}" + ) + return None + + def consume(self, db_session: Session, project: Project) -> None: + """Consume messages from SQS queue. + + Implements a long-running consumer with graceful shutdown handling. + Uses the application's session management with nested transactions for message-level isolation. + + Args: + db_session: Initial SQLAlchemy session (will be closed after setup) + project: The project context for signal processing + + Note: + - Uses dispatch's get_session() context manager for proper session lifecycle + - Individual messages use nested transactions (SAVEPOINTs) + - Handles SIGTERM/SIGINT for graceful shutdown + - Includes automatic session tracking and cleanup + """ + try: + self._setup_signal_handlers() + + client = boto3.client("sqs", region_name=self.configuration.region) + queue_url: str = client.get_queue_url( + QueueName=self.configuration.queue_name, + QueueOwnerAWSAccountId=self.configuration.queue_owner, + )["QueueUrl"] + + # Close the original session as we'll use get_session() + db_session.close() + + while not self._shutdown: try: - signal_instance_in = SignalInstanceCreate( - project=project, raw=signal_data, **signal_data - ) - except ValidationError as e: - log.warning( - f"Received a signal instance that does not conform to the SignalInstanceCreate pydantic model. Skipping creation: {e}" + response = client.receive_message( + QueueUrl=queue_url, + MaxNumberOfMessages=self.configuration.batch_size, + # Increased visibility timeout to handle larger batches and potential delays + # 5 minutes should cover most processing scenarios while preventing + # excessive message lock time + VisibilityTimeout=300, # 5 minutes + # Long polling wait time - reduces empty responses while allowing + # reasonable shutdown time + WaitTimeSeconds=20, ) - continue - - # if the signal has an existing uuid we check if it already exists - if signal_instance_in.raw and signal_instance_in.raw.get("id"): - if signal_service.get_signal_instance( - db_session=db_session, signal_instance_id=signal_instance_in.raw["id"] - ): - log.info( - f"Received a signal that already exists in the database. Skipping signal instance creation: {signal_instance_in.raw['id']}" - ) + + if not response.get("Messages"): + log.info("No messages received from SQS.") continue - try: - with db_session.begin_nested(): - signal_instance = signal_service.create_signal_instance( - db_session=db_session, - signal_instance_in=signal_instance_in, - ) - except IntegrityError as e: - if isinstance(e.orig, UniqueViolation): - log.info( - f"Received a signal that already exists in the database. Skipping signal instance creation: {e}" - ) - else: - log.exception( - f"Encountered an integrity error when trying to create a signal instance: {e}" - ) - continue - except ResourceClosedError as e: - log.warning( - f"Encountered an error when trying to create a signal instance. The plugin will retry again as the message hasn't been deleted from the SQS queue. Signal name/variant: {signal_instance_in.raw['name'] if signal_instance_in.raw and signal_instance_in.raw['name'] else signal_instance_in.raw['variant']}. Error: {e}" - ) - db_session.rollback() - continue + entries: list[SqsEntries] = [] + with get_session() as batch_session: + # Batch transaction - commits all successful messages or none + with batch_session.begin(): + for message in response["Messages"]: + if self._shutdown: + log.info("Shutdown requested, stopping message processing...") + break + entry = self._process_message(batch_session, message, project) + if entry: + entries.append(entry) + + # Only delete messages that were successfully processed + if entries: + client.delete_message_batch(QueueUrl=queue_url, Entries=entries) + except Exception as e: - log.exception( - f"Encountered an error when trying to create a signal instance. Signal name/variant: {signal_instance_in.raw['name'] if signal_instance_in.raw and signal_instance_in.raw['name'] else signal_instance_in.raw['variant']}. Error: {e}" - ) - db_session.rollback() - continue - else: - metrics_provider.counter( - "aws-sqs-signal-consumer.signal.received", - tags={ - "signalName": signal_instance.signal.name, - "externalId": signal_instance.signal.external_id, - }, - ) - log.debug( - f"Received a signal with name {signal_instance.signal.name} and id {signal_instance.signal.id}" - ) - entries.append( - {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]} - ) + log.exception("Error processing message batch: %s", e) + if not self._shutdown: + time.sleep(1) # Prevent tight error loops - if entries: - client.delete_message_batch(QueueUrl=queue_url, Entries=entries) + except Exception as e: + log.exception("Fatal error in consumer: %s", e) + raise + finally: + log.info("Consumer shutting down...") + if db_session: + db_session.close() From f5e2e6d749d15835736c44264d02869b28bdefde Mon Sep 17 00:00:00 2001 From: Marc Vilanova Date: Fri, 14 Mar 2025 11:54:14 -0700 Subject: [PATCH 2/2] improvements --- src/dispatch/plugins/dispatch_aws/plugin.py | 29 +++++++++++++-------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/dispatch/plugins/dispatch_aws/plugin.py b/src/dispatch/plugins/dispatch_aws/plugin.py index 6cf1004f5f0b..302d7210922d 100644 --- a/src/dispatch/plugins/dispatch_aws/plugin.py +++ b/src/dispatch/plugins/dispatch_aws/plugin.py @@ -107,17 +107,17 @@ def _process_message( ) except ValidationError as e: log.warning( - f"Received a signal instance that does not conform to the SignalInstanceCreate pydantic model. Skipping creation: {e}" + f"Received a signal that does not conform to the SignalInstanceCreate pydantic model. Skipping creation and deleting message: {e}" ) - return None + return "DELETE_MESSAGE" # if the signal has an existing uuid we check if it already exists if signal_instance_in.raw and signal_instance_in.raw.get("id"): if signal_service.get_signal_instance( - db_session=db_session, signal_instance_id=signal_instance_in.raw["id"] + db_session=db_session, signal_instance_id=signal_instance_in.raw.get("id") ): log.info( - f"Received a signal that already exists in the database. Skipping signal instance creation: {signal_instance_in.raw['id']}" + f"Received a signal that already exists in the database. Skipping signal instance creation: {signal_instance_in.raw.get('id')}" ) return None @@ -128,7 +128,7 @@ def _process_message( if not signal_instance_in.project: log.warning( - f"No project provided for signal instance creation. Skipping signal instance creation: {signal_instance_in.raw['id']}" + f"No project provided for signal instance creation. Skipping signal instance creation: {signal_instance_in.raw.get('id')}" ) return None @@ -245,22 +245,29 @@ def consume(self, db_session: Session, project: Project) -> None: ) if not response.get("Messages"): - log.info("No messages received from SQS.") + log.debug("No messages received from SQS.") continue entries: list[SqsEntries] = [] with get_session() as batch_session: - # Batch transaction - commits all successful messages or none - with batch_session.begin(): + # Use nested transaction (SAVEPOINT) for the batch + with batch_session.begin_nested(): for message in response["Messages"]: if self._shutdown: log.info("Shutdown requested, stopping message processing...") break entry = self._process_message(batch_session, message, project) - if entry: + if isinstance(entry, dict): entries.append(entry) - - # Only delete messages that were successfully processed + elif entry == "DELETE_MESSAGE": + # Force deletion for hopelessly invalid messages + client.delete_message( + QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"] + ) + # Commit if everything else is good + batch_session.commit() + + # Now delete successfully processed messages in bulk if entries: client.delete_message_batch(QueueUrl=queue_url, Entries=entries)