Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(signals): improves aws sqs signal consumer plugin #5829

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 218 additions & 88 deletions src/dispatch/plugins/dispatch_aws/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import base64
import json
import logging
import signal as os_signal
import time
import zlib
from typing import TypedDict

Expand All @@ -18,11 +20,14 @@
from sqlalchemy.exc import IntegrityError, ResourceClosedError
from sqlalchemy.orm import Session

from dispatch.database.core import get_session
from dispatch.metrics import provider as metrics_provider
from dispatch.plugins.bases import SignalConsumerPlugin
from dispatch.plugins.dispatch_aws.config import AWSSQSConfiguration
from dispatch.project import service as project_service
from dispatch.project.models import Project
from dispatch.signal import service as signal_service
from dispatch.signal.exceptions import SignalNotIdentifiedException
from dispatch.signal.models import SignalInstanceCreate

from . import __version__
Expand Down Expand Up @@ -53,103 +58,228 @@ class AWSSQSSignalConsumerPlugin(SignalConsumerPlugin):

def __init__(self):
self.configuration_schema = AWSSQSConfiguration
self._shutdown = False

def consume(self, db_session: Session, project: Project) -> None:
client = boto3.client("sqs", region_name=self.configuration.region)
queue_url: str = client.get_queue_url(
QueueName=self.configuration.queue_name,
QueueOwnerAWSAccountId=self.configuration.queue_owner,
)["QueueUrl"]

while True:
response = client.receive_message(
QueueUrl=queue_url,
MaxNumberOfMessages=self.configuration.batch_size,
VisibilityTimeout=40,
WaitTimeSeconds=20,
def _setup_signal_handlers(self):
"""Setup handlers for graceful shutdown."""

def handle_shutdown(signum, frame):
self._shutdown = True
log.info("Received shutdown signal, finishing current batch before exiting...")

# Handle graceful shutdown signals
os_signal.signal(os_signal.SIGTERM, handle_shutdown)
os_signal.signal(os_signal.SIGINT, handle_shutdown)

def _process_message(
self, db_session: Session, message: dict, project: Project
) -> SqsEntries | None:
"""Process a single SQS message and return entry for deletion if successful.

Uses a nested transaction (SAVEPOINT) for message-level isolation within the batch transaction.
If the message processing fails, only its SAVEPOINT is rolled back, not affecting other messages.

Args:
db_session: The SQLAlchemy session for database operations
message: The SQS message to process
project: The project context for the signal

Returns:
SqsEntries if message was processed successfully, None otherwise
"""
try:
message_body = json.loads(message["Body"])
message_body_message = message_body.get("Message")
message_attributes = message_body.get("MessageAttributes", {})

if message_attributes.get("compressed", {}).get("Value") == "zlib":
# Message is compressed, decompress it
message_body_message = decompress_json(message_body_message)

signal_data = json.loads(message_body_message)
except Exception as e:
log.exception(f"Unable to extract signal data from SQS message: {e}")
return None

try:
signal_instance_in = SignalInstanceCreate(
project=project, raw=signal_data, **signal_data
)
except ValidationError as e:
log.warning(
f"Received a signal that does not conform to the SignalInstanceCreate pydantic model. Skipping creation and deleting message: {e}"
)
if not response.get("Messages") or len(response["Messages"]) == 0:
log.info("No messages received from SQS.")
continue
return "DELETE_MESSAGE"

entries: list[SqsEntries] = []
for message in response["Messages"]:
try:
message_body = json.loads(message["Body"])
message_body_message = message_body.get("Message")
message_attributes = message_body.get("MessageAttributes", {})
# if the signal has an existing uuid we check if it already exists
if signal_instance_in.raw and signal_instance_in.raw.get("id"):
if signal_service.get_signal_instance(
db_session=db_session, signal_instance_id=signal_instance_in.raw.get("id")
):
log.info(
f"Received a signal that already exists in the database. Skipping signal instance creation: {signal_instance_in.raw.get('id')}"
)
return None

if message_attributes.get("compressed", {}).get("Value") == "zlib":
# Message is compressed, decompress it
message_body_message = decompress_json(message_body_message)
try:
# Get the signal definition first
signal_definition = None
external_id = None

signal_data = json.loads(message_body_message)
except Exception as e:
log.exception(f"Unable to extract signal data from SQS message: {e}")
continue
if not signal_instance_in.project:
log.warning(
f"No project provided for signal instance creation. Skipping signal instance creation: {signal_instance_in.raw.get('id')}"
)
return None

project_obj = project_service.get_by_name_or_default(
db_session=db_session, project_in=signal_instance_in.project
)

if not signal_instance_in.signal:
external_id = signal_instance_in.external_id

# this assumes the external_ids are uuids
if not external_id:
msg = "A detection external id must be provided in order to get the signal definition."
raise SignalNotIdentifiedException(msg)

# Try to get the signal definition by external ID or variant
signal_definition = signal_service.get_by_variant_or_external_id(
db_session=db_session,
project_id=project_obj.id,
external_id=external_id,
)

if not signal_definition:
# Fall back to default signal definition
signal_definition = signal_service.get_default(
db_session=db_session,
project_id=project_obj.id,
)

if not signal_definition:
log.warning(
f"No signal definition could be found for external_id {external_id} and no default exists."
)
return None

signal_instance_in.signal = signal_definition

with db_session.begin_nested():
# Use create_instance directly to avoid the extra commit in create_signal_instance
signal_instance = signal_service.create_instance(
db_session=db_session, signal_instance_in=signal_instance_in
)

metrics_provider.counter(
"aws-sqs-signal-consumer.signal.received",
tags={
"signalName": signal_instance.signal.name,
"externalId": signal_instance.signal.external_id,
},
)

log.debug(
f"Received a signal with name {signal_instance.signal.name} and id {signal_instance.signal.id}"
)

return {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]}

except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
log.info(
f"Received a signal that already exists in the database. Skipping signal instance creation: {e}"
)
else:
log.exception(
f"Encountered an integrity error when trying to create a signal instance: {e}"
)
return None
except (ResourceClosedError, Exception) as e:
log.exception(
f"Encountered an error when trying to create a signal instance. Signal name/variant: {signal_instance_in.raw.get('name', '') if signal_instance_in.raw else ''} / {signal_instance_in.raw.get('variant', '') if signal_instance_in.raw else ''}. Error: {e}"
)
return None

def consume(self, db_session: Session, project: Project) -> None:
"""Consume messages from SQS queue.

Implements a long-running consumer with graceful shutdown handling.
Uses the application's session management with nested transactions for message-level isolation.

Args:
db_session: Initial SQLAlchemy session (will be closed after setup)
project: The project context for signal processing

Note:
- Uses dispatch's get_session() context manager for proper session lifecycle
- Individual messages use nested transactions (SAVEPOINTs)
- Handles SIGTERM/SIGINT for graceful shutdown
- Includes automatic session tracking and cleanup
"""
try:
self._setup_signal_handlers()

client = boto3.client("sqs", region_name=self.configuration.region)
queue_url: str = client.get_queue_url(
QueueName=self.configuration.queue_name,
QueueOwnerAWSAccountId=self.configuration.queue_owner,
)["QueueUrl"]

# Close the original session as we'll use get_session()
db_session.close()

while not self._shutdown:
try:
signal_instance_in = SignalInstanceCreate(
project=project, raw=signal_data, **signal_data
)
except ValidationError as e:
log.warning(
f"Received a signal instance that does not conform to the SignalInstanceCreate pydantic model. Skipping creation: {e}"
response = client.receive_message(
QueueUrl=queue_url,
MaxNumberOfMessages=self.configuration.batch_size,
# Increased visibility timeout to handle larger batches and potential delays
# 5 minutes should cover most processing scenarios while preventing
# excessive message lock time
VisibilityTimeout=300, # 5 minutes
# Long polling wait time - reduces empty responses while allowing
# reasonable shutdown time
WaitTimeSeconds=20,
)
continue

# if the signal has an existing uuid we check if it already exists
if signal_instance_in.raw and signal_instance_in.raw.get("id"):
if signal_service.get_signal_instance(
db_session=db_session, signal_instance_id=signal_instance_in.raw["id"]
):
log.info(
f"Received a signal that already exists in the database. Skipping signal instance creation: {signal_instance_in.raw['id']}"
)

if not response.get("Messages"):
log.debug("No messages received from SQS.")
continue

try:
with db_session.begin_nested():
signal_instance = signal_service.create_signal_instance(
db_session=db_session,
signal_instance_in=signal_instance_in,
)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
log.info(
f"Received a signal that already exists in the database. Skipping signal instance creation: {e}"
)
else:
log.exception(
f"Encountered an integrity error when trying to create a signal instance: {e}"
)
continue
except ResourceClosedError as e:
log.warning(
f"Encountered an error when trying to create a signal instance. The plugin will retry again as the message hasn't been deleted from the SQS queue. Signal name/variant: {signal_instance_in.raw['name'] if signal_instance_in.raw and signal_instance_in.raw['name'] else signal_instance_in.raw['variant']}. Error: {e}"
)
db_session.rollback()
continue
entries: list[SqsEntries] = []
with get_session() as batch_session:
# Use nested transaction (SAVEPOINT) for the batch
with batch_session.begin_nested():
for message in response["Messages"]:
if self._shutdown:
log.info("Shutdown requested, stopping message processing...")
break
entry = self._process_message(batch_session, message, project)
if isinstance(entry, dict):
entries.append(entry)
elif entry == "DELETE_MESSAGE":
# Force deletion for hopelessly invalid messages
client.delete_message(
QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"]
)
# Commit if everything else is good
batch_session.commit()

# Now delete successfully processed messages in bulk
if entries:
client.delete_message_batch(QueueUrl=queue_url, Entries=entries)

except Exception as e:
log.exception(
f"Encountered an error when trying to create a signal instance. Signal name/variant: {signal_instance_in.raw['name'] if signal_instance_in.raw and signal_instance_in.raw['name'] else signal_instance_in.raw['variant']}. Error: {e}"
)
db_session.rollback()
continue
else:
metrics_provider.counter(
"aws-sqs-signal-consumer.signal.received",
tags={
"signalName": signal_instance.signal.name,
"externalId": signal_instance.signal.external_id,
},
)
log.debug(
f"Received a signal with name {signal_instance.signal.name} and id {signal_instance.signal.id}"
)
entries.append(
{"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]}
)
log.exception("Error processing message batch: %s", e)
if not self._shutdown:
time.sleep(1) # Prevent tight error loops

if entries:
client.delete_message_batch(QueueUrl=queue_url, Entries=entries)
except Exception as e:
log.exception("Fatal error in consumer: %s", e)
raise
finally:
log.info("Consumer shutting down...")
if db_session:
db_session.close()
Loading