diff --git a/src/Motor.Extensions.Hosting.Kafka/KafkaMessageConsumer.cs b/src/Motor.Extensions.Hosting.Kafka/KafkaMessageConsumer.cs index 67c5e521..34a10ca2 100644 --- a/src/Motor.Extensions.Hosting.Kafka/KafkaMessageConsumer.cs +++ b/src/Motor.Extensions.Hosting.Kafka/KafkaMessageConsumer.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Text.Json; @@ -38,6 +39,11 @@ public sealed class KafkaMessageConsumer : IMessageConsumer, IDisp private IConsumer? _consumer; private readonly CancellationTokenSource _internalCts = new(); + // Per-partition processing state + private readonly ConcurrentDictionary _partitionProcessors = new(); + private readonly object _pauseLock = new(); + private readonly HashSet _pausedPartitions = new(); + public KafkaMessageConsumer( ILogger> logger, IOptions> config, @@ -71,9 +77,6 @@ CloudEventFormatter cloudEventFormatter "partition" ); - _processedMessages = Channel.CreateBounded>( - _options.MaxConcurrentMessages - ); _timer = new Timer(HandleCommitTimer); _retryPolicy = Policy @@ -97,9 +100,56 @@ public Task StartAsync(CancellationToken token = default) throw new InvalidOperationException("ConsumeCallback is null"); } + var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_internalCts.Token, token); + var consumerBuilder = new ConsumerBuilder(_options) .SetLogHandler((_, logMessage) => WriteLog(logMessage)) - .SetStatisticsHandler((_, json) => WriteStatistics(json)); + .SetStatisticsHandler((_, json) => WriteStatistics(json)) + .SetPartitionsAssignedHandler( + (_, partitions) => + { + foreach (var tp in partitions) + { + var processor = new PartitionProcessor( + tp, + _options.MaxConcurrentMessagesPerPartition, + ConsumeCallbackAsync, + _applicationNameService, + _cloudEventFormatter, + _retryPolicy, + _applicationLifetime, + _logger, + linkedCts.Token + ); + _partitionProcessors[tp] = processor; + _logger.LogInformation(LogEvents.PartitionAssigned, "Partition assigned: {TopicPartition}", tp); + } + } + ) + .SetPartitionsRevokedHandler( + (consumer, partitions) => + { + foreach (var tp in partitions.Select(tpo => tpo.TopicPartition)) + { + if (_partitionProcessors.TryRemove(tp, out var processor)) + { + DrainPartitionProcessor(processor); + _logger.LogInformation( + LogEvents.PartitionRevoked, + "Partition revoked: {TopicPartition}", + tp + ); + } + + lock (_pauseLock) + { + _pausedPartitions.Remove(tp); + } + } + + Commit(); + } + ); _consumer = consumerBuilder.Build(); _consumer.Subscribe(_options.Topic); @@ -120,25 +170,49 @@ await Task.Run( { try { - if (!await _processedMessages.Writer.WaitToWriteAsync(cts.Token)) + ResumePartitionsWithCapacity(); + + var msg = _consumer?.Consume(TimeSpan.FromMilliseconds(100)); + if (msg is null || msg.IsPartitionEOF) { - break; + _logger.LogDebug(LogEvents.NoMessageReceived, "No messages received"); + continue; } - var msg = _consumer?.Consume(cts.Token); - if (msg is { IsPartitionEOF: false }) + var tp = msg.TopicPartition; + + if (!_partitionProcessors.TryGetValue(tp, out var processor)) { - await _processedMessages.Writer.WriteAsync( - SingleMessageHandlingAsync(msg, cts.Token), - cts.Token + // Partition not assigned (race condition during rebalance), skip + _logger.LogWarning( + "Received message for unassigned partition {TopicPartition}, skipping", + tp ); + continue; + } + + // Try to write the raw message into the partition's input channel. + // The PartitionProcessor will process it sequentially. + if (processor.HasCapacity) + { + processor.InputChannel.Writer.TryWrite(msg); } else { - _logger.LogDebug(LogEvents.NoMessageReceived, "No messages received"); + // Channel is full — pause this partition and wait for space. + PausePartition(tp); + + if (await processor.InputChannel.Writer.WaitToWriteAsync(cts.Token)) + { + processor.InputChannel.Writer.TryWrite(msg); + } } } - catch (Exception e) when (e is not OperationCanceledException or ChannelClosedException) + catch (ChannelClosedException) + { + // Channel was closed due to partition revocation, skip and continue + } + catch (Exception e) when (e is not OperationCanceledException) { _logger.LogError(LogEvents.MessageReceivedFailure, e, "Failed to receive message."); } @@ -165,6 +239,86 @@ public Task StopAsync(CancellationToken token = default) return Task.CompletedTask; } + // PausePartition is implemented in the client, it tells the partition fetcher to stop fetching an otherwise active partition. + // https://github.com/confluentinc/librdkafka/issues/1849#issuecomment-397763904 + private void PausePartition(TopicPartition tp) + { + lock (_pauseLock) + { + if (_pausedPartitions.Add(tp)) + { + try + { + _consumer?.Pause(new[] { tp }); + _logger.LogDebug(LogEvents.PartitionPaused, "Paused partition {TopicPartition}", tp); + } + catch (KafkaException) + { + // Partition may have been revoked + _pausedPartitions.Remove(tp); + } + } + } + } + + // ResumePartition is implemented in the client, it tells the partition fetcher to stop fetching an otherwise active partition. + // https://github.com/confluentinc/librdkafka/issues/1849#issuecomment-397763904 + private void ResumePartitionsWithCapacity() + { + lock (_pauseLock) + { + if (_pausedPartitions.Count == 0) + { + return; + } + + var toResume = new List(); + foreach (var tp in _pausedPartitions) + { + if ( + _partitionProcessors.TryGetValue(tp, out var processor) + && processor.InputChannel.Reader.Count < processor.Capacity + ) + { + toResume.Add(tp); + } + } + + foreach (var tp in toResume) + { + try + { + _consumer?.Resume(new[] { tp }); + _pausedPartitions.Remove(tp); + _logger.LogDebug(LogEvents.PartitionResumed, "Resumed partition {TopicPartition}", tp); + } + catch (KafkaException) + { + _pausedPartitions.Remove(tp); + } + } + } + } + + private void DrainPartitionProcessor(PartitionProcessor processor) + { + // Complete the input channel so the processing loop finishes + processor.InputChannel.Writer.TryComplete(); + + // Read any already-completed results from the output channel + while (processor.OutputChannel.Reader.TryRead(out var result)) + { + try + { + _consumer?.StoreOffset(result.ConsumeResult); + } + catch + { + // Best effort drain + } + } + } + private void WriteLog(LogMessage logMessage) { switch (logMessage.Level) @@ -242,45 +396,8 @@ private void WriteStatistics(string json) } } - private async Task SingleMessageHandlingAsync( - ConsumeResult msg, - CancellationToken token - ) - { - try - { - _logger.LogDebug( - LogEvents.ReceivedMessage, - "Received message from topic '{Topic}:{Partition}' with offset: '{Offset}[{TopicPartitionOffset}]'", - msg.Topic, - msg.Partition, - msg.Offset, - msg.TopicPartitionOffset - ); - var cloudEvent = KafkaMessageToCloudEvent(msg.Message); - - var status = await _retryPolicy.ExecuteAsync( - (cancellationToken) => ConsumeCallbackAsync!.Invoke(cloudEvent, cancellationToken), - token - ); - return new ConsumeResultAndProcessedMessageStatus(msg, status); - } - catch (Exception e) - { - _logger.LogCritical( - LogEvents.MessageHandlingUnexpectedException, - e, - "Unexpected exception in message handling" - ); - _applicationLifetime.StopApplication(); - } - - return new ConsumeResultAndProcessedMessageStatus(msg, ProcessedMessageStatus.CriticalFailure); - } - #region Commit - private readonly Channel> _processedMessages; private readonly Timer _timer; private readonly object _commitLock = new(); private bool _pendingCommit; @@ -294,31 +411,48 @@ private async Task ExecuteCommitLoopAsync(CancellationToken cancellationToken) { try { - var result = await PeekAndAwaitProcessedMessages(cancellationToken); + // Poll all partition processors in a round-robin fashion + var didWork = false; - if (IsIrrecoverableFailure(result.ProcessedMessageStatus)) + foreach (var kvp in _partitionProcessors) { - await _internalCts.CancelAsync(); - _applicationLifetime.StopApplication(); - break; - } + var outputChannel = kvp.Value.OutputChannel; - // Remove message from channel, when Task is successfully completed - await _processedMessages.Reader.ReadAsync(cancellationToken); + // Try to read the next completed result from this partition's output channel + if (!outputChannel.Reader.TryRead(out var result)) + { + continue; + } - lock (_commitLock) - { - _consumer?.StoreOffset(result.ConsumeResult); - _pendingCommit = true; - _messagesSinceLastCommit++; + didWork = true; + + if (IsIrrecoverableFailure(result.ProcessedMessageStatus)) + { + await _internalCts.CancelAsync(); + _applicationLifetime.StopApplication(); + return; + } + + lock (_commitLock) + { + _consumer?.StoreOffset(result.ConsumeResult); + _pendingCommit = true; + _messagesSinceLastCommit++; + } + + // Use message count since last commit instead of offset-based check. + // This works correctly across multiple partitions with non-contiguous offsets. + if (_messagesSinceLastCommit >= _options.CommitPeriod) + { + Commit(); + RestartCommitTimer(); + } } - // Use message count since last commit instead of offset-based check. - // This works correctly across multiple partitions with non-contiguous offsets. - if (_messagesSinceLastCommit >= _options.CommitPeriod) + if (!didWork) { - Commit(); - RestartCommitTimer(); + // No partition had a completed result; wait for any partition to produce one + await WaitForAnyPartitionCompletion(cancellationToken); } } catch (Exception e) when (e is OperationCanceledException or ChannelClosedException) @@ -330,18 +464,24 @@ private async Task ExecuteCommitLoopAsync(CancellationToken cancellationToken) StopCommitTimer(); } - private async Task PeekAndAwaitProcessedMessages( - CancellationToken cancellationToken - ) + private async Task WaitForAnyPartitionCompletion(CancellationToken cancellationToken) { - await _processedMessages.Reader.WaitToReadAsync(cancellationToken); + var waitTasks = new List(); - if (!_processedMessages.Reader.TryPeek(out var consumeAndProcessTask)) + foreach (var kvp in _partitionProcessors) { - throw new InvalidOperationException("Awaited channel data has been removed by another consumer"); + var outputChannel = kvp.Value.OutputChannel; + waitTasks.Add(outputChannel.Reader.WaitToReadAsync(cancellationToken).AsTask()); } - return await consumeAndProcessTask; + if (waitTasks.Count == 0) + { + // No partitions assigned yet, just wait briefly + await Task.Delay(10, cancellationToken); + return; + } + + await Task.WhenAny(waitTasks); } private void Commit() @@ -441,14 +581,24 @@ private void Dispose(bool disposing) private void CloseOrDispose() { + // Cancel the internal CTS first so that any in-flight message handlers + // are cancelled before Close() triggers partition revocation (which drains channels). try { - _consumer?.Close(); _internalCts.Cancel(); } catch (ObjectDisposedException) { - // thrown if the consumer is already closed + // CTS already disposed + } + + try + { + _consumer?.Close(); + } + catch (ObjectDisposedException) + { + // thrown if the consumer is already closed/disposed } finally { diff --git a/src/Motor.Extensions.Hosting.Kafka/KafkaPartitionProcessor.cs b/src/Motor.Extensions.Hosting.Kafka/KafkaPartitionProcessor.cs new file mode 100644 index 00000000..6f1445d3 --- /dev/null +++ b/src/Motor.Extensions.Hosting.Kafka/KafkaPartitionProcessor.cs @@ -0,0 +1,164 @@ +using System; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using CloudNative.CloudEvents; +using Confluent.Kafka; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Motor.Extensions.Hosting.Abstractions; +using Motor.Extensions.Hosting.CloudEvents; +using Polly; + +namespace Motor.Extensions.Hosting.Kafka; + +/// +/// Represents the processing state for a single Kafka partition. +/// Each partition gets its own bounded input channel and processes +/// messages sequentially — one at a time. Results are written to +/// an output channel for the commit loop to consume. +/// +public sealed class PartitionProcessor +{ + public TopicPartition TopicPartition { get; } + + /// + /// Inbound channel: raw consumed messages waiting to be processed. + /// + public Channel> InputChannel { get; } + + /// + /// Outbound channel: processed results ready for offset commit. + /// + public Channel OutputChannel { get; } + + public int Capacity { get; } + + /// + /// Returns true if the input channel has space for at least one more item. + /// + public bool HasCapacity => InputChannel.Reader.Count < Capacity; + + /// + /// The background task running the sequential processing loop. + /// + public Task ProcessingTask { get; } + + public PartitionProcessor( + TopicPartition topicPartition, + int capacity, + Func, CancellationToken, Task> consumeCallbackAsync, + IApplicationNameService applicationNameService, + CloudEventFormatter cloudEventFormatter, + AsyncPolicy retryPolicy, + IHostApplicationLifetime applicationLifetime, + ILogger logger, + CancellationToken cancellationToken + ) + { + TopicPartition = topicPartition; + Capacity = capacity; + InputChannel = Channel.CreateBounded>(capacity); + OutputChannel = Channel.CreateBounded(capacity); + + ProcessingTask = Task.Run( + () => + ProcessMessagesAsync( + consumeCallbackAsync, + applicationNameService, + cloudEventFormatter, + retryPolicy, + applicationLifetime, + logger, + cancellationToken + ), + cancellationToken + ); + } + + private async Task ProcessMessagesAsync( + Func, CancellationToken, Task> consumeCallbackAsync, + IApplicationNameService applicationNameService, + CloudEventFormatter cloudEventFormatter, + AsyncPolicy retryPolicy, + IHostApplicationLifetime applicationLifetime, + ILogger logger, + CancellationToken cancellationToken + ) + { + try + { + await foreach (var msg in InputChannel.Reader.ReadAllAsync(cancellationToken)) + { + var result = await HandleSingleMessageAsync( + msg, + consumeCallbackAsync, + applicationNameService, + cloudEventFormatter, + retryPolicy, + applicationLifetime, + logger, + cancellationToken + ); + + // Write the result to the output channel for the commit loop. + // This will block if the output channel is full, providing backpressure. + await OutputChannel.Writer.WriteAsync(result, cancellationToken); + } + } + catch (OperationCanceledException) + { + // Processing was cancelled + } + catch (ChannelClosedException) + { + // Input channel was completed (partition revoked) + } + finally + { + OutputChannel.Writer.TryComplete(); + } + } + + private static async Task HandleSingleMessageAsync( + ConsumeResult msg, + Func, CancellationToken, Task> consumeCallbackAsync, + IApplicationNameService applicationNameService, + CloudEventFormatter cloudEventFormatter, + AsyncPolicy retryPolicy, + IHostApplicationLifetime applicationLifetime, + ILogger logger, + CancellationToken token + ) + { + try + { + logger.LogDebug( + LogEvents.ReceivedMessage, + "Received message from topic '{Topic}:{Partition}' with offset: '{Offset}[{TopicPartitionOffset}]'", + msg.Topic, + msg.Partition, + msg.Offset, + msg.TopicPartitionOffset + ); + var cloudEvent = msg.Message.ToMotorCloudEvent(applicationNameService, cloudEventFormatter); + + var status = await retryPolicy.ExecuteAsync( + cancellationToken => consumeCallbackAsync.Invoke(cloudEvent, cancellationToken), + token + ); + return new ConsumeResultAndProcessedMessageStatus(msg, status); + } + catch (Exception e) + { + logger.LogCritical( + LogEvents.MessageHandlingUnexpectedException, + e, + "Unexpected exception in message handling" + ); + applicationLifetime.StopApplication(); + } + + return new ConsumeResultAndProcessedMessageStatus(msg, ProcessedMessageStatus.CriticalFailure); + } +} diff --git a/src/Motor.Extensions.Hosting.Kafka/LogEvents.cs b/src/Motor.Extensions.Hosting.Kafka/LogEvents.cs index afbff61c..e0c20757 100644 --- a/src/Motor.Extensions.Hosting.Kafka/LogEvents.cs +++ b/src/Motor.Extensions.Hosting.Kafka/LogEvents.cs @@ -17,4 +17,9 @@ public static class LogEvents 8, nameof(MessageHandlingUnexpectedException) ); + + public static readonly EventId PartitionAssigned = new(9, nameof(PartitionAssigned)); + public static readonly EventId PartitionRevoked = new(10, nameof(PartitionRevoked)); + public static readonly EventId PartitionPaused = new(11, nameof(PartitionPaused)); + public static readonly EventId PartitionResumed = new(12, nameof(PartitionResumed)); } diff --git a/src/Motor.Extensions.Hosting.Kafka/Options/KafkaConsumerOptions.cs b/src/Motor.Extensions.Hosting.Kafka/Options/KafkaConsumerOptions.cs index 5b2cd2c6..59a91318 100644 --- a/src/Motor.Extensions.Hosting.Kafka/Options/KafkaConsumerOptions.cs +++ b/src/Motor.Extensions.Hosting.Kafka/Options/KafkaConsumerOptions.cs @@ -13,7 +13,7 @@ public KafkaConsumerOptions() public string? Topic { get; set; } public int CommitPeriod { get; set; } = 1000; - public int MaxConcurrentMessages { get; set; } = 1000; + public int MaxConcurrentMessagesPerPartition { get; set; } = 1; public int RetriesOnTemporaryFailure { get; set; } = 10; public TimeSpan RetryBasePeriod { get; set; } = TimeSpan.FromSeconds(1); } diff --git a/test/Motor.Extensions.Hosting.Kafka_IntegrationTest/KafkaExtensionTests.cs b/test/Motor.Extensions.Hosting.Kafka_IntegrationTest/KafkaExtensionTests.cs index 20a4a8e2..04f25a3e 100644 --- a/test/Motor.Extensions.Hosting.Kafka_IntegrationTest/KafkaExtensionTests.cs +++ b/test/Motor.Extensions.Hosting.Kafka_IntegrationTest/KafkaExtensionTests.cs @@ -1,10 +1,6 @@ -using System; -using System.Collections.Generic; -using System.Linq; +using System.Collections.Concurrent; using System.Text; -using System.Threading; using System.Threading.Channels; -using System.Threading.Tasks; using CloudNative.CloudEvents.SystemTextJson; using Confluent.Kafka; using Microsoft.Extensions.Hosting; @@ -185,7 +181,7 @@ ProcessedMessageStatus returnStatus var taskCompletionSource = new TaskCompletionSource(); await PublishMessage(topic, "someKey", "1"); await PublishMessage(topic, "someKey", "2"); - var config = GetConsumerConfig(topic, maxConcurrentMessages: 1); + var config = GetConsumerConfig(topic, maxConcurrentMessagesPerPartition: 1); var consumer = GetConsumer(topic, config); var distinctHandledMessages = new HashSet(); consumer.ConsumeCallbackAsync = async (data, _) => @@ -225,7 +221,7 @@ ProcessedMessageStatus processedMessageStatus await PublishMessage(topic, "someKey", message); } - var config = GetConsumerConfig(topic, maxConcurrentMessages: 1); + var config = GetConsumerConfig(topic, maxConcurrentMessagesPerPartition: 1); config.CommitPeriod = 1; config.AutoCommitIntervalMs = null; var consumer = GetConsumer(topic, config); @@ -336,7 +332,11 @@ public async Task Consume_AfterProcessingAMessage_CommitsEveryCommitPeriod() { var topic = NewTopic(); var fakeLifetimeMock = new Mock(); - var config = GetConsumerConfig(topic, maxConcurrentMessages: 1, retriesOnTemporaryFailure: 1); + var config = GetConsumerConfig( + topic, + maxConcurrentMessagesPerPartition: 1, + retriesOnTemporaryFailure: 1 + ); config.CommitPeriod = 2; config.AutoCommitIntervalMs = null; using var consumer = GetConsumer(topic, config, fakeLifetimeMock.Object); @@ -396,7 +396,11 @@ public async Task Consume_AfterProcessingAMessage_CommitsOnlyEveryCommitPeriod() { var topic = NewTopic(); var fakeLifetimeMock = new Mock(); - var config = GetConsumerConfig(topic, maxConcurrentMessages: 1, retriesOnTemporaryFailure: 1); + var config = GetConsumerConfig( + topic, + maxConcurrentMessagesPerPartition: 1, + retriesOnTemporaryFailure: 1 + ); config.CommitPeriod = 10; config.AutoCommitIntervalMs = null; using var consumer = GetConsumer(topic, config, fakeLifetimeMock.Object); @@ -419,7 +423,11 @@ public async Task Consume_WhenHavingUncommittedMessages_CommitsEveryAutoCommitIn { var topic = NewTopic(); var fakeLifetimeMock = new Mock(); - var config = GetConsumerConfig(topic, maxConcurrentMessages: 1, retriesOnTemporaryFailure: 1); + var config = GetConsumerConfig( + topic, + maxConcurrentMessagesPerPartition: 1, + retriesOnTemporaryFailure: 1 + ); config.CommitPeriod = 1000; // default config.AutoCommitIntervalMs = 1; using var consumer = GetConsumer(topic, config, fakeLifetimeMock.Object); @@ -442,7 +450,11 @@ public async Task Consume_OnShutdown_Commits() { var topic = NewTopic(); var fakeLifetimeMock = new Mock(); - var config = GetConsumerConfig(topic, maxConcurrentMessages: 1, retriesOnTemporaryFailure: 1); + var config = GetConsumerConfig( + topic, + maxConcurrentMessagesPerPartition: 1, + retriesOnTemporaryFailure: 1 + ); config.AutoCommitIntervalMs = null; using var consumer = GetConsumer(topic, config, fakeLifetimeMock.Object); consumer.ConsumeCallbackAsync = CreateConsumeCallback(ProcessedMessageStatus.Success, _consumedChannel); @@ -497,6 +509,25 @@ await producer.ProduceAsync( producer.Flush(); } + private async Task PublishMessageToPartition(string topic, int partition, string value) + { + using var producer = new ProducerBuilder( + new ProducerConfig { BootstrapServers = fixture.BootstrapServers } + ).Build(); + await producer.ProduceAsync( + new TopicPartition(topic, new Partition(partition)), + new Message { Value = Encoding.UTF8.GetBytes(value) } + ); + producer.Flush(); + } + + private async Task CreateMultiPartitionTopic(int numPartitions) + { + var topic = NewTopic(); + await fixture.CreateTopicAsync(topic, numPartitions); + return topic; + } + private KafkaMessageConsumer GetConsumer( string topic, KafkaConsumerOptions config = null, @@ -542,7 +573,7 @@ private static IApplicationNameService GetApplicationNameService(string source = private KafkaConsumerOptions GetConsumerConfig( string topic, - int maxConcurrentMessages = 1000, + int maxConcurrentMessagesPerPartition = 1000, string groupId = "group_id", int retriesOnTemporaryFailure = 10, TimeSpan? retryBasePeriod = null @@ -558,9 +589,383 @@ private KafkaConsumerOptions GetConsumerConfig( StatisticsIntervalMs = 5000, SessionTimeoutMs = 6000, AutoOffsetReset = AutoOffsetReset.Earliest, - MaxConcurrentMessages = maxConcurrentMessages, + MaxConcurrentMessagesPerPartition = maxConcurrentMessagesPerPartition, RetriesOnTemporaryFailure = retriesOnTemporaryFailure, RetryBasePeriod = retryBasePeriod ?? TimeSpan.FromSeconds(1), }; } + + #region Per-partition fair processing tests + + [Fact(Timeout = 50000)] + public async Task Consume_MultiPartitionTopic_AllPartitionsProcessed() + { + const int numPartitions = 3; + const int messagesPerPartition = 5; + var topic = await CreateMultiPartitionTopic(numPartitions); + + for (var p = 0; p < numPartitions; p++) + { + for (var m = 0; m < messagesPerPartition; m++) + { + await PublishMessageToPartition(topic, p, $"p{p}-m{m}"); + } + } + + var config = GetConsumerConfig(topic, maxConcurrentMessagesPerPartition: messagesPerPartition); + config.CommitPeriod = 1; + config.AutoCommitIntervalMs = 1; + var consumer = GetConsumer(topic, config); + var processedMessages = new ConcurrentBag(); + var allProcessed = new TaskCompletionSource(); + const int totalMessages = numPartitions * messagesPerPartition; + consumer.ConsumeCallbackAsync = async (data, _) => + { + processedMessages.Add(Encoding.UTF8.GetString(data.TypedData)); + if (processedMessages.Count >= totalMessages) + { + allProcessed.TrySetResult(); + } + + return await Task.FromResult(ProcessedMessageStatus.Success); + }; + + await consumer.StartAsync(); + var executionTask = consumer.ExecuteAsync(); + await Task.WhenAny(allProcessed.Task, Task.Delay(TimeSpan.FromSeconds(30))); + await consumer.StopAsync(); + await executionTask; + + Assert.Equal(totalMessages, processedMessages.Count); + // Verify all partitions were represented + for (var p = 0; p < numPartitions; p++) + { + for (var m = 0; m < messagesPerPartition; m++) + { + Assert.Contains($"p{p}-m{m}", processedMessages); + } + } + } + + [Fact(Timeout = 50000)] + public async Task Consume_SlowPartitionDoesNotBlockOtherPartitions_OtherPartitionsStillProcessed() + { + const int numPartitions = 2; + var topic = await CreateMultiPartitionTopic(numPartitions); + + // Publish 1 message to partition 0 (will be slow) and multiple to partition 1 + await PublishMessageToPartition(topic, 0, "slow"); + for (var i = 0; i < 5; i++) + { + await PublishMessageToPartition(topic, 1, $"fast-{i}"); + } + + var config = GetConsumerConfig(topic, maxConcurrentMessagesPerPartition: 5); + config.CommitPeriod = 1; + config.AutoCommitIntervalMs = 1; + var consumer = GetConsumer(topic, config); + + var partition0Blocked = new TaskCompletionSource(); + var partition1Messages = new ConcurrentBag(); + var allFastProcessed = new TaskCompletionSource(); + + consumer.ConsumeCallbackAsync = async (data, cancellationToken) => + { + var msg = Encoding.UTF8.GetString(data.TypedData); + if (msg == "slow") + { + partition0Blocked.TrySetResult(); + // Block this message indefinitely to simulate a slow partition + await Task.Delay(-1, cancellationToken); + return ProcessedMessageStatus.Success; + } + + partition1Messages.Add(msg); + if (partition1Messages.Count >= 5) + { + allFastProcessed.TrySetResult(); + } + + return await Task.FromResult(ProcessedMessageStatus.Success); + }; + + await consumer.StartAsync(); + var executionTask = consumer.ExecuteAsync(); + + // Wait for the slow message to start processing + await Task.WhenAny(partition0Blocked.Task, Task.Delay(TimeSpan.FromSeconds(15))); + Assert.True(partition0Blocked.Task.IsCompleted, "Slow message on partition 0 should have started processing"); + + // Wait for all fast messages to complete — they should not be blocked by partition 0 + await Task.WhenAny(allFastProcessed.Task, Task.Delay(TimeSpan.FromSeconds(15))); + Assert.True( + allFastProcessed.Task.IsCompleted, + "Fast messages on partition 1 should have completed despite slow partition 0" + ); + + await consumer.StopAsync(); + await executionTask; + + Assert.Equal(5, partition1Messages.Count); + } + + [Fact(Timeout = 50000)] + public async Task Consume_PerPartitionConcurrencyLimit_LimitsEachPartitionIndependently() + { + const int numPartitions = 2; + const int perPartitionLimit = 3; + var topic = await CreateMultiPartitionTopic(numPartitions); + + // Publish more messages than the per-partition limit to each partition + for (var p = 0; p < numPartitions; p++) + { + for (var i = 0; i < perPartitionLimit * 2; i++) + { + await PublishMessageToPartition(topic, p, $"p{p}-{i}"); + } + } + + var config = GetConsumerConfig(topic, maxConcurrentMessagesPerPartition: perPartitionLimit); + var consumer = GetConsumer(topic, config); + + var perPartitionCounts = new ConcurrentDictionary(); + var firstMessageProcessed = new TaskCompletionSource(); + + consumer.ConsumeCallbackAsync = async (data, cancellationToken) => + { + var msg = Encoding.UTF8.GetString(data.TypedData); + var partition = msg.Split('-')[0]; // "p0" or "p1" + perPartitionCounts.AddOrUpdate(partition, 1, (_, count) => count + 1); + firstMessageProcessed.TrySetResult(); + + // Block indefinitely to keep messages in-flight + await Task.Delay(-1, cancellationToken); + return ProcessedMessageStatus.Success; + }; + + await consumer.StartAsync(); + var executionTask = consumer.ExecuteAsync(); + + // Wait for processing to start + await Task.WhenAny(firstMessageProcessed.Task, Task.Delay(TimeSpan.FromSeconds(15))); + // Give time for the consumer to fill up the channels + await Task.Delay(TimeSpan.FromSeconds(2)); + + await consumer.StopAsync(); + await executionTask; + + // Each partition should have at most perPartitionLimit in-flight messages + foreach (var kvp in perPartitionCounts) + { + output.WriteLine($"Partition {kvp.Key}: {kvp.Value} messages started"); + Assert.True( + kvp.Value <= perPartitionLimit, + $"Partition {kvp.Key} had {kvp.Value} concurrent messages, expected at most {perPartitionLimit}" + ); + } + } + + [Fact(Timeout = 50000)] + public async Task Consume_MultiPartition_CommitsAllPartitions() + { + const int numPartitions = 3; + const int messagesPerPartition = 3; + var topic = await CreateMultiPartitionTopic(numPartitions); + + for (var p = 0; p < numPartitions; p++) + { + for (var m = 0; m < messagesPerPartition; m++) + { + await PublishMessageToPartition(topic, p, $"p{p}-m{m}"); + } + } + + var config = GetConsumerConfig(topic, maxConcurrentMessagesPerPartition: 5); + config.CommitPeriod = 1; + config.AutoCommitIntervalMs = 1; + var consumer = GetConsumer(topic, config); + var processedCount = 0; + var allProcessed = new TaskCompletionSource(); + const int totalMessages = numPartitions * messagesPerPartition; + var lockObject = new object(); + consumer.ConsumeCallbackAsync = async (_, _) => + { + lock (lockObject) + { + processedCount++; + output.WriteLine($"Processed message {processedCount}/{totalMessages}"); + if (processedCount >= totalMessages) + { + allProcessed.TrySetResult(); + } + } + + return await Task.FromResult(ProcessedMessageStatus.Success); + }; + + await consumer.StartAsync(); + var executionTask = consumer.ExecuteAsync(); + + // Wait until all messages are processed + await Task.WhenAny(allProcessed.Task, Task.Delay(TimeSpan.FromSeconds(30))); + Assert.True( + allProcessed.Task.IsCompleted, + $"Expected all {totalMessages} messages to be processed, but only {processedCount} were processed" + ); + + // Give time for the commit loop to commit all offsets + await Task.Delay(TimeSpan.FromSeconds(2)); + + // Verify all committed offsets sum to total messages (must be done before StopAsync disposes the consumer) + var offsets = consumer.Committed(); + var committedTotal = offsets.Where(tpo => tpo.Offset != Offset.Unset).Sum(tpo => (long)tpo.Offset); + + await consumer.StopAsync(); + await executionTask; + + Assert.Equal(totalMessages, committedTotal); + } + + [Fact(Timeout = 50000)] + public async Task Consume_MultiPartition_IrrecoverableFailureStopsApplication() + { + const int numPartitions = 2; + var topic = await CreateMultiPartitionTopic(numPartitions); + var fakeLifetimeMock = new Mock(); + + await PublishMessageToPartition(topic, 0, "good"); + await PublishMessageToPartition(topic, 1, "bad"); + + var config = GetConsumerConfig(topic, maxConcurrentMessagesPerPartition: 5); + var consumer = GetConsumer(topic, config, fakeLifetimeMock.Object); + consumer.ConsumeCallbackAsync = async (data, _) => + { + var msg = Encoding.UTF8.GetString(data.TypedData); + return await Task.FromResult( + msg == "bad" ? ProcessedMessageStatus.CriticalFailure : ProcessedMessageStatus.Success + ); + }; + + await consumer.StartAsync(); + var executionTask = consumer.ExecuteAsync(); + + WaitUntil(() => fakeLifetimeMock.Verify(mock => mock.StopApplication())); + await consumer.StopAsync(); + await executionTask; + } + + #endregion + + #region Rebalance / repartitioning tests + + [Fact(Timeout = 60000)] + public async Task Consume_SecondConsumerJoinsGroup_RebalanceDoesNotLoseMessages() + { + const int numPartitions = 10; + var topic = await CreateMultiPartitionTopic(numPartitions); + var groupId = $"rebalance-join-{Guid.NewGuid():N}"; + + var processedMessagesA = new ConcurrentBag(); + var processedMessagesB = new ConcurrentBag(); + + // Start consumer A — initially owns all 4 partitions + var configA = GetConsumerConfig(topic, maxConcurrentMessagesPerPartition: 10, groupId: groupId); + configA.CommitPeriod = 1; + configA.AutoCommitIntervalMs = 1; + configA.SessionTimeoutMs = 6000; + var consumerA = GetConsumer(topic, configA); + consumerA.ConsumeCallbackAsync = async (data, _) => + { + processedMessagesA.Add(Encoding.UTF8.GetString(data.TypedData)); + return await Task.FromResult(ProcessedMessageStatus.Success); + }; + await consumerA.StartAsync(); + var executionA = consumerA.ExecuteAsync(); + + // Start a background publisher that continuously sends messages across all partitions. + // The stream runs before, during, and after both rebalances. + var publishCts = new CancellationTokenSource(); + var publishedMessages = new ConcurrentBag(); + var publisherTask = Task.Run( + async () => + { + var messageIndex = 0; + while (!publishCts.Token.IsCancellationRequested) + { + for (var p = 0; p < numPartitions; p++) + { + var msg = $"msg-{messageIndex++}-p{p}"; + try + { + await PublishMessageToPartition(topic, p, msg); + publishedMessages.Add(msg); + } + catch (Exception) when (publishCts.Token.IsCancellationRequested) + { + return; + } + } + await Task.Delay(10, publishCts.Token); + } + }, + publishCts.Token + ); + + // Wait for consumer A to consume some messages before the first rebalance + await Task.Delay(TimeSpan.FromSeconds(3)); + + // --- First rebalance: consumer B joins --- + var configB = GetConsumerConfig(topic, maxConcurrentMessagesPerPartition: 10, groupId: groupId); + configB.CommitPeriod = 1; + configB.AutoCommitIntervalMs = 1; + configB.SessionTimeoutMs = 6000; + var consumerB = GetConsumer(topic, configB); + consumerB.ConsumeCallbackAsync = async (data, _) => + { + processedMessagesB.Add(Encoding.UTF8.GetString(data.TypedData)); + return await Task.FromResult(ProcessedMessageStatus.Success); + }; + await consumerB.StartAsync(); + var executionB = consumerB.ExecuteAsync(); + + // Let both consumers process messages together for a while + await Task.Delay(TimeSpan.FromSeconds(5)); + + // --- Second rebalance: consumer B leaves --- + await consumerB.StopAsync(); + await executionB; + + // Let consumer A process alone after the second rebalance + await Task.Delay(TimeSpan.FromSeconds(3)); + + // Stop the message stream + await publishCts.CancelAsync(); + try + { + await publisherTask; + } + catch (OperationCanceledException) { } + + // Give consumer A time to process remaining messages + await Task.Delay(TimeSpan.FromSeconds(3)); + + await consumerA.StopAsync(); + await executionA; + + // Both consumers must have processed at least one message + output.WriteLine($"Consumer A processed: {processedMessagesA.Count}"); + output.WriteLine($"Consumer B processed: {processedMessagesB.Count}"); + output.WriteLine($"Total published: {publishedMessages.Count}"); + Assert.NotEmpty(processedMessagesA); + Assert.NotEmpty(processedMessagesB); + + // All published messages must have been processed by one of the consumers + var allProcessed = processedMessagesA.Concat(processedMessagesB).ToHashSet(); + foreach (var msg in publishedMessages) + { + Assert.Contains(msg, allProcessed); + } + } + + #endregion } diff --git a/test/Motor.Extensions.Hosting.Kafka_IntegrationTest/KafkaFixture.cs b/test/Motor.Extensions.Hosting.Kafka_IntegrationTest/KafkaFixture.cs index 9863cf83..ca5c2701 100644 --- a/test/Motor.Extensions.Hosting.Kafka_IntegrationTest/KafkaFixture.cs +++ b/test/Motor.Extensions.Hosting.Kafka_IntegrationTest/KafkaFixture.cs @@ -1,4 +1,5 @@ -using System.Threading.Tasks; +using Confluent.Kafka; +using Confluent.Kafka.Admin; using Testcontainers.Kafka; using Xunit; @@ -15,6 +16,24 @@ public Task InitializeAsync() return _container.StartAsync(); } + public async Task CreateTopicAsync(string topicName, int numPartitions) + { + using var adminClient = new AdminClientBuilder( + new AdminClientConfig { BootstrapServers = _container.GetBootstrapAddress() } + ).Build(); + + var topicExists = adminClient.GetMetadata(TimeSpan.FromMinutes(1)).Topics.Any(t => t.Topic == topicName); + if (!topicExists) + { + await adminClient.CreateTopicsAsync( + new[] + { + new TopicSpecification { Name = topicName, NumPartitions = numPartitions }, + } + ); + } + } + public Task DisposeAsync() { return _container.DisposeAsync().AsTask();