diff --git a/pyqs/decorator.py b/pyqs/decorator.py index 5f9e3d6..61c4a2b 100644 --- a/pyqs/decorator.py +++ b/pyqs/decorator.py @@ -54,6 +54,11 @@ def wrapper(*args, **kwargs): class task(object): + """ Decorator that enables sqs based task execution. If the function + accepts an optional `_context` argument, an instance of TaskContext is + passed to the task function. The context allows the function to do things + like change message visibility. """ + def __init__(self, queue=None, delay_seconds=None, custom_function_path=None): self.queue_name = queue diff --git a/pyqs/utils.py b/pyqs/utils.py index 11d0a57..7f466e5 100644 --- a/pyqs/utils.py +++ b/pyqs/utils.py @@ -3,6 +3,7 @@ import pickle import boto3 +from datetime import timedelta def decode_message(message): @@ -36,3 +37,22 @@ def get_aws_region_name(): region_name = 'us-east-1' return region_name + + +class TaskContext(object): + """ Tasks may optionally accept a _context variable. If they do, an + instance of this object is passed as the context. """ + + def __init__(self, conn, queue_url, message_id, receipt_handle, approx_receive_count): + self.conn = conn + self.queue_url = queue_url + self.message_id = message_id + self.receipt_handle = receipt_handle + self.approx_receive_count = approx_receive_count + + def change_message_visibility(self, timeout=timedelta(minutes=10)): + self.conn.change_message_visibility( + QueueUrl=self.queue_url, + ReceiptHandle=self.receipt_handle, + VisibilityTimeout=int(timeout.total_seconds()) + ) diff --git a/pyqs/worker.py b/pyqs/worker.py index 0f3c584..d6411aa 100644 --- a/pyqs/worker.py +++ b/pyqs/worker.py @@ -11,14 +11,20 @@ import time from multiprocessing import Event, Process, Queue + try: from queue import Empty, Full except ImportError: from Queue import Empty, Full +try: + from inspect import getfullargspec as get_args +except ImportError: + from inspect import getargspec as get_args + import boto3 -from pyqs.utils import get_aws_region_name, decode_message +from pyqs.utils import get_aws_region_name, decode_message, TaskContext MESSAGE_DOWNLOAD_BATCH_SIZE = 10 LONG_POLLING_INTERVAL = 20 @@ -92,6 +98,7 @@ def read_message(self): QueueUrl=self.queue_url, MaxNumberOfMessages=self.batchsize, WaitTimeSeconds=LONG_POLLING_INTERVAL, + AttributeNames=["ApproximateReceiveCount"] ).get('Messages', []) logger.debug( @@ -180,6 +187,9 @@ def process_message(self): full_task_path = message_body['task'] args = message_body['args'] kwargs = message_body['kwargs'] + message_id = message['MessageId'] + receipt_handle = message['ReceiptHandle'] + approx_receive_count = int(message.get('Attributes', {}).get("ApproximateReceiveCount", 1)) task_name = full_task_path.split(".")[-1] task_path = ".".join(full_task_path.split(".")[:-1]) @@ -188,6 +198,17 @@ def process_message(self): task = getattr(task_module, task_name) + # if the task accepts the optional _context argument, pass it the TaskContext + if '_context' in get_args(task).args: + kwargs = dict(kwargs) + kwargs['_context'] = TaskContext( + conn=self.conn, + queue_url=queue_url, + message_id=message_id, + receipt_handle=receipt_handle, + approx_receive_count=approx_receive_count + ) + current_time = time.time() if int(current_time - fetch_time) >= timeout: logger.warning( @@ -214,12 +235,20 @@ def process_message(self): traceback.format_exc(), ) ) + + # since the task failed, mark it is available again quickly (10 seconds) + self.conn.change_message_visibility( + QueueUrl=queue_url, + ReceiptHandle=receipt_handle, + VisibilityTimeout=10 + ) + return True else: end_time = time.clock() self.conn.delete_message( QueueUrl=queue_url, - ReceiptHandle=message['ReceiptHandle'] + ReceiptHandle=receipt_handle ) logger.info( "Processed task {} in {:.4f} seconds with args: {} "