Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KAFKA-18404: Remove partitionMaxBytes usage from DelayedShareFetch #17870

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 66 additions & 42 deletions core/src/main/java/kafka/server/share/DelayedShareFetch.java
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
import org.apache.kafka.server.purgatory.DelayedOperation;
import org.apache.kafka.server.share.SharePartitionKey;
import org.apache.kafka.server.share.fetch.DelayedShareFetchGroupKey;
import org.apache.kafka.server.share.fetch.PartitionMaxBytesStrategy;
import org.apache.kafka.server.share.fetch.ShareFetch;
import org.apache.kafka.server.storage.log.FetchIsolation;
import org.apache.kafka.server.storage.log.FetchPartitionData;
@@ -60,24 +61,35 @@ public class DelayedShareFetch extends DelayedOperation {
private final ShareFetch shareFetch;
private final ReplicaManager replicaManager;
private final BiConsumer<SharePartitionKey, Throwable> exceptionHandler;
private final PartitionMaxBytesStrategy partitionMaxBytesStrategy;
// The topic partitions that need to be completed for the share fetch request are given by sharePartitions.
// sharePartitions is a subset of shareFetchData. The order of insertion/deletion of entries in sharePartitions is important.
private final LinkedHashMap<TopicIdPartition, SharePartition> sharePartitions;
private LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> partitionsAcquired;
private LinkedHashMap<TopicIdPartition, Long> partitionsAcquired;
private LinkedHashMap<TopicIdPartition, LogReadResult> partitionsAlreadyFetched;

DelayedShareFetch(
ShareFetch shareFetch,
ReplicaManager replicaManager,
BiConsumer<SharePartitionKey, Throwable> exceptionHandler,
LinkedHashMap<TopicIdPartition, SharePartition> sharePartitions) {
this(shareFetch, replicaManager, exceptionHandler, sharePartitions, PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM));
}

DelayedShareFetch(
ShareFetch shareFetch,
ReplicaManager replicaManager,
BiConsumer<SharePartitionKey, Throwable> exceptionHandler,
LinkedHashMap<TopicIdPartition, SharePartition> sharePartitions,
PartitionMaxBytesStrategy partitionMaxBytesStrategy) {
super(shareFetch.fetchParams().maxWaitMs, Optional.empty());
this.shareFetch = shareFetch;
this.replicaManager = replicaManager;
this.partitionsAcquired = new LinkedHashMap<>();
this.partitionsAlreadyFetched = new LinkedHashMap<>();
this.exceptionHandler = exceptionHandler;
this.sharePartitions = sharePartitions;
this.partitionMaxBytesStrategy = partitionMaxBytesStrategy;
}

@Override
@@ -99,7 +111,7 @@ public void onComplete() {
partitionsAcquired.keySet());

try {
LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> topicPartitionData;
LinkedHashMap<TopicIdPartition, Long> topicPartitionData;
// tryComplete did not invoke forceComplete, so we need to check if we have any partitions to fetch.
if (partitionsAcquired.isEmpty())
topicPartitionData = acquirablePartitions();
@@ -121,11 +133,13 @@ public void onComplete() {
}
}

private void completeShareFetchRequest(LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> topicPartitionData) {
private void completeShareFetchRequest(LinkedHashMap<TopicIdPartition, Long> topicPartitionData) {
try {
LinkedHashMap<TopicIdPartition, LogReadResult> responseData;
if (partitionsAlreadyFetched.isEmpty())
responseData = readFromLog(topicPartitionData);
responseData = readFromLog(
topicPartitionData,
partitionMaxBytesStrategy.maxBytes(shareFetch.fetchParams().maxBytes, topicPartitionData.keySet(), topicPartitionData.size()));
else
// There shouldn't be a case when we have a partitionsAlreadyFetched value here and this variable is getting
// updated in a different tryComplete thread.
@@ -158,7 +172,7 @@ private void completeShareFetchRequest(LinkedHashMap<TopicIdPartition, FetchRequ
*/
@Override
public boolean tryComplete() {
LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> topicPartitionData = acquirablePartitions();
LinkedHashMap<TopicIdPartition, Long> topicPartitionData = acquirablePartitions();

try {
if (!topicPartitionData.isEmpty()) {
@@ -167,7 +181,7 @@ public boolean tryComplete() {
// those topic partitions.
LinkedHashMap<TopicIdPartition, LogReadResult> replicaManagerReadResponse = maybeReadFromLog(topicPartitionData);
maybeUpdateFetchOffsetMetadata(topicPartitionData, replicaManagerReadResponse);
if (anyPartitionHasLogReadError(replicaManagerReadResponse) || isMinBytesSatisfied(topicPartitionData)) {
if (anyPartitionHasLogReadError(replicaManagerReadResponse) || isMinBytesSatisfied(topicPartitionData, partitionMaxBytesStrategy.maxBytes(shareFetch.fetchParams().maxBytes, topicPartitionData.keySet(), topicPartitionData.size()))) {
partitionsAcquired = topicPartitionData;
partitionsAlreadyFetched = replicaManagerReadResponse;
boolean completedByMe = forceComplete();
@@ -202,28 +216,18 @@ public boolean tryComplete() {
* Prepare fetch request structure for partitions in the share fetch request for which we can acquire records.
*/
// Visible for testing
LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> acquirablePartitions() {
LinkedHashMap<TopicIdPartition, Long> acquirablePartitions() {
// Initialize the topic partitions for which the fetch should be attempted.
LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> topicPartitionData = new LinkedHashMap<>();
LinkedHashMap<TopicIdPartition, Long> topicPartitionData = new LinkedHashMap<>();

sharePartitions.forEach((topicIdPartition, sharePartition) -> {
int partitionMaxBytes = shareFetch.partitionMaxBytes().getOrDefault(topicIdPartition, 0);
// Add the share partition to the list of partitions to be fetched only if we can
// acquire the fetch lock on it.
if (sharePartition.maybeAcquireFetchLock()) {
try {
// If the share partition is already at capacity, we should not attempt to fetch.
if (sharePartition.canAcquireRecords()) {
topicPartitionData.put(
topicIdPartition,
new FetchRequest.PartitionData(
topicIdPartition.topicId(),
sharePartition.nextFetchOffset(),
0,
partitionMaxBytes,
Optional.empty()
)
);
topicPartitionData.put(topicIdPartition, sharePartition.nextFetchOffset());
} else {
sharePartition.releaseFetchLock();
log.trace("Record lock partition limit exceeded for SharePartition {}, " +
@@ -239,24 +243,28 @@ LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> acquirablePartitions
return topicPartitionData;
}

private LinkedHashMap<TopicIdPartition, LogReadResult> maybeReadFromLog(LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> topicPartitionData) {
LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> partitionsNotMatchingFetchOffsetMetadata = new LinkedHashMap<>();
topicPartitionData.forEach((topicIdPartition, partitionData) -> {
private LinkedHashMap<TopicIdPartition, LogReadResult> maybeReadFromLog(LinkedHashMap<TopicIdPartition, Long> topicPartitionData) {
LinkedHashMap<TopicIdPartition, Long> partitionsNotMatchingFetchOffsetMetadata = new LinkedHashMap<>();
topicPartitionData.forEach((topicIdPartition, fetchOffset) -> {
SharePartition sharePartition = sharePartitions.get(topicIdPartition);
if (sharePartition.fetchOffsetMetadata(partitionData.fetchOffset).isEmpty()) {
partitionsNotMatchingFetchOffsetMetadata.put(topicIdPartition, partitionData);
if (sharePartition.fetchOffsetMetadata(fetchOffset).isEmpty()) {
partitionsNotMatchingFetchOffsetMetadata.put(topicIdPartition, fetchOffset);
}
});
if (partitionsNotMatchingFetchOffsetMetadata.isEmpty()) {
return new LinkedHashMap<>();
}
// We fetch data from replica manager corresponding to the topic partitions that have missing fetch offset metadata.
return readFromLog(partitionsNotMatchingFetchOffsetMetadata);
// Although we are fetching partition max bytes for partitionsNotMatchingFetchOffsetMetadata,
// we will take acquired partitions size = topicPartitionData.size() because we do not want to let the
// leftover partitions to starve which will be fetched later.
return readFromLog(
partitionsNotMatchingFetchOffsetMetadata,
partitionMaxBytesStrategy.maxBytes(shareFetch.fetchParams().maxBytes, partitionsNotMatchingFetchOffsetMetadata.keySet(), topicPartitionData.size()));
}

private void maybeUpdateFetchOffsetMetadata(
LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> topicPartitionData,
LinkedHashMap<TopicIdPartition, LogReadResult> replicaManagerReadResponseData) {
private void maybeUpdateFetchOffsetMetadata(LinkedHashMap<TopicIdPartition, Long> topicPartitionData,
LinkedHashMap<TopicIdPartition, LogReadResult> replicaManagerReadResponseData) {
for (Map.Entry<TopicIdPartition, LogReadResult> entry : replicaManagerReadResponseData.entrySet()) {
TopicIdPartition topicIdPartition = entry.getKey();
SharePartition sharePartition = sharePartitions.get(topicIdPartition);
@@ -267,17 +275,18 @@ private void maybeUpdateFetchOffsetMetadata(
continue;
}
sharePartition.updateFetchOffsetMetadata(
topicPartitionData.get(topicIdPartition).fetchOffset,
topicPartitionData.get(topicIdPartition),
replicaManagerLogReadResult.info().fetchOffsetMetadata);
}
}

// minByes estimation currently assumes the common case where all fetched data is acquirable.
private boolean isMinBytesSatisfied(LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> topicPartitionData) {
private boolean isMinBytesSatisfied(LinkedHashMap<TopicIdPartition, Long> topicPartitionData,
LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes) {
long accumulatedSize = 0;
for (Map.Entry<TopicIdPartition, FetchRequest.PartitionData> entry : topicPartitionData.entrySet()) {
for (Map.Entry<TopicIdPartition, Long> entry : topicPartitionData.entrySet()) {
TopicIdPartition topicIdPartition = entry.getKey();
FetchRequest.PartitionData partitionData = entry.getValue();
long fetchOffset = entry.getValue();

LogOffsetMetadata endOffsetMetadata;
try {
@@ -294,7 +303,7 @@ private boolean isMinBytesSatisfied(LinkedHashMap<TopicIdPartition, FetchRequest

SharePartition sharePartition = sharePartitions.get(topicIdPartition);

Optional<LogOffsetMetadata> optionalFetchOffsetMetadata = sharePartition.fetchOffsetMetadata(partitionData.fetchOffset);
Optional<LogOffsetMetadata> optionalFetchOffsetMetadata = sharePartition.fetchOffsetMetadata(fetchOffset);
if (optionalFetchOffsetMetadata.isEmpty() || optionalFetchOffsetMetadata.get() == LogOffsetMetadata.UNKNOWN_OFFSET_METADATA)
continue;
LogOffsetMetadata fetchOffsetMetadata = optionalFetchOffsetMetadata.get();
@@ -312,7 +321,7 @@ private boolean isMinBytesSatisfied(LinkedHashMap<TopicIdPartition, FetchRequest
return true;
} else if (fetchOffsetMetadata.onSameSegment(endOffsetMetadata)) {
// we take the partition fetch size as upper bound when accumulating the bytes.
long bytesAvailable = Math.min(endOffsetMetadata.positionDiff(fetchOffsetMetadata), partitionData.maxBytes);
long bytesAvailable = Math.min(endOffsetMetadata.positionDiff(fetchOffsetMetadata), partitionMaxBytes.get(topicIdPartition));
accumulatedSize += bytesAvailable;
}
}
@@ -335,13 +344,25 @@ else if (isolationType == FetchIsolation.HIGH_WATERMARK)

}

private LinkedHashMap<TopicIdPartition, LogReadResult> readFromLog(LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> topicPartitionData) {
private LinkedHashMap<TopicIdPartition, LogReadResult> readFromLog(LinkedHashMap<TopicIdPartition, Long> topicPartitionFetchOffsets,
LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes) {
// Filter if there already exists any erroneous topic partition.
Set<TopicIdPartition> partitionsToFetch = shareFetch.filterErroneousTopicPartitions(topicPartitionData.keySet());
Set<TopicIdPartition> partitionsToFetch = shareFetch.filterErroneousTopicPartitions(topicPartitionFetchOffsets.keySet());
if (partitionsToFetch.isEmpty()) {
return new LinkedHashMap<>();
}

LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> topicPartitionData = new LinkedHashMap<>();

topicPartitionFetchOffsets.forEach((topicIdPartition, fetchOffset) -> topicPartitionData.put(topicIdPartition,
new FetchRequest.PartitionData(
topicIdPartition.topicId(),
fetchOffset,
0,
partitionMaxBytes.get(topicIdPartition),
Optional.empty())
));

Seq<Tuple2<TopicIdPartition, LogReadResult>> responseLogResult = replicaManager.readFromLog(
shareFetch.fetchParams(),
CollectionConverters.asScala(
@@ -390,18 +411,21 @@ private void handleFetchException(
}

// Visible for testing.
LinkedHashMap<TopicIdPartition, LogReadResult> combineLogReadResponse(LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> topicPartitionData,
LinkedHashMap<TopicIdPartition, LogReadResult> existingFetchedData) {
LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> missingLogReadTopicPartitions = new LinkedHashMap<>();
topicPartitionData.forEach((topicIdPartition, partitionData) -> {
LinkedHashMap<TopicIdPartition, LogReadResult> combineLogReadResponse(LinkedHashMap<TopicIdPartition, Long> topicPartitionData,
LinkedHashMap<TopicIdPartition, LogReadResult> existingFetchedData) {
LinkedHashMap<TopicIdPartition, Long> missingLogReadTopicPartitions = new LinkedHashMap<>();
topicPartitionData.forEach((topicIdPartition, fetchOffset) -> {
if (!existingFetchedData.containsKey(topicIdPartition)) {
missingLogReadTopicPartitions.put(topicIdPartition, partitionData);
missingLogReadTopicPartitions.put(topicIdPartition, fetchOffset);
}
});
if (missingLogReadTopicPartitions.isEmpty()) {
return existingFetchedData;
}
LinkedHashMap<TopicIdPartition, LogReadResult> missingTopicPartitionsLogReadResponse = readFromLog(missingLogReadTopicPartitions);

LinkedHashMap<TopicIdPartition, LogReadResult> missingTopicPartitionsLogReadResponse = readFromLog(
missingLogReadTopicPartitions,
partitionMaxBytesStrategy.maxBytes(shareFetch.fetchParams().maxBytes, missingLogReadTopicPartitions.keySet(), topicPartitionData.size()));
missingTopicPartitionsLogReadResponse.putAll(existingFetchedData);
return missingTopicPartitionsLogReadResponse;
}
357 changes: 352 additions & 5 deletions core/src/test/java/kafka/server/share/DelayedShareFetchTest.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -64,6 +64,7 @@
import org.apache.kafka.server.share.context.ShareSessionContext;
import org.apache.kafka.server.share.fetch.DelayedShareFetchGroupKey;
import org.apache.kafka.server.share.fetch.DelayedShareFetchKey;
import org.apache.kafka.server.share.fetch.PartitionMaxBytesStrategy;
import org.apache.kafka.server.share.fetch.ShareAcquiredRecords;
import org.apache.kafka.server.share.fetch.ShareFetch;
import org.apache.kafka.server.share.persister.NoOpShareStatePersister;
@@ -1711,6 +1712,7 @@ public void testAcknowledgeCompletesDelayedShareFetchRequest() {
.withShareFetchData(shareFetch)
.withReplicaManager(mockReplicaManager)
.withSharePartitions(sharePartitions)
.withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM))
.build();

delayedShareFetchPurgatory.tryCompleteElseWatch(delayedShareFetch, delayedShareFetchWatchKeys);
@@ -1912,6 +1914,7 @@ public void testReleaseSessionCompletesDelayedShareFetchRequest() {
.withShareFetchData(shareFetch)
.withReplicaManager(mockReplicaManager)
.withSharePartitions(sharePartitions)
.withPartitionMaxBytesStrategy(PartitionMaxBytesStrategy.type(PartitionMaxBytesStrategy.StrategyType.UNIFORM))
.build();

delayedShareFetchPurgatory.tryCompleteElseWatch(delayedShareFetch, delayedShareFetchWatchKeys);
2 changes: 1 addition & 1 deletion core/src/test/java/kafka/test/api/ShareConsumerTest.java
Original file line number Diff line number Diff line change
@@ -902,7 +902,7 @@ public void testFetchRecordLargerThanMaxPartitionFetchBytes(String persister) th
shareConsumer.subscribe(Collections.singleton(tp.topic()));

ConsumerRecords<byte[], byte[]> records = shareConsumer.poll(Duration.ofMillis(5000));
assertEquals(1, records.count());
assertEquals(2, records.count());
}
}

Original file line number Diff line number Diff line change
@@ -1310,7 +1310,7 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo
),
)
)
def testShareFetchBrokerRespectsPartitionsSizeLimit(): Unit = {
def testShareFetchBrokerDoesNotRespectPartitionsSizeLimit(): Unit = {
val groupId: String = "group"
val memberId = Uuid.randomUuid()

@@ -1350,10 +1350,10 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo
.setPartitionIndex(partition)
.setErrorCode(Errors.NONE.code())
.setAcknowledgeErrorCode(Errors.NONE.code())
.setAcquiredRecords(expectedAcquiredRecords(Collections.singletonList(0), Collections.singletonList(11), Collections.singletonList(1)))
.setAcquiredRecords(expectedAcquiredRecords(Collections.singletonList(0), Collections.singletonList(12), Collections.singletonList(1)))
// The first 10 records will be consumed as it is. For the last 3 records, each of size MAX_PARTITION_BYTES/3,
// only 2 of then will be consumed (offsets 10 and 11) because the inclusion of the third last record will exceed
// the max partition bytes limit
// all 3 of then will be consumed (offsets 10, 11 and 12) because even though the inclusion of the third last record will exceed
// the max partition bytes limit. We should only consider the request level maxBytes as the hard limit.

val partitionData = shareFetchResponseData.responses().get(0).partitions().get(0)
compareFetchResponsePartitions(expectedPartitionData, partitionData)
@@ -1412,15 +1412,15 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo
// mocking the behaviour of multiple share consumers from the same share group
val metadata1: ShareRequestMetadata = new ShareRequestMetadata(memberId1, ShareRequestMetadata.INITIAL_EPOCH)
val acknowledgementsMap1: Map[TopicIdPartition, util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
val shareFetchRequest1 = createShareFetchRequest(groupId, metadata1, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap1)
val shareFetchRequest1 = createShareFetchRequest(groupId, metadata1, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap1, minBytes = 100, maxBytes = 1500)

val metadata2: ShareRequestMetadata = new ShareRequestMetadata(memberId2, ShareRequestMetadata.INITIAL_EPOCH)
val acknowledgementsMap2: Map[TopicIdPartition, util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
val shareFetchRequest2 = createShareFetchRequest(groupId, metadata2, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap2)
val shareFetchRequest2 = createShareFetchRequest(groupId, metadata2, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap2, minBytes = 100, maxBytes = 1500)

val metadata3: ShareRequestMetadata = new ShareRequestMetadata(memberId3, ShareRequestMetadata.INITIAL_EPOCH)
val acknowledgementsMap3: Map[TopicIdPartition, util.List[ShareFetchRequestData.AcknowledgementBatch]] = Map.empty
val shareFetchRequest3 = createShareFetchRequest(groupId, metadata3, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap3)
val shareFetchRequest3 = createShareFetchRequest(groupId, metadata3, MAX_PARTITION_BYTES, send, Seq.empty, acknowledgementsMap3, minBytes = 100, maxBytes = 1500)

val shareFetchResponse1 = connectAndReceive[ShareFetchResponse](shareFetchRequest1)
val shareFetchResponse2 = connectAndReceive[ShareFetchResponse](shareFetchRequest2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.server.share.fetch;

import org.apache.kafka.common.TopicIdPartition;

import java.util.LinkedHashMap;
import java.util.Locale;
import java.util.Set;

/**
* This interface helps identify the max bytes for topic partitions in a share fetch request based on different strategy types.
*/
public interface PartitionMaxBytesStrategy {

enum StrategyType {
UNIFORM;

@Override
public String toString() {
return super.toString().toLowerCase(Locale.ROOT);
}
}

/**
* Returns the partition max bytes for a given partition based on the strategy type.
*
* @param requestMaxBytes - The total max bytes available for the share fetch request
* @param partitions - The topic partitions in the order for which we compute the partition max bytes.
* @param acquiredPartitionsSize - The total partitions that have been acquired.
* @return the partition max bytes for the topic partitions
*/
LinkedHashMap<TopicIdPartition, Integer> maxBytes(int requestMaxBytes, Set<TopicIdPartition> partitions, int acquiredPartitionsSize);

static PartitionMaxBytesStrategy type(StrategyType type) {
if (type == null)
throw new IllegalArgumentException("Strategy type cannot be null");
return switch (type) {
case UNIFORM -> PartitionMaxBytesStrategy::uniformPartitionMaxBytes;
};
}


private static LinkedHashMap<TopicIdPartition, Integer> uniformPartitionMaxBytes(int requestMaxBytes, Set<TopicIdPartition> partitions, int acquiredPartitionsSize) {
checkValidArguments(requestMaxBytes, partitions, acquiredPartitionsSize);
LinkedHashMap<TopicIdPartition, Integer> partitionMaxBytes = new LinkedHashMap<>();
partitions.forEach(partition -> partitionMaxBytes.put(partition, requestMaxBytes / acquiredPartitionsSize));
return partitionMaxBytes;
}

// Visible for testing.
static void checkValidArguments(int requestMaxBytes, Set<TopicIdPartition> partitions, int acquiredPartitionsSize) {
if (partitions == null || partitions.isEmpty()) {
throw new IllegalArgumentException("Partitions to generate max bytes is null or empty");
}
if (requestMaxBytes <= 0) {
throw new IllegalArgumentException("Request max bytes must be greater than 0");
}
if (acquiredPartitionsSize <= 0) {
throw new IllegalArgumentException("Acquired partitions size must be greater than 0");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.server.share.fetch;

import org.apache.kafka.common.TopicIdPartition;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.Uuid;
import org.apache.kafka.server.share.fetch.PartitionMaxBytesStrategy.StrategyType;

import org.junit.jupiter.api.Test;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class PartitionMaxBytesStrategyTest {

@Test
public void testConstructor() {
assertThrows(IllegalArgumentException.class, () -> PartitionMaxBytesStrategy.type(null));
assertDoesNotThrow(() -> PartitionMaxBytesStrategy.type(StrategyType.UNIFORM));
}

@Test
public void testCheckValidArguments() {
TopicIdPartition topicIdPartition1 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic1", 0));
TopicIdPartition topicIdPartition2 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic1", 1));
TopicIdPartition topicIdPartition3 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic2", 0));
Set<TopicIdPartition> partitions = new LinkedHashSet<>();
partitions.add(topicIdPartition1);
partitions.add(topicIdPartition2);
partitions.add(topicIdPartition3);

// acquired partitions size is 0.
assertThrows(IllegalArgumentException.class, () -> PartitionMaxBytesStrategy.checkValidArguments(
100, partitions, 0));
// empty partitions set.
assertThrows(IllegalArgumentException.class, () -> PartitionMaxBytesStrategy.checkValidArguments(
100, Collections.EMPTY_SET, 20));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologize for the delayed review. Could you please use Set.of() to address the following warnings?

> Task :share:compileTestJava
Note: /home/jenkins/kafka/share/src/test/java/org/apache/kafka/server/share/fetch/PartitionMaxBytesStrategyTest.java uses unchecked or unsafe operations.
Note: Recompile with -Xlint:unchecked for details.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chia7712 I missed it in review, thanks for pointing out. @adixitconfluent is busy so I just raised this minor PR: #18541

// partitions is null.
assertThrows(IllegalArgumentException.class, () -> PartitionMaxBytesStrategy.checkValidArguments(
100, null, 20));
// request max bytes is 0.
assertThrows(IllegalArgumentException.class, () -> PartitionMaxBytesStrategy.checkValidArguments(
0, partitions, 20));

// Valid arguments.
assertDoesNotThrow(() -> PartitionMaxBytesStrategy.checkValidArguments(100, partitions, 20));
}

@Test
public void testUniformStrategy() {
PartitionMaxBytesStrategy partitionMaxBytesStrategy = PartitionMaxBytesStrategy.type(StrategyType.UNIFORM);
TopicIdPartition topicIdPartition1 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic1", 0));
TopicIdPartition topicIdPartition2 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic1", 1));
TopicIdPartition topicIdPartition3 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("topic2", 0));
Set<TopicIdPartition> partitions = new LinkedHashSet<>();
partitions.add(topicIdPartition1);
partitions.add(topicIdPartition2);
partitions.add(topicIdPartition3);

LinkedHashMap<TopicIdPartition, Integer> result = partitionMaxBytesStrategy.maxBytes(
100, partitions, 3);
assertEquals(result.values().stream().toList(), List.of(33, 33, 33));

result = partitionMaxBytesStrategy.maxBytes(
100, partitions, 5);
assertEquals(result.values().stream().toList(), List.of(20, 20, 20));
}
}