diff --git a/README.md b/README.md index 19f04d74..99063f5e 100644 --- a/README.md +++ b/README.md @@ -50,8 +50,11 @@ app.start(); - The queue is polled continuously for messages using [long polling](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-long-polling.html). - Throwing an error (or returning a rejected promise) from the handler function will cause the message to be left on the queue. An [SQS redrive policy](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/SQSDeadLetterQueue.html) can be used to move messages that cannot be processed to a dead letter queue. -- By default messages are processed one at a time – a new message won't be received until the first one has been processed. To process messages in parallel, use the `batchSize` option [detailed here](https://bbc.github.io/sqs-consumer/interfaces/ConsumerOptions.html#batchSize). - - It's also important to await any processing that you are doing to ensure that messages are processed one at a time. +- Messages can be processed in two ways: + 1. Individual message processing (default): Messages are processed one at a time by default. You can control parallel processing using the `concurrency` option to specify how many messages can be processed simultaneously. + 2. Batch processing: Using the `batchSize` option [detailed here](https://bbc.github.io/sqs-consumer/interfaces/ConsumerOptions.html#batchSize) processes messages in batches. When using batch processing, the entire batch is processed together and concurrency controls don't apply. + - Note: When both `batchSize` and `concurrency` are set, `concurrency` will automatically be set to at least match `batchSize` to maintain compatibility. +- It's important to await any processing that you are doing to ensure messages are processed correctly. - By default, messages that are sent to the `handleMessage` and `handleMessageBatch` functions will be considered as processed if they return without an error. - To acknowledge individual messages, please return the message that you want to acknowledge if you are using `handleMessage` or the messages for `handleMessageBatch`. - To note, returning an empty object or an empty array will be considered an acknowledgement of no message(s) and will result in no messages being deleted. If you would like to change this behaviour, please use the `alwaysAcknowledge` option [detailed here](https://bbc.github.io/sqs-consumer/interfaces/ConsumerOptions.html). diff --git a/src/consumer.ts b/src/consumer.ts index 12c27ed4..08be057d 100644 --- a/src/consumer.ts +++ b/src/consumer.ts @@ -63,11 +63,14 @@ export class Consumer extends TypedEventEmitter { private authenticationErrorTimeout: number; private pollingWaitTimeMs: number; private pollingCompleteWaitTimeMs: number; + private concurrencyWaitTimeMs: number; private heartbeatInterval: number; private isPolling = false; private stopRequestedAtTimestamp: number; public abortController: AbortController; private extendedAWSErrors: boolean; + private concurrency: number; + private inFlightMessages = 0; constructor(options: ConsumerOptions) { super(options.queueUrl); @@ -92,9 +95,11 @@ export class Consumer extends TypedEventEmitter { options.authenticationErrorTimeout ?? 10000; this.pollingWaitTimeMs = options.pollingWaitTimeMs ?? 0; this.pollingCompleteWaitTimeMs = options.pollingCompleteWaitTimeMs ?? 0; + this.concurrencyWaitTimeMs = options.concurrencyWaitTimeMs ?? 50; this.shouldDeleteMessages = options.shouldDeleteMessages ?? true; this.alwaysAcknowledge = options.alwaysAcknowledge ?? false; this.extendedAWSErrors = options.extendedAWSErrors ?? false; + this.concurrency = Math.max(options.concurrency ?? 1, this.batchSize); this.sqs = options.sqs || new SQSClient({ @@ -328,21 +333,47 @@ export class Consumer extends TypedEventEmitter { private async handleSqsResponse( response: ReceiveMessageCommandOutput, ): Promise { - if (hasMessages(response)) { - if (this.handleMessageBatch) { - await this.processMessageBatch(response.Messages); - } else { - await Promise.all( - response.Messages.map((message: Message) => - this.processMessage(message), - ), - ); - } - - this.emit("response_processed"); - } else if (response) { + if (!hasMessages(response)) { this.emit("empty"); + return; } + + const messages = response.Messages; + + if (this.handleMessageBatch) { + await this.processMessageBatch(messages); + } else { + let waitingMessages = 0; + await Promise.all( + messages.map(async (message) => { + while ( + this.batchSize === 1 && + this.inFlightMessages >= this.concurrency + ) { + if (waitingMessages === 0) { + this.emit("concurrency_limit_reached", { + limit: this.concurrency, + waiting: messages.length - this.inFlightMessages, + }); + } + waitingMessages++; + await new Promise((resolve) => + setTimeout(resolve, this.concurrencyWaitTimeMs), + ); + if (this.stopped) return; + } + waitingMessages = Math.max(0, waitingMessages - 1); + this.inFlightMessages++; + try { + await this.processMessage(message); + } finally { + this.inFlightMessages--; + } + }), + ); + } + + this.emit("response_processed"); } /** diff --git a/src/types.ts b/src/types.ts index 7fa0fe7a..a561b862 100644 --- a/src/types.ts +++ b/src/types.ts @@ -157,6 +157,17 @@ export interface ConsumerOptions { * example to add middlewares. */ postReceiveMessageCallback?(): Promise; + /** + * The maximum number of messages that can be processed concurrently. + * If not provided, messages will be processed sequentially. + * @defaultvalue `1` + */ + concurrency?: number; + /** + * The duration (in milliseconds) to wait between concurrency checks when the concurrency limit is reached. + * @defaultvalue `50` + */ + concurrencyWaitTimeMs?: number; /** * Set this to `true` if you want to receive additional information about the error * that occurred from AWS, such as the response and metadata. @@ -171,7 +182,9 @@ export type UpdatableOptions = | "visibilityTimeout" | "batchSize" | "waitTimeSeconds" - | "pollingWaitTimeMs"; + | "pollingWaitTimeMs" + | "concurrency" + | "concurrencyWaitTimeMs"; /** * The options for the stop method. @@ -257,6 +270,11 @@ export interface Events { * Fired when the Consumer has waited for polling to complete and is stopping due to a timeout. */ waiting_for_polling_to_complete_timeout_exceeded: []; + /** + * Fired when concurrency limit is hit and messages are waiting to be processed. + * Includes the current concurrency limit and number of messages waiting. + */ + concurrency_limit_reached: [{ limit: number; waiting: number }]; } /** diff --git a/src/validation.ts b/src/validation.ts index 6de283d6..bde588c4 100644 --- a/src/validation.ts +++ b/src/validation.ts @@ -50,6 +50,23 @@ function validateOption( throw new Error("pollingWaitTimeMs must be greater than 0."); } break; + case "concurrency": + if (!Number.isInteger(value) || value < 1) { + throw new Error("concurrency must be a positive integer."); + } + if (allOptions.batchSize && value < allOptions.batchSize) { + throw new Error( + "concurrency must be greater than or equal to batchSize.", + ); + } + break; + case "concurrencyWaitTimeMs": + if (!Number.isInteger(value) || value < 0) { + throw new Error( + "concurrencyWaitTimeMs must be a non-negative integer.", + ); + } + break; default: if (strict) { throw new Error(`The update ${option} cannot be updated`); @@ -78,6 +95,16 @@ function assertOptions(options: ConsumerOptions): void { if (options.heartbeatInterval) { validateOption("heartbeatInterval", options.heartbeatInterval, options); } + if (options.concurrency !== undefined) { + validateOption("concurrency", options.concurrency, options); + } + if (options.concurrencyWaitTimeMs !== undefined) { + validateOption( + "concurrencyWaitTimeMs", + options.concurrencyWaitTimeMs, + options, + ); + } } /** diff --git a/test/tests/consumer.test.ts b/test/tests/consumer.test.ts index 88cd21b4..a2d47268 100644 --- a/test/tests/consumer.test.ts +++ b/test/tests/consumer.test.ts @@ -20,7 +20,8 @@ const sandbox = sinon.createSandbox(); const AUTHENTICATION_ERROR_TIMEOUT = 20; const POLLING_TIMEOUT = 100; -const QUEUE_URL = "some-queue-url"; +const QUEUE_URL = + "https://sqs.some-region.amazonaws.com/123456789012/queue-name"; const REGION = "some-region"; const mockReceiveMessage = sinon.match.instanceOf(ReceiveMessageCommand); @@ -154,6 +155,102 @@ describe("Consumer", () => { }); }, "heartbeatInterval must be less than visibilityTimeout."); }); + + it("requires concurrency to be a positive integer", () => { + assert.throws(() => { + new Consumer({ + region: REGION, + queueUrl: QUEUE_URL, + handleMessage, + concurrency: 0, + }); + }, "concurrency must be a positive integer."); + + assert.throws(() => { + new Consumer({ + region: REGION, + queueUrl: QUEUE_URL, + handleMessage, + concurrency: -1, + }); + }, "concurrency must be a positive integer."); + + assert.throws(() => { + new Consumer({ + region: REGION, + queueUrl: QUEUE_URL, + handleMessage, + concurrency: 1.5, + }); + }, "concurrency must be a positive integer."); + }); + + it("requires concurrency to be greater than or equal to batchSize", () => { + assert.throws(() => { + new Consumer({ + region: REGION, + queueUrl: QUEUE_URL, + handleMessage, + batchSize: 5, + concurrency: 3, + }); + }, "concurrency must be greater than or equal to batchSize."); + }); + + it("allows concurrency to be equal to batchSize", () => { + assert.doesNotThrow(() => { + new Consumer({ + region: REGION, + queueUrl: QUEUE_URL, + handleMessage, + batchSize: 5, + concurrency: 5, + }); + }); + }); + + it("allows concurrency to be greater than batchSize", () => { + assert.doesNotThrow(() => { + new Consumer({ + region: REGION, + queueUrl: QUEUE_URL, + handleMessage, + batchSize: 5, + concurrency: 10, + }); + }); + }); + + it("requires concurrencyWaitTimeMs to be a non-negative integer", () => { + assert.throws(() => { + new Consumer({ + region: REGION, + queueUrl: QUEUE_URL, + handleMessage, + concurrencyWaitTimeMs: -1, + }); + }, "concurrencyWaitTimeMs must be a non-negative integer."); + + assert.throws(() => { + new Consumer({ + region: REGION, + queueUrl: QUEUE_URL, + handleMessage, + concurrencyWaitTimeMs: 1.5, + }); + }, "concurrencyWaitTimeMs must be a non-negative integer."); + }); + + it("allows concurrencyWaitTimeMs to be zero", () => { + assert.doesNotThrow(() => { + new Consumer({ + region: REGION, + queueUrl: QUEUE_URL, + handleMessage, + concurrencyWaitTimeMs: 0, + }); + }); + }); }); describe(".create", () => { @@ -1860,6 +1957,332 @@ describe("Consumer", () => { assert.equal(err.queueUrl, QUEUE_URL); assert.deepEqual(err.messageIds, ["1", "2"]); }); + + describe("concurrency", () => { + it("processes messages respecting the concurrency limit", async () => { + const message1 = { MessageId: "1", ReceiptHandle: "1", Body: "1" }; + const message2 = { MessageId: "2", ReceiptHandle: "2", Body: "2" }; + const message3 = { MessageId: "3", ReceiptHandle: "3", Body: "3" }; + + handleMessage.callsFake( + () => new Promise((resolve) => setTimeout(resolve, 2000)), + ); + sqs.send + .withArgs(mockReceiveMessage) + .resolves({ Messages: [message1, message2, message3] }); + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessage, + sqs, + concurrency: 2, + batchSize: 1, + }); + + consumer.start(); + await clock.tickAsync(0); + + // First two messages should be in flight + assert.equal(handleMessage.callCount, 2); + assert.deepEqual(handleMessage.firstCall.args[0], message1); + assert.deepEqual(handleMessage.secondCall.args[0], message2); + + // Complete first two messages + await clock.tickAsync(2000); + + // Third message should now be processing + assert.equal(handleMessage.callCount, 3); + assert.deepEqual(handleMessage.thirdCall.args[0], message3); + + consumer.stop(); + }); + + it("bypasses concurrency limits in batch mode", async () => { + const messages = [ + { MessageId: "1", ReceiptHandle: "1", Body: "1" }, + { MessageId: "2", ReceiptHandle: "2", Body: "2" }, + { MessageId: "3", ReceiptHandle: "3", Body: "3" }, + ]; + + handleMessageBatch.callsFake( + () => new Promise((resolve) => setTimeout(resolve, 2000)), + ); + sqs.send.withArgs(mockReceiveMessage).resolves({ Messages: messages }); + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessageBatch, + sqs, + batchSize: 3, + }); + + consumer.start(); + await clock.tickAsync(0); + + // All messages should be processed in one batch + assert.equal(handleMessageBatch.callCount, 1); + assert.deepEqual(handleMessageBatch.firstCall.args[0], messages); + + await clock.tickAsync(2000); + consumer.stop(); + }); + + it("handles dynamic concurrency updates", async () => { + const messages = [ + { MessageId: "1", ReceiptHandle: "1", Body: "1" }, + { MessageId: "2", ReceiptHandle: "2", Body: "2" }, + { MessageId: "3", ReceiptHandle: "3", Body: "3" }, + ]; + + handleMessage.callsFake( + () => new Promise((resolve) => setTimeout(resolve, 2000)), + ); + sqs.send.withArgs(mockReceiveMessage).resolves({ Messages: messages }); + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessage, + sqs, + concurrency: 1, + batchSize: 1, + pollingWaitTimeMs: 0, + }); + + consumer.start(); + await clock.tickAsync(0); + + // First message starts processing + assert.equal(handleMessage.callCount, 1); + assert.deepEqual(handleMessage.firstCall.args[0], messages[0]); + + // Update concurrency to 3 + consumer.updateOption("concurrency", 3); + + // Need to wait for the next polling cycle + await clock.tickAsync(POLLING_TIMEOUT); + + // Second and third messages should now be processing + assert.equal(handleMessage.callCount, 3); + assert.deepEqual(handleMessage.secondCall.args[0], messages[1]); + assert.deepEqual(handleMessage.thirdCall.args[0], messages[2]); + + await clock.tickAsync(2000); + consumer.stop(); + }); + + it("completes in-flight messages when stopping", async () => { + const messages = [ + { MessageId: "1", ReceiptHandle: "1", Body: "1" }, + { MessageId: "2", ReceiptHandle: "2", Body: "2" }, + { MessageId: "3", ReceiptHandle: "3", Body: "3" }, + ]; + + let resolveMessage1: (() => void) | undefined; + + handleMessage.callsFake((message) => { + return new Promise((resolve) => { + if (message.MessageId === "1") { + resolveMessage1 = resolve; + } else { + resolve(); + } + }); + }); + + sqs.send.withArgs(mockReceiveMessage).resolves({ Messages: messages }); + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessage, + sqs, + concurrency: 2, + batchSize: 1, + }); + + consumer.start(); + await clock.tickAsync(0); + + // First two messages should be in flight + assert.equal(handleMessage.callCount, 2); + assert.deepEqual(handleMessage.firstCall.args[0], messages[0]); + assert.deepEqual(handleMessage.secondCall.args[0], messages[1]); + + consumer.stop(); + + // Complete first message + resolveMessage1?.(); + await clock.tickAsync(0); + + // No more messages should be processed after stopping + assert.equal(handleMessage.callCount, 2); + }); + + it("reports concurrency slot usage", async () => { + const messages = [ + { MessageId: "1", ReceiptHandle: "1", Body: "1" }, + { MessageId: "2", ReceiptHandle: "2", Body: "2" }, + { MessageId: "3", ReceiptHandle: "3", Body: "3" }, + ]; + + handleMessage.callsFake(() => Promise.resolve()); + sqs.send.withArgs(mockReceiveMessage).resolves({ Messages: messages }); + + const concurrencyUpdateHandler = sandbox.stub(); + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessage, + sqs, + concurrency: 2, + batchSize: 1, + }); + + consumer.on("concurrency_limit_reached", concurrencyUpdateHandler); + consumer.start(); + await clock.tickAsync(0); + + // First two messages should be in flight, using all concurrency slots + assert.equal(handleMessage.callCount, 2); + assert.equal(concurrencyUpdateHandler.callCount, 1); + assert.deepEqual(concurrencyUpdateHandler.firstCall.args[0], { + limit: 2, + waiting: 1, + }); + + consumer.stop(); + }); + + it("reports concurrency changes when updating limit", async () => { + const messages = [ + { MessageId: "1", ReceiptHandle: "1", Body: "1" }, + { MessageId: "2", ReceiptHandle: "2", Body: "2" }, + ]; + + handleMessage.callsFake( + () => new Promise((resolve) => setTimeout(resolve, 2000)), + ); + sqs.send.withArgs(mockReceiveMessage).resolves({ Messages: messages }); + + const concurrencyUpdateHandler = sandbox.stub(); + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessage, + sqs, + concurrency: 1, + batchSize: 1, + }); + + consumer.on("concurrency_limit_reached", concurrencyUpdateHandler); + consumer.start(); + await clock.tickAsync(0); + + // First message uses the only slot + assert.equal(handleMessage.callCount, 1); + assert.equal(concurrencyUpdateHandler.callCount, 1); + assert.deepEqual(concurrencyUpdateHandler.firstCall.args[0], { + limit: 1, + waiting: 1, + }); + + // Update concurrency limit + consumer.updateOption("concurrency", 2); + await clock.tickAsync(POLLING_TIMEOUT); + + // Second message should now be processing + assert.equal(handleMessage.callCount, 2); + + consumer.stop(); + }); + + it("uses the configured concurrencyWaitTimeMs when waiting for slots", async () => { + const messages = [ + { MessageId: "1", ReceiptHandle: "1", Body: "1" }, + { MessageId: "2", ReceiptHandle: "2", Body: "2" }, + ]; + + handleMessage.callsFake( + () => new Promise((resolve) => setTimeout(resolve, 2000)), + ); + sqs.send.withArgs(mockReceiveMessage).resolves({ Messages: messages }); + + const customWaitTime = 100; + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessage, + sqs, + concurrency: 1, + batchSize: 1, + concurrencyWaitTimeMs: customWaitTime, + }); + + consumer.start(); + await clock.tickAsync(0); + + // First message starts processing + assert.equal(handleMessage.callCount, 1); + assert.deepEqual(handleMessage.firstCall.args[0], messages[0]); + + // Wait less than the configured wait time - second message shouldn't be processed + await clock.tickAsync(customWaitTime - 1); + assert.equal(handleMessage.callCount, 1); + + // Wait for the full wait time - second message should now be checked + await clock.tickAsync(1); + assert.equal(handleMessage.callCount, 1); + + consumer.stop(); + }); + + it("uses updated concurrencyWaitTimeMs value after runtime update", async () => { + const messages = [ + { MessageId: "1", ReceiptHandle: "1", Body: "1" }, + { MessageId: "2", ReceiptHandle: "2", Body: "2" }, + ]; + + handleMessage.callsFake( + () => new Promise((resolve) => setTimeout(resolve, 2000)), + ); + sqs.send.withArgs(mockReceiveMessage).resolves({ Messages: messages }); + + consumer = new Consumer({ + queueUrl: QUEUE_URL, + region: REGION, + handleMessage, + sqs, + concurrency: 1, + batchSize: 1, + concurrencyWaitTimeMs: 50, + }); + + consumer.start(); + await clock.tickAsync(0); + + // First message starts processing + assert.equal(handleMessage.callCount, 1); + + // Update to a longer wait time + const newWaitTime = 100; + consumer.updateOption("concurrencyWaitTimeMs", newWaitTime); + + // Wait less than the new wait time - second message shouldn't be processed + await clock.tickAsync(newWaitTime - 1); + assert.equal(handleMessage.callCount, 1); + + // Wait for the full new wait time - second message should now be checked + await clock.tickAsync(1); + assert.equal(handleMessage.callCount, 1); + + consumer.stop(); + }); + }); }); describe("event listeners", () => { @@ -2264,6 +2687,47 @@ describe("Consumer", () => { sandbox.assert.notCalled(optionUpdatedListener); }); + it("updates the concurrencyWaitTimeMs option and emits an event", () => { + const optionUpdatedListener = sandbox.stub(); + consumer.on("option_updated", optionUpdatedListener); + + consumer.updateOption("concurrencyWaitTimeMs", 100); + + assert.equal(consumer.concurrencyWaitTimeMs, 100); + + sandbox.assert.calledWithMatch( + optionUpdatedListener, + "concurrencyWaitTimeMs", + 100, + ); + }); + + it("does not update the concurrencyWaitTimeMs if the value is negative", () => { + const optionUpdatedListener = sandbox.stub(); + consumer.on("option_updated", optionUpdatedListener); + + assert.throws(() => { + consumer.updateOption("concurrencyWaitTimeMs", -1); + }, "concurrencyWaitTimeMs must be a non-negative integer."); + + assert.equal(consumer.concurrencyWaitTimeMs, 50); + + sandbox.assert.notCalled(optionUpdatedListener); + }); + + it("does not update the concurrencyWaitTimeMs if the value is not an integer", () => { + const optionUpdatedListener = sandbox.stub(); + consumer.on("option_updated", optionUpdatedListener); + + assert.throws(() => { + consumer.updateOption("concurrencyWaitTimeMs", 1.5); + }, "concurrencyWaitTimeMs must be a non-negative integer."); + + assert.equal(consumer.concurrencyWaitTimeMs, 50); + + sandbox.assert.notCalled(optionUpdatedListener); + }); + it("throws an error for an unknown option", () => { consumer = new Consumer({ region: REGION,