diff --git a/core/src/main/java/kafka/server/share/DelayedShareFetch.java b/core/src/main/java/kafka/server/share/DelayedShareFetch.java index 9bab9818c0706..1422e08524a77 100644 --- a/core/src/main/java/kafka/server/share/DelayedShareFetch.java +++ b/core/src/main/java/kafka/server/share/DelayedShareFetch.java @@ -27,6 +27,7 @@ import org.apache.kafka.server.purgatory.DelayedOperation; import org.apache.kafka.server.share.SharePartitionKey; import org.apache.kafka.server.share.fetch.DelayedShareFetchGroupKey; +import org.apache.kafka.server.share.fetch.PartitionMaxBytesStrategy; import org.apache.kafka.server.share.fetch.ShareFetch; import org.apache.kafka.server.storage.log.FetchIsolation; import org.apache.kafka.server.storage.log.FetchPartitionData; @@ -60,10 +61,11 @@ public class DelayedShareFetch extends DelayedOperation { private final ShareFetch shareFetch; private final ReplicaManager replicaManager; private final BiConsumer exceptionHandler; + private final PartitionMaxBytesStrategy partitionMaxBytesStrategy; // The topic partitions that need to be completed for the share fetch request are given by sharePartitions. // sharePartitions is a subset of shareFetchData. The order of insertion/deletion of entries in sharePartitions is important. private final LinkedHashMap sharePartitions; - private LinkedHashMap partitionsAcquired; + private LinkedHashMap partitionsAcquired; private LinkedHashMap partitionsAlreadyFetched; DelayedShareFetch( @@ -71,6 +73,15 @@ public class DelayedShareFetch extends DelayedOperation { ReplicaManager replicaManager, BiConsumer exceptionHandler, LinkedHashMap sharePartitions) { + this(shareFetch, replicaManager, exceptionHandler, sharePartitions, PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)); + } + + DelayedShareFetch( + ShareFetch shareFetch, + ReplicaManager replicaManager, + BiConsumer exceptionHandler, + LinkedHashMap sharePartitions, + PartitionMaxBytesStrategy partitionMaxBytesStrategy) { super(shareFetch.fetchParams().maxWaitMs, Optional.empty()); this.shareFetch = shareFetch; this.replicaManager = replicaManager; @@ -78,6 +89,7 @@ public class DelayedShareFetch extends DelayedOperation { this.partitionsAlreadyFetched = new LinkedHashMap<>(); this.exceptionHandler = exceptionHandler; this.sharePartitions = sharePartitions; + this.partitionMaxBytesStrategy = partitionMaxBytesStrategy; } @Override @@ -99,7 +111,7 @@ public void onComplete() { partitionsAcquired.keySet()); try { - LinkedHashMap topicPartitionData; + LinkedHashMap topicPartitionData; // tryComplete did not invoke forceComplete, so we need to check if we have any partitions to fetch. if (partitionsAcquired.isEmpty()) topicPartitionData = acquirablePartitions(); @@ -121,11 +133,13 @@ public void onComplete() { } } - private void completeShareFetchRequest(LinkedHashMap topicPartitionData) { + private void completeShareFetchRequest(LinkedHashMap topicPartitionData) { try { LinkedHashMap responseData; if (partitionsAlreadyFetched.isEmpty()) - responseData = readFromLog(topicPartitionData); + responseData = readFromLog( + topicPartitionData, + partitionMaxBytesStrategy.maxBytes(shareFetch.fetchParams().maxBytes, topicPartitionData.keySet(), topicPartitionData.size())); else // There shouldn't be a case when we have a partitionsAlreadyFetched value here and this variable is getting // updated in a different tryComplete thread. @@ -158,7 +172,7 @@ private void completeShareFetchRequest(LinkedHashMap topicPartitionData = acquirablePartitions(); + LinkedHashMap topicPartitionData = acquirablePartitions(); try { if (!topicPartitionData.isEmpty()) { @@ -167,7 +181,7 @@ public boolean tryComplete() { // those topic partitions. LinkedHashMap replicaManagerReadResponse = maybeReadFromLog(topicPartitionData); maybeUpdateFetchOffsetMetadata(topicPartitionData, replicaManagerReadResponse); - if (anyPartitionHasLogReadError(replicaManagerReadResponse) || isMinBytesSatisfied(topicPartitionData)) { + if (anyPartitionHasLogReadError(replicaManagerReadResponse) || isMinBytesSatisfied(topicPartitionData, partitionMaxBytesStrategy.maxBytes(shareFetch.fetchParams().maxBytes, topicPartitionData.keySet(), topicPartitionData.size()))) { partitionsAcquired = topicPartitionData; partitionsAlreadyFetched = replicaManagerReadResponse; boolean completedByMe = forceComplete(); @@ -202,28 +216,18 @@ public boolean tryComplete() { * Prepare fetch request structure for partitions in the share fetch request for which we can acquire records. */ // Visible for testing - LinkedHashMap acquirablePartitions() { + LinkedHashMap acquirablePartitions() { // Initialize the topic partitions for which the fetch should be attempted. - LinkedHashMap topicPartitionData = new LinkedHashMap<>(); + LinkedHashMap topicPartitionData = new LinkedHashMap<>(); sharePartitions.forEach((topicIdPartition, sharePartition) -> { - int partitionMaxBytes = shareFetch.partitionMaxBytes().getOrDefault(topicIdPartition, 0); // Add the share partition to the list of partitions to be fetched only if we can // acquire the fetch lock on it. if (sharePartition.maybeAcquireFetchLock()) { try { // If the share partition is already at capacity, we should not attempt to fetch. if (sharePartition.canAcquireRecords()) { - topicPartitionData.put( - topicIdPartition, - new FetchRequest.PartitionData( - topicIdPartition.topicId(), - sharePartition.nextFetchOffset(), - 0, - partitionMaxBytes, - Optional.empty() - ) - ); + topicPartitionData.put(topicIdPartition, sharePartition.nextFetchOffset()); } else { sharePartition.releaseFetchLock(); log.trace("Record lock partition limit exceeded for SharePartition {}, " + @@ -239,24 +243,28 @@ LinkedHashMap acquirablePartitions return topicPartitionData; } - private LinkedHashMap maybeReadFromLog(LinkedHashMap topicPartitionData) { - LinkedHashMap partitionsNotMatchingFetchOffsetMetadata = new LinkedHashMap<>(); - topicPartitionData.forEach((topicIdPartition, partitionData) -> { + private LinkedHashMap maybeReadFromLog(LinkedHashMap topicPartitionData) { + LinkedHashMap partitionsNotMatchingFetchOffsetMetadata = new LinkedHashMap<>(); + topicPartitionData.forEach((topicIdPartition, fetchOffset) -> { SharePartition sharePartition = sharePartitions.get(topicIdPartition); - if (sharePartition.fetchOffsetMetadata(partitionData.fetchOffset).isEmpty()) { - partitionsNotMatchingFetchOffsetMetadata.put(topicIdPartition, partitionData); + if (sharePartition.fetchOffsetMetadata(fetchOffset).isEmpty()) { + partitionsNotMatchingFetchOffsetMetadata.put(topicIdPartition, fetchOffset); } }); if (partitionsNotMatchingFetchOffsetMetadata.isEmpty()) { return new LinkedHashMap<>(); } // We fetch data from replica manager corresponding to the topic partitions that have missing fetch offset metadata. - return readFromLog(partitionsNotMatchingFetchOffsetMetadata); + // Although we are fetching partition max bytes for partitionsNotMatchingFetchOffsetMetadata, + // we will take acquired partitions size = topicPartitionData.size() because we do not want to let the + // leftover partitions to starve which will be fetched later. + return readFromLog( + partitionsNotMatchingFetchOffsetMetadata, + partitionMaxBytesStrategy.maxBytes(shareFetch.fetchParams().maxBytes, partitionsNotMatchingFetchOffsetMetadata.keySet(), topicPartitionData.size())); } - private void maybeUpdateFetchOffsetMetadata( - LinkedHashMap topicPartitionData, - LinkedHashMap replicaManagerReadResponseData) { + private void maybeUpdateFetchOffsetMetadata(LinkedHashMap topicPartitionData, + LinkedHashMap replicaManagerReadResponseData) { for (Map.Entry entry : replicaManagerReadResponseData.entrySet()) { TopicIdPartition topicIdPartition = entry.getKey(); SharePartition sharePartition = sharePartitions.get(topicIdPartition); @@ -267,17 +275,18 @@ private void maybeUpdateFetchOffsetMetadata( continue; } sharePartition.updateFetchOffsetMetadata( - topicPartitionData.get(topicIdPartition).fetchOffset, + topicPartitionData.get(topicIdPartition), replicaManagerLogReadResult.info().fetchOffsetMetadata); } } // minByes estimation currently assumes the common case where all fetched data is acquirable. - private boolean isMinBytesSatisfied(LinkedHashMap topicPartitionData) { + private boolean isMinBytesSatisfied(LinkedHashMap topicPartitionData, + LinkedHashMap partitionMaxBytes) { long accumulatedSize = 0; - for (Map.Entry entry : topicPartitionData.entrySet()) { + for (Map.Entry entry : topicPartitionData.entrySet()) { TopicIdPartition topicIdPartition = entry.getKey(); - FetchRequest.PartitionData partitionData = entry.getValue(); + long fetchOffset = entry.getValue(); LogOffsetMetadata endOffsetMetadata; try { @@ -294,7 +303,7 @@ private boolean isMinBytesSatisfied(LinkedHashMap optionalFetchOffsetMetadata = sharePartition.fetchOffsetMetadata(partitionData.fetchOffset); + Optional optionalFetchOffsetMetadata = sharePartition.fetchOffsetMetadata(fetchOffset); if (optionalFetchOffsetMetadata.isEmpty() || optionalFetchOffsetMetadata.get() == LogOffsetMetadata.UNKNOWN_OFFSET_METADATA) continue; LogOffsetMetadata fetchOffsetMetadata = optionalFetchOffsetMetadata.get(); @@ -312,7 +321,7 @@ private boolean isMinBytesSatisfied(LinkedHashMap readFromLog(LinkedHashMap topicPartitionData) { + private LinkedHashMap readFromLog(LinkedHashMap topicPartitionFetchOffsets, + LinkedHashMap partitionMaxBytes) { // Filter if there already exists any erroneous topic partition. - Set partitionsToFetch = shareFetch.filterErroneousTopicPartitions(topicPartitionData.keySet()); + Set partitionsToFetch = shareFetch.filterErroneousTopicPartitions(topicPartitionFetchOffsets.keySet()); if (partitionsToFetch.isEmpty()) { return new LinkedHashMap<>(); } + LinkedHashMap topicPartitionData = new LinkedHashMap<>(); + + topicPartitionFetchOffsets.forEach((topicIdPartition, fetchOffset) -> topicPartitionData.put(topicIdPartition, + new FetchRequest.PartitionData( + topicIdPartition.topicId(), + fetchOffset, + 0, + partitionMaxBytes.get(topicIdPartition), + Optional.empty()) + )); + Seq> responseLogResult = replicaManager.readFromLog( shareFetch.fetchParams(), CollectionConverters.asScala( @@ -390,18 +411,21 @@ private void handleFetchException( } // Visible for testing. - LinkedHashMap combineLogReadResponse(LinkedHashMap topicPartitionData, - LinkedHashMap existingFetchedData) { - LinkedHashMap missingLogReadTopicPartitions = new LinkedHashMap<>(); - topicPartitionData.forEach((topicIdPartition, partitionData) -> { + LinkedHashMap combineLogReadResponse(LinkedHashMap topicPartitionData, + LinkedHashMap existingFetchedData) { + LinkedHashMap missingLogReadTopicPartitions = new LinkedHashMap<>(); + topicPartitionData.forEach((topicIdPartition, fetchOffset) -> { if (!existingFetchedData.containsKey(topicIdPartition)) { - missingLogReadTopicPartitions.put(topicIdPartition, partitionData); + missingLogReadTopicPartitions.put(topicIdPartition, fetchOffset); } }); if (missingLogReadTopicPartitions.isEmpty()) { return existingFetchedData; } - LinkedHashMap missingTopicPartitionsLogReadResponse = readFromLog(missingLogReadTopicPartitions); + + LinkedHashMap missingTopicPartitionsLogReadResponse = readFromLog( + missingLogReadTopicPartitions, + partitionMaxBytesStrategy.maxBytes(shareFetch.fetchParams().maxBytes, missingLogReadTopicPartitions.keySet(), topicPartitionData.size())); missingTopicPartitionsLogReadResponse.putAll(existingFetchedData); return missingTopicPartitionsLogReadResponse; } diff --git a/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java b/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java index 11d3e26eaf39a..533ffc9569352 100644 --- a/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java +++ b/core/src/test/java/kafka/server/share/DelayedShareFetchTest.java @@ -18,6 +18,7 @@ import kafka.cluster.Partition; import kafka.server.LogReadResult; +import kafka.server.QuotaFactory; import kafka.server.ReplicaManager; import kafka.server.ReplicaQuota; @@ -26,11 +27,13 @@ import org.apache.kafka.common.Uuid; import org.apache.kafka.common.message.ShareFetchResponseData; import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.record.Records; import org.apache.kafka.common.requests.FetchRequest; import org.apache.kafka.server.purgatory.DelayedOperationKey; import org.apache.kafka.server.purgatory.DelayedOperationPurgatory; import org.apache.kafka.server.share.SharePartitionKey; import org.apache.kafka.server.share.fetch.DelayedShareFetchGroupKey; +import org.apache.kafka.server.share.fetch.PartitionMaxBytesStrategy; import org.apache.kafka.server.share.fetch.ShareAcquiredRecords; import org.apache.kafka.server.share.fetch.ShareFetch; import org.apache.kafka.server.storage.log.FetchIsolation; @@ -39,6 +42,7 @@ import org.apache.kafka.server.util.timer.SystemTimer; import org.apache.kafka.server.util.timer.SystemTimerReaper; import org.apache.kafka.server.util.timer.Timer; +import org.apache.kafka.storage.internals.log.FetchDataInfo; import org.apache.kafka.storage.internals.log.LogOffsetMetadata; import org.apache.kafka.storage.internals.log.LogOffsetSnapshot; @@ -51,11 +55,17 @@ import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; +import java.util.stream.Collectors; + +import scala.Tuple2; +import scala.jdk.javaapi.CollectionConverters; import static kafka.server.share.SharePartitionManagerTest.DELAYED_SHARE_FETCH_PURGATORY_PURGE_INTERVAL; import static kafka.server.share.SharePartitionManagerTest.PARTITION_MAX_BYTES; @@ -182,6 +192,7 @@ public void testTryCompleteWhenMinBytesNotSatisfiedOnFirstFetch() { .withSharePartitions(sharePartitions) .withReplicaManager(replicaManager) .withExceptionHandler(exceptionHandler) + .withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)) .build()); assertFalse(delayedShareFetch.isCompleted()); @@ -286,6 +297,7 @@ public void testDelayedShareFetchTryCompleteReturnsTrue() { .withShareFetchData(shareFetch) .withSharePartitions(sharePartitions) .withReplicaManager(replicaManager) + .withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)) .build()); assertFalse(delayedShareFetch.isCompleted()); @@ -328,6 +340,7 @@ public void testEmptyFutureReturnedByDelayedShareFetchOnComplete() { .withShareFetchData(shareFetch) .withReplicaManager(replicaManager) .withSharePartitions(sharePartitions) + .withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)) .build()); assertFalse(delayedShareFetch.isCompleted()); delayedShareFetch.forceComplete(); @@ -375,6 +388,7 @@ public void testReplicaManagerFetchShouldHappenOnComplete() { .withShareFetchData(shareFetch) .withReplicaManager(replicaManager) .withSharePartitions(sharePartitions) + .withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)) .build()); assertFalse(delayedShareFetch.isCompleted()); delayedShareFetch.forceComplete(); @@ -507,6 +521,7 @@ public void testForceCompleteTriggersDelayedActionsQueue() { .withShareFetchData(shareFetch2) .withReplicaManager(replicaManager) .withSharePartitions(sharePartitions2) + .withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)) .build()); // sp1 can be acquired now @@ -557,15 +572,21 @@ public void testCombineLogReadResponse() { .withShareFetchData(shareFetch) .withReplicaManager(replicaManager) .withSharePartitions(sharePartitions) + .withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)) .build(); - LinkedHashMap topicPartitionData = new LinkedHashMap<>(); - topicPartitionData.put(tp0, mock(FetchRequest.PartitionData.class)); - topicPartitionData.put(tp1, mock(FetchRequest.PartitionData.class)); + LinkedHashMap topicPartitionData = new LinkedHashMap<>(); + topicPartitionData.put(tp0, 0L); + topicPartitionData.put(tp1, 0L); // Case 1 - logReadResponse contains tp0. LinkedHashMap logReadResponse = new LinkedHashMap<>(); - logReadResponse.put(tp0, mock(LogReadResult.class)); + LogReadResult logReadResult = mock(LogReadResult.class); + Records records = mock(Records.class); + when(records.sizeInBytes()).thenReturn(2); + FetchDataInfo fetchDataInfo = new FetchDataInfo(mock(LogOffsetMetadata.class), records); + when(logReadResult.info()).thenReturn(fetchDataInfo); + logReadResponse.put(tp0, logReadResult); doAnswer(invocation -> buildLogReadResult(Collections.singleton(tp1))).when(replicaManager).readFromLog(any(), any(), any(ReplicaQuota.class), anyBoolean()); LinkedHashMap combinedLogReadResponse = delayedShareFetch.combineLogReadResponse(topicPartitionData, logReadResponse); @@ -619,6 +640,7 @@ public void testExceptionInMinBytesCalculation() { .withSharePartitions(sharePartitions) .withReplicaManager(replicaManager) .withExceptionHandler(exceptionHandler) + .withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)) .build()); // Try complete should return false as the share partition has errored out. @@ -707,6 +729,324 @@ public void testLocksReleasedAcquireException() { delayedShareFetch.lock().unlock(); } + @Test + public void testTryCompleteWhenPartitionMaxBytesStrategyThrowsException() { + String groupId = "grp"; + TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)); + SharePartition sp0 = mock(SharePartition.class); + Map partitionMaxBytes = new HashMap<>(); + partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES); + + when(sp0.maybeAcquireFetchLock()).thenReturn(true); + when(sp0.canAcquireRecords()).thenReturn(true); + LinkedHashMap sharePartitions = new LinkedHashMap<>(); + sharePartitions.put(tp0, sp0); + CompletableFuture> future = new CompletableFuture<>(); + + ShareFetch shareFetch = new ShareFetch( + new FetchParams(ApiKeys.SHARE_FETCH.latestVersion(), FetchRequest.ORDINARY_CONSUMER_ID, -1, MAX_WAIT_MS, + 2, 1024 * 1024, FetchIsolation.HIGH_WATERMARK, Optional.empty()), groupId, Uuid.randomUuid().toString(), + future, partitionMaxBytes, MAX_FETCH_RECORDS); + + // partitionMaxBytesStrategy.maxBytes() function throws an exception + PartitionMaxBytesStrategy partitionMaxBytesStrategy = mock(PartitionMaxBytesStrategy.class); + when(partitionMaxBytesStrategy.maxBytes(anyInt(), any(), anyInt())).thenThrow(new IllegalArgumentException("Exception thrown")); + + DelayedShareFetch delayedShareFetch = spy(DelayedShareFetchBuilder.builder() + .withShareFetchData(shareFetch) + .withSharePartitions(sharePartitions) + .withExceptionHandler(mockExceptionHandler()) + .withPartitionMaxBytesStrategy(partitionMaxBytesStrategy) + .build()); + + assertFalse(delayedShareFetch.isCompleted()); + assertTrue(delayedShareFetch.tryComplete()); + assertTrue(delayedShareFetch.isCompleted()); + // releasePartitionLocks is called twice - first time from tryComplete and second time from onComplete + Mockito.verify(delayedShareFetch, times(2)).releasePartitionLocks(any()); + assertTrue(delayedShareFetch.lock().tryLock()); + delayedShareFetch.lock().unlock(); + + assertTrue(future.isDone()); + assertFalse(future.isCompletedExceptionally()); + Map partitionDataMap = future.join(); + assertEquals(1, partitionDataMap.size()); + assertTrue(partitionDataMap.containsKey(tp0)); + assertEquals("Exception thrown", partitionDataMap.get(tp0).errorMessage()); + } + + @Test + public void testPartitionMaxBytesFromUniformStrategyWhenAllPartitionsAreAcquirable() { + ReplicaManager replicaManager = mock(ReplicaManager.class); + String groupId = "grp"; + TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)); + TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 1)); + TopicIdPartition tp2 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 2)); + TopicIdPartition tp3 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 3)); + TopicIdPartition tp4 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 4)); + + SharePartition sp0 = mock(SharePartition.class); + SharePartition sp1 = mock(SharePartition.class); + SharePartition sp2 = mock(SharePartition.class); + SharePartition sp3 = mock(SharePartition.class); + SharePartition sp4 = mock(SharePartition.class); + + Map partitionMaxBytes = new HashMap<>(); + partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES); + partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES); + partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES); + partitionMaxBytes.put(tp3, PARTITION_MAX_BYTES); + partitionMaxBytes.put(tp4, PARTITION_MAX_BYTES); + + when(sp0.maybeAcquireFetchLock()).thenReturn(true); + when(sp1.maybeAcquireFetchLock()).thenReturn(true); + when(sp2.maybeAcquireFetchLock()).thenReturn(true); + when(sp3.maybeAcquireFetchLock()).thenReturn(true); + when(sp4.maybeAcquireFetchLock()).thenReturn(true); + when(sp0.canAcquireRecords()).thenReturn(true); + when(sp1.canAcquireRecords()).thenReturn(true); + when(sp2.canAcquireRecords()).thenReturn(true); + when(sp3.canAcquireRecords()).thenReturn(true); + when(sp4.canAcquireRecords()).thenReturn(true); + + LinkedHashMap sharePartitions = new LinkedHashMap<>(); + sharePartitions.put(tp0, sp0); + sharePartitions.put(tp1, sp1); + sharePartitions.put(tp2, sp2); + sharePartitions.put(tp3, sp3); + sharePartitions.put(tp4, sp4); + + ShareFetch shareFetch = new ShareFetch(FETCH_PARAMS, groupId, Uuid.randomUuid().toString(), + new CompletableFuture<>(), partitionMaxBytes, MAX_FETCH_RECORDS); + + when(sp0.acquire(anyString(), anyInt(), any(FetchPartitionData.class))).thenReturn( + ShareAcquiredRecords.fromAcquiredRecords(new ShareFetchResponseData.AcquiredRecords().setFirstOffset(0).setLastOffset(3).setDeliveryCount((short) 1))); + when(sp1.acquire(anyString(), anyInt(), any(FetchPartitionData.class))).thenReturn( + ShareAcquiredRecords.fromAcquiredRecords(new ShareFetchResponseData.AcquiredRecords().setFirstOffset(0).setLastOffset(3).setDeliveryCount((short) 1))); + when(sp2.acquire(anyString(), anyInt(), any(FetchPartitionData.class))).thenReturn( + ShareAcquiredRecords.fromAcquiredRecords(new ShareFetchResponseData.AcquiredRecords().setFirstOffset(0).setLastOffset(3).setDeliveryCount((short) 1))); + when(sp3.acquire(anyString(), anyInt(), any(FetchPartitionData.class))).thenReturn( + ShareAcquiredRecords.fromAcquiredRecords(new ShareFetchResponseData.AcquiredRecords().setFirstOffset(0).setLastOffset(3).setDeliveryCount((short) 1))); + when(sp4.acquire(anyString(), anyInt(), any(FetchPartitionData.class))).thenReturn( + ShareAcquiredRecords.fromAcquiredRecords(new ShareFetchResponseData.AcquiredRecords().setFirstOffset(0).setLastOffset(3).setDeliveryCount((short) 1))); + + // All 5 partitions are acquirable. + doAnswer(invocation -> buildLogReadResult(sharePartitions.keySet())).when(replicaManager).readFromLog(any(), any(), any(ReplicaQuota.class), anyBoolean()); + + when(sp0.fetchOffsetMetadata(anyLong())).thenReturn(Optional.of(new LogOffsetMetadata(0, 1, 0))); + when(sp1.fetchOffsetMetadata(anyLong())).thenReturn(Optional.of(new LogOffsetMetadata(0, 1, 0))); + when(sp2.fetchOffsetMetadata(anyLong())).thenReturn(Optional.of(new LogOffsetMetadata(0, 1, 0))); + when(sp3.fetchOffsetMetadata(anyLong())).thenReturn(Optional.of(new LogOffsetMetadata(0, 1, 0))); + when(sp4.fetchOffsetMetadata(anyLong())).thenReturn(Optional.of(new LogOffsetMetadata(0, 1, 0))); + + mockTopicIdPartitionToReturnDataEqualToMinBytes(replicaManager, tp0, 1); + mockTopicIdPartitionToReturnDataEqualToMinBytes(replicaManager, tp1, 1); + mockTopicIdPartitionToReturnDataEqualToMinBytes(replicaManager, tp2, 1); + mockTopicIdPartitionToReturnDataEqualToMinBytes(replicaManager, tp3, 1); + mockTopicIdPartitionToReturnDataEqualToMinBytes(replicaManager, tp4, 1); + + DelayedShareFetch delayedShareFetch = spy(DelayedShareFetchBuilder.builder() + .withShareFetchData(shareFetch) + .withSharePartitions(sharePartitions) + .withReplicaManager(replicaManager) + .withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)) + .build()); + + assertTrue(delayedShareFetch.tryComplete()); + assertTrue(delayedShareFetch.isCompleted()); + + // Since all partitions are acquirable, maxbytes per partition = requestMaxBytes(i.e. 1024*1024) / acquiredTopicPartitions(i.e. 5) + int expectedPartitionMaxBytes = 1024 * 1024 / 5; + LinkedHashMap expectedReadPartitionInfo = new LinkedHashMap<>(); + sharePartitions.keySet().forEach(topicIdPartition -> expectedReadPartitionInfo.put(topicIdPartition, + new FetchRequest.PartitionData( + topicIdPartition.topicId(), + 0, + 0, + expectedPartitionMaxBytes, + Optional.empty() + ))); + + Mockito.verify(replicaManager, times(1)).readFromLog( + shareFetch.fetchParams(), + CollectionConverters.asScala( + sharePartitions.keySet().stream().map(topicIdPartition -> + new Tuple2<>(topicIdPartition, expectedReadPartitionInfo.get(topicIdPartition))).collect(Collectors.toList()) + ), + QuotaFactory.UNBOUNDED_QUOTA, + true); + } + + @Test + public void testPartitionMaxBytesFromUniformStrategyWhenFewPartitionsAreAcquirable() { + ReplicaManager replicaManager = mock(ReplicaManager.class); + String groupId = "grp"; + TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)); + TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 1)); + TopicIdPartition tp2 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 2)); + TopicIdPartition tp3 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 3)); + TopicIdPartition tp4 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 4)); + + SharePartition sp0 = mock(SharePartition.class); + SharePartition sp1 = mock(SharePartition.class); + SharePartition sp2 = mock(SharePartition.class); + SharePartition sp3 = mock(SharePartition.class); + SharePartition sp4 = mock(SharePartition.class); + + Map partitionMaxBytes = new HashMap<>(); + partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES); + partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES); + partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES); + partitionMaxBytes.put(tp3, PARTITION_MAX_BYTES); + partitionMaxBytes.put(tp4, PARTITION_MAX_BYTES); + + when(sp0.maybeAcquireFetchLock()).thenReturn(true); + when(sp1.maybeAcquireFetchLock()).thenReturn(true); + when(sp2.maybeAcquireFetchLock()).thenReturn(false); + when(sp3.maybeAcquireFetchLock()).thenReturn(true); + when(sp4.maybeAcquireFetchLock()).thenReturn(false); + when(sp0.canAcquireRecords()).thenReturn(true); + when(sp1.canAcquireRecords()).thenReturn(true); + when(sp2.canAcquireRecords()).thenReturn(false); + when(sp3.canAcquireRecords()).thenReturn(false); + when(sp4.canAcquireRecords()).thenReturn(false); + + LinkedHashMap sharePartitions = new LinkedHashMap<>(); + sharePartitions.put(tp0, sp0); + sharePartitions.put(tp1, sp1); + sharePartitions.put(tp2, sp2); + sharePartitions.put(tp3, sp3); + sharePartitions.put(tp4, sp4); + + ShareFetch shareFetch = new ShareFetch(FETCH_PARAMS, groupId, Uuid.randomUuid().toString(), + new CompletableFuture<>(), partitionMaxBytes, MAX_FETCH_RECORDS); + + when(sp0.acquire(anyString(), anyInt(), any(FetchPartitionData.class))).thenReturn( + ShareAcquiredRecords.fromAcquiredRecords(new ShareFetchResponseData.AcquiredRecords().setFirstOffset(0).setLastOffset(3).setDeliveryCount((short) 1))); + when(sp1.acquire(anyString(), anyInt(), any(FetchPartitionData.class))).thenReturn( + ShareAcquiredRecords.fromAcquiredRecords(new ShareFetchResponseData.AcquiredRecords().setFirstOffset(0).setLastOffset(3).setDeliveryCount((short) 1))); + + // Only 2 out of 5 partitions are acquirable. + Set acquirableTopicPartitions = new LinkedHashSet<>(); + acquirableTopicPartitions.add(tp0); + acquirableTopicPartitions.add(tp1); + doAnswer(invocation -> buildLogReadResult(acquirableTopicPartitions)).when(replicaManager).readFromLog(any(), any(), any(ReplicaQuota.class), anyBoolean()); + + when(sp0.fetchOffsetMetadata(anyLong())).thenReturn(Optional.of(new LogOffsetMetadata(0, 1, 0))); + when(sp1.fetchOffsetMetadata(anyLong())).thenReturn(Optional.of(new LogOffsetMetadata(0, 1, 0))); + + mockTopicIdPartitionToReturnDataEqualToMinBytes(replicaManager, tp0, 1); + mockTopicIdPartitionToReturnDataEqualToMinBytes(replicaManager, tp1, 1); + + DelayedShareFetch delayedShareFetch = spy(DelayedShareFetchBuilder.builder() + .withShareFetchData(shareFetch) + .withSharePartitions(sharePartitions) + .withReplicaManager(replicaManager) + .withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)) + .build()); + + assertTrue(delayedShareFetch.tryComplete()); + assertTrue(delayedShareFetch.isCompleted()); + + // Since only 2 partitions are acquirable, maxbytes per partition = requestMaxBytes(i.e. 1024*1024) / acquiredTopicPartitions(i.e. 2) + int expectedPartitionMaxBytes = 1024 * 1024 / 2; + LinkedHashMap expectedReadPartitionInfo = new LinkedHashMap<>(); + acquirableTopicPartitions.forEach(topicIdPartition -> expectedReadPartitionInfo.put(topicIdPartition, + new FetchRequest.PartitionData( + topicIdPartition.topicId(), + 0, + 0, + expectedPartitionMaxBytes, + Optional.empty() + ))); + + Mockito.verify(replicaManager, times(1)).readFromLog( + shareFetch.fetchParams(), + CollectionConverters.asScala( + acquirableTopicPartitions.stream().map(topicIdPartition -> + new Tuple2<>(topicIdPartition, expectedReadPartitionInfo.get(topicIdPartition))).collect(Collectors.toList()) + ), + QuotaFactory.UNBOUNDED_QUOTA, + true); + } + + @Test + public void testPartitionMaxBytesFromUniformStrategyInCombineLogReadResponse() { + String groupId = "grp"; + ReplicaManager replicaManager = mock(ReplicaManager.class); + TopicIdPartition tp0 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)); + TopicIdPartition tp1 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 1)); + TopicIdPartition tp2 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 2)); + Map partitionMaxBytes = new HashMap<>(); + partitionMaxBytes.put(tp0, PARTITION_MAX_BYTES); + partitionMaxBytes.put(tp1, PARTITION_MAX_BYTES); + partitionMaxBytes.put(tp2, PARTITION_MAX_BYTES); + + SharePartition sp0 = mock(SharePartition.class); + SharePartition sp1 = mock(SharePartition.class); + SharePartition sp2 = mock(SharePartition.class); + + LinkedHashMap sharePartitions = new LinkedHashMap<>(); + sharePartitions.put(tp0, sp0); + sharePartitions.put(tp1, sp1); + sharePartitions.put(tp2, sp2); + + ShareFetch shareFetch = new ShareFetch( + new FetchParams(ApiKeys.SHARE_FETCH.latestVersion(), FetchRequest.ORDINARY_CONSUMER_ID, -1, MAX_WAIT_MS, + 1, 1024 * 1024, FetchIsolation.HIGH_WATERMARK, Optional.empty()), groupId, Uuid.randomUuid().toString(), + new CompletableFuture<>(), partitionMaxBytes, MAX_FETCH_RECORDS); + + DelayedShareFetch delayedShareFetch = DelayedShareFetchBuilder.builder() + .withShareFetchData(shareFetch) + .withReplicaManager(replicaManager) + .withSharePartitions(sharePartitions) + .withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM)) + .build(); + + LinkedHashMap topicPartitionData = new LinkedHashMap<>(); + topicPartitionData.put(tp0, 0L); + topicPartitionData.put(tp1, 0L); + topicPartitionData.put(tp2, 0L); + + // Existing fetched data already contains tp0. + LinkedHashMap logReadResponse = new LinkedHashMap<>(); + LogReadResult logReadResult = mock(LogReadResult.class); + Records records = mock(Records.class); + when(records.sizeInBytes()).thenReturn(2); + FetchDataInfo fetchDataInfo = new FetchDataInfo(mock(LogOffsetMetadata.class), records); + when(logReadResult.info()).thenReturn(fetchDataInfo); + logReadResponse.put(tp0, logReadResult); + + Set fetchableTopicPartitions = new LinkedHashSet<>(); + fetchableTopicPartitions.add(tp1); + fetchableTopicPartitions.add(tp2); + // We will be doing replica manager fetch only for tp1 and tp2. + doAnswer(invocation -> buildLogReadResult(fetchableTopicPartitions)).when(replicaManager).readFromLog(any(), any(), any(ReplicaQuota.class), anyBoolean()); + LinkedHashMap combinedLogReadResponse = delayedShareFetch.combineLogReadResponse(topicPartitionData, logReadResponse); + + assertEquals(topicPartitionData.keySet(), combinedLogReadResponse.keySet()); + // Since only 2 partitions are fetchable but the third one has already been fetched, maxbytes per partition = requestMaxBytes(i.e. 1024*1024) / acquiredTopicPartitions(i.e. 3) + int expectedPartitionMaxBytes = 1024 * 1024 / 3; + LinkedHashMap expectedReadPartitionInfo = new LinkedHashMap<>(); + fetchableTopicPartitions.forEach(topicIdPartition -> expectedReadPartitionInfo.put(topicIdPartition, + new FetchRequest.PartitionData( + topicIdPartition.topicId(), + 0, + 0, + expectedPartitionMaxBytes, + Optional.empty() + ))); + + Mockito.verify(replicaManager, times(1)).readFromLog( + shareFetch.fetchParams(), + CollectionConverters.asScala( + fetchableTopicPartitions.stream().map(topicIdPartition -> + new Tuple2<>(topicIdPartition, expectedReadPartitionInfo.get(topicIdPartition))).collect(Collectors.toList()) + ), + QuotaFactory.UNBOUNDED_QUOTA, + true); + } + static void mockTopicIdPartitionToReturnDataEqualToMinBytes(ReplicaManager replicaManager, TopicIdPartition topicIdPartition, int minBytes) { LogOffsetMetadata hwmOffsetMetadata = new LogOffsetMetadata(1, 1, minBytes); LogOffsetSnapshot endOffsetSnapshot = new LogOffsetSnapshot(1, mock(LogOffsetMetadata.class), @@ -736,6 +1076,7 @@ static class DelayedShareFetchBuilder { private ReplicaManager replicaManager = mock(ReplicaManager.class); private BiConsumer exceptionHandler = mockExceptionHandler(); private LinkedHashMap sharePartitions = mock(LinkedHashMap.class); + private PartitionMaxBytesStrategy partitionMaxBytesStrategy = mock(PartitionMaxBytesStrategy.class); DelayedShareFetchBuilder withShareFetchData(ShareFetch shareFetch) { this.shareFetch = shareFetch; @@ -757,6 +1098,11 @@ DelayedShareFetchBuilder withSharePartitions(LinkedHashMap records = shareConsumer.poll(Duration.ofMillis(5000)); - assertEquals(1, records.count()); + assertEquals(2, records.count()); } } diff --git a/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala b/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala index 73f9fce42e6b1..b38ec285c403c 100644 --- a/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala +++ b/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala @@ -1310,7 +1310,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo ), ) ) - def testShareFetchBrokerRespectsPartitionsSizeLimit(): Unit = { + def testShareFetchBrokerDoesNotRespectPartitionsSizeLimit(): Unit = { val groupId: String = "group" val memberId = Uuid.randomUuid() @@ -1350,10 +1350,10 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo .setPartitionIndex(partition) .setErrorCode(Errors.NONE.code()) .setAcknowledgeErrorCode(Errors.NONE.code()) - .setAcquiredRecords(expectedAcquiredRecords(Collections.singletonList(0), Collections.singletonList(11), Collections.singletonList(1))) + .setAcquiredRecords(expectedAcquiredRecords(Collections.singletonList(0), Collections.singletonList(12), Collections.singletonList(1))) // The first 10 records will be consumed as it is. For the last 3 records, each of size MAX_PARTITION_BYTES/3, - // only 2 of then will be consumed (offsets 10 and 11) because the inclusion of the third last record will exceed - // the max partition bytes limit + // all 3 of then will be consumed (offsets 10, 11 and 12) because even though the inclusion of the third last record will exceed + // the max partition bytes limit. We should only consider the request level maxBytes as the hard limit. val partitionData = shareFetchResponseData.responses().get(0).partitions().get(0) compareFetchResponsePartitions(expectedPartitionData, partitionData) @@ -1412,15 +1412,15 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo // mocking the behaviour of multiple share consumers from the same share group val metadata1: ShareRequestMetadata = new ShareRequestMetadata(memberId1, ShareRequestMetadata.INITIAL_EPOCH) val acknowledgementsMap1: Map[TopicIdPartition, util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty - val shareFetchRequest1 = createShareFetchRequest(groupId, metadata1, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap1) + val shareFetchRequest1 = createShareFetchRequest(groupId, metadata1, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap1, minBytes = 100, maxBytes = 1500) val metadata2: ShareRequestMetadata = new ShareRequestMetadata(memberId2, ShareRequestMetadata.INITIAL_EPOCH) val acknowledgementsMap2: Map[TopicIdPartition, util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty - val shareFetchRequest2 = createShareFetchRequest(groupId, metadata2, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap2) + val shareFetchRequest2 = createShareFetchRequest(groupId, metadata2, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap2, minBytes = 100, maxBytes = 1500) val metadata3: ShareRequestMetadata = new ShareRequestMetadata(memberId3, ShareRequestMetadata.INITIAL_EPOCH) val acknowledgementsMap3: Map[TopicIdPartition, util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty - val shareFetchRequest3 = createShareFetchRequest(groupId, metadata3, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap3) + val shareFetchRequest3 = createShareFetchRequest(groupId, metadata3, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap3, minBytes = 100, maxBytes = 1500) val shareFetchResponse1 = connectAndReceive[ShareFetchResponse](shareFetchRequest1) val shareFetchResponse2 = connectAndReceive[ShareFetchResponse](shareFetchRequest2) diff --git a/share/src/main/java/org/apache/kafka/server/share/fetch/PartitionMaxBytesStrategy.java b/share/src/main/java/org/apache/kafka/server/share/fetch/PartitionMaxBytesStrategy.java new file mode 100644 index 0000000000000..e0600e02842fc --- /dev/null +++ b/share/src/main/java/org/apache/kafka/server/share/fetch/PartitionMaxBytesStrategy.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.share.fetch; + +import org.apache.kafka.common.TopicIdPartition; + +import java.util.LinkedHashMap; +import java.util.Locale; +import java.util.Set; + +/** + * This interface helps identify the max bytes for topic partitions in a share fetch request based on different strategy types. + */ +public interface PartitionMaxBytesStrategy { + + enum StrategyType { + UNIFORM; + + @Override + public String toString() { + return super.toString().toLowerCase(Locale.ROOT); + } + } + + /** + * Returns the partition max bytes for a given partition based on the strategy type. + * + * @param requestMaxBytes - The total max bytes available for the share fetch request + * @param partitions - The topic partitions in the order for which we compute the partition max bytes. + * @param acquiredPartitionsSize - The total partitions that have been acquired. + * @return the partition max bytes for the topic partitions + */ + LinkedHashMap maxBytes(int requestMaxBytes, Set partitions, int acquiredPartitionsSize); + + static PartitionMaxBytesStrategy type(StrategyType type) { + if (type == null) + throw new IllegalArgumentException("Strategy type cannot be null"); + return switch (type) { + case UNIFORM -> PartitionMaxBytesStrategy::uniformPartitionMaxBytes; + }; + } + + + private static LinkedHashMap uniformPartitionMaxBytes(int requestMaxBytes, Set partitions, int acquiredPartitionsSize) { + checkValidArguments(requestMaxBytes, partitions, acquiredPartitionsSize); + LinkedHashMap partitionMaxBytes = new LinkedHashMap<>(); + partitions.forEach(partition -> partitionMaxBytes.put(partition, requestMaxBytes / acquiredPartitionsSize)); + return partitionMaxBytes; + } + + // Visible for testing. + static void checkValidArguments(int requestMaxBytes, Set partitions, int acquiredPartitionsSize) { + if (partitions == null || partitions.isEmpty()) { + throw new IllegalArgumentException("Partitions to generate max bytes is null or empty"); + } + if (requestMaxBytes <= 0) { + throw new IllegalArgumentException("Request max bytes must be greater than 0"); + } + if (acquiredPartitionsSize <= 0) { + throw new IllegalArgumentException("Acquired partitions size must be greater than 0"); + } + } +} diff --git a/share/src/test/java/org/apache/kafka/server/share/fetch/PartitionMaxBytesStrategyTest.java b/share/src/test/java/org/apache/kafka/server/share/fetch/PartitionMaxBytesStrategyTest.java new file mode 100644 index 0000000000000..3c6c7ad220766 --- /dev/null +++ b/share/src/test/java/org/apache/kafka/server/share/fetch/PartitionMaxBytesStrategyTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.share.fetch; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.server.share.fetch.PartitionMaxBytesStrategy.StrategyType; + +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class PartitionMaxBytesStrategyTest { + + @Test + public void testConstructor() { + assertThrows(IllegalArgumentException.class, () -> PartitionMaxBytesStrategy.type(null)); + assertDoesNotThrow(() -> PartitionMaxBytesStrategy.type(StrategyType.UNIFORM)); + } + + @Test + public void testCheckValidArguments() { + TopicIdPartition topicIdPartition1 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic1", 0)); + TopicIdPartition topicIdPartition2 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic1", 1)); + TopicIdPartition topicIdPartition3 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic2", 0)); + Set partitions = new LinkedHashSet<>(); + partitions.add(topicIdPartition1); + partitions.add(topicIdPartition2); + partitions.add(topicIdPartition3); + + // acquired partitions size is 0. + assertThrows(IllegalArgumentException.class, () -> PartitionMaxBytesStrategy.checkValidArguments( + 100, partitions, 0)); + // empty partitions set. + assertThrows(IllegalArgumentException.class, () -> PartitionMaxBytesStrategy.checkValidArguments( + 100, Collections.EMPTY_SET, 20)); + // partitions is null. + assertThrows(IllegalArgumentException.class, () -> PartitionMaxBytesStrategy.checkValidArguments( + 100, null, 20)); + // request max bytes is 0. + assertThrows(IllegalArgumentException.class, () -> PartitionMaxBytesStrategy.checkValidArguments( + 0, partitions, 20)); + + // Valid arguments. + assertDoesNotThrow(() -> PartitionMaxBytesStrategy.checkValidArguments(100, partitions, 20)); + } + + @Test + public void testUniformStrategy() { + PartitionMaxBytesStrategy partitionMaxBytesStrategy = PartitionMaxBytesStrategy.type(StrategyType.UNIFORM); + TopicIdPartition topicIdPartition1 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic1", 0)); + TopicIdPartition topicIdPartition2 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic1", 1)); + TopicIdPartition topicIdPartition3 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic2", 0)); + Set partitions = new LinkedHashSet<>(); + partitions.add(topicIdPartition1); + partitions.add(topicIdPartition2); + partitions.add(topicIdPartition3); + + LinkedHashMap result = partitionMaxBytesStrategy.maxBytes( + 100, partitions, 3); + assertEquals(result.values().stream().toList(), List.of(33, 33, 33)); + + result = partitionMaxBytesStrategy.maxBytes( + 100, partitions, 5); + assertEquals(result.values().stream().toList(), List.of(20, 20, 20)); + } +}