diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java
index bf67b953dad7f..08aa6b6927f92 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java
@@ -49,6 +49,8 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;
+import static org.apache.kafka.clients.consumer.internals.NetworkClientDelegate.PollResult.EMPTY;
+
/**
*
Manages the request creation and response handling for the Streams group heartbeat. The class creates a
* heartbeat request using the state stored in the membership manager. The requests can be retrieved
@@ -374,6 +376,55 @@ public NetworkClientDelegate.PollResult poll(long currentTimeMs) {
}
}
+ /**
+ * Generate a heartbeat request to leave the group if the state is still LEAVING when this is
+ * called to close the consumer.
+ *
+ * Note that when closing the consumer, even though an event to Unsubscribe is generated
+ * (triggers callbacks and sends leave group), it could be the case that the Unsubscribe event
+ * processing does not complete in time and moves on to close the managers (ex. calls to
+ * close with zero timeout). So we could end up on this pollOnClose with the member in
+ * {@link MemberState#PREPARE_LEAVING} (ex. app thread did not have the time to process the
+ * event to execute callbacks), or {@link MemberState#LEAVING} (ex. the leave request could
+ * not be sent due to coordinator not available at that time). In all cases, the pollOnClose
+ * will be triggered right before sending the final requests, so we ensure that we generate
+ * the request to leave if needed.
+ *
+ * @param currentTimeMs The current system time in milliseconds at which the method was called
+ * @return PollResult containing the request to send
+ */
+ @Override
+ public NetworkClientDelegate.PollResult pollOnClose(long currentTimeMs) {
+ if (membershipManager.isLeavingGroup()) {
+ NetworkClientDelegate.UnsentRequest request = makeHeartbeatRequestAndLogResponse(currentTimeMs);
+ return new NetworkClientDelegate.PollResult(heartbeatRequestState.heartbeatIntervalMs(), List.of(request));
+ }
+ return EMPTY;
+ }
+
+ /**
+ * Returns the delay for which the application thread can safely wait before it should be responsive
+ * to results from the request managers. For example, the subscription state can change when heartbeats
+ * are sent, so blocking for longer than the heartbeat interval might mean the application thread is not
+ * responsive to changes.
+ *
+ * Similarly, we may have to unblock the application thread to send a `PollApplicationEvent` to make sure
+ * our poll timer will not expire while we are polling.
+ *
+ *
In the event that heartbeats are currently being skipped, this still returns the next heartbeat
+ * delay rather than {@code Long.MAX_VALUE} so that the application thread remains responsive.
+ */
+ @Override
+ public long maximumTimeToWait(long currentTimeMs) {
+ pollTimer.update(currentTimeMs);
+ if (pollTimer.isExpired() ||
+ membershipManager.shouldNotWaitForHeartbeatInterval() && !heartbeatRequestState.requestInFlight()) {
+
+ return 0L;
+ }
+ return Math.min(pollTimer.remainingMs() / 2, heartbeatRequestState.timeToNextHeartbeatMs(currentTimeMs));
+ }
+
/**
* A heartbeat should be sent without waiting for the heartbeat interval to expire if:
* - the member is leaving the group
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManagerTest.java
index 126be01e1f504..dae6958035bc5 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManagerTest.java
@@ -45,6 +45,7 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
@@ -1382,6 +1383,132 @@ private static Stream provideOtherErrors() {
.map(Arguments::of);
}
+ @Test
+ public void testPollOnCloseWhenIsNotLeaving() {
+ final StreamsGroupHeartbeatRequestManager heartbeatRequestManager = createStreamsGroupHeartbeatRequestManager();
+
+ NetworkClientDelegate.PollResult result = heartbeatRequestManager.pollOnClose(time.milliseconds());
+
+ assertEquals(NetworkClientDelegate.PollResult.EMPTY, result);
+ }
+
+ @Test
+ public void testPollOnCloseWhenIsLeaving() {
+ final StreamsGroupHeartbeatRequestManager heartbeatRequestManager = createStreamsGroupHeartbeatRequestManager();
+ when(membershipManager.isLeavingGroup()).thenReturn(true);
+ when(membershipManager.groupId()).thenReturn(GROUP_ID);
+ when(membershipManager.memberId()).thenReturn(MEMBER_ID);
+ when(membershipManager.memberEpoch()).thenReturn(LEAVE_GROUP_MEMBER_EPOCH);
+
+ NetworkClientDelegate.PollResult result = heartbeatRequestManager.pollOnClose(time.milliseconds());
+
+ assertEquals(1, result.unsentRequests.size());
+ final NetworkClientDelegate.UnsentRequest networkRequest = result.unsentRequests.get(0);
+ StreamsGroupHeartbeatRequest streamsRequest = (StreamsGroupHeartbeatRequest) networkRequest.requestBuilder().build();
+ assertEquals(GROUP_ID, streamsRequest.data().groupId());
+ assertEquals(MEMBER_ID, streamsRequest.data().memberId());
+ assertEquals(LEAVE_GROUP_MEMBER_EPOCH, streamsRequest.data().memberEpoch());
+ }
+
+ @Test
+ public void testMaximumTimeToWaitPollTimerExpired() {
+ try (
+ final MockedConstruction timerMockedConstruction = mockConstruction(Timer.class, (mock, context) -> {
+ when(mock.isExpired()).thenReturn(true);
+ });
+ final MockedConstruction heartbeatRequestStateMockedConstruction = mockConstruction(
+ HeartbeatRequestState.class,
+ (mock, context) -> {
+ when(mock.requestInFlight()).thenReturn(false);
+ })
+ ) {
+ final StreamsGroupHeartbeatRequestManager heartbeatRequestManager = createStreamsGroupHeartbeatRequestManager();
+ final Timer pollTimer = timerMockedConstruction.constructed().get(0);
+ time.sleep(1234);
+
+ final long maximumTimeToWait = heartbeatRequestManager.maximumTimeToWait(time.milliseconds());
+
+ assertEquals(0, maximumTimeToWait);
+ verify(pollTimer).update(time.milliseconds());
+ }
+ }
+
+ @Test
+ public void testMaximumTimeToWaitWhenHeartbeatShouldBeSentImmediately() {
+ try (
+ final MockedConstruction timerMockedConstruction = mockConstruction(Timer.class);
+ final MockedConstruction heartbeatRequestStateMockedConstruction = mockConstruction(
+ HeartbeatRequestState.class,
+ (mock, context) -> {
+ when(mock.requestInFlight()).thenReturn(false);
+ })
+ ) {
+ final StreamsGroupHeartbeatRequestManager heartbeatRequestManager = createStreamsGroupHeartbeatRequestManager();
+ final Timer pollTimer = timerMockedConstruction.constructed().get(0);
+ when(membershipManager.shouldNotWaitForHeartbeatInterval()).thenReturn(true);
+ time.sleep(1234);
+
+ final long maximumTimeToWait = heartbeatRequestManager.maximumTimeToWait(time.milliseconds());
+
+ assertEquals(0, maximumTimeToWait);
+ verify(pollTimer).update(time.milliseconds());
+ }
+ }
+
+ @ParameterizedTest
+ @CsvSource({"true, false", "false, false", "true, true"})
+ public void testMaximumTimeToWaitWhenHeartbeatShouldBeNotSentImmediately(final boolean isRequestInFlight,
+ final boolean shouldNotWaitForHeartbeatInterval) {
+ final long remainingMs = 12L;
+ final long timeToNextHeartbeatMs = 6L;
+ try (
+ final MockedConstruction timerMockedConstruction = mockConstruction(Timer.class, (mock, context) -> {
+ when(mock.remainingMs()).thenReturn(remainingMs);
+ });
+ final MockedConstruction heartbeatRequestStateMockedConstruction = mockConstruction(
+ HeartbeatRequestState.class,
+ (mock, context) -> {
+ when(mock.requestInFlight()).thenReturn(isRequestInFlight);
+ when(mock.timeToNextHeartbeatMs(anyLong())).thenReturn(timeToNextHeartbeatMs);
+ })
+ ) {
+ final StreamsGroupHeartbeatRequestManager heartbeatRequestManager = createStreamsGroupHeartbeatRequestManager();
+ final Timer pollTimer = timerMockedConstruction.constructed().get(0);
+ when(membershipManager.shouldNotWaitForHeartbeatInterval()).thenReturn(shouldNotWaitForHeartbeatInterval);
+ time.sleep(1234);
+
+ final long maximumTimeToWait = heartbeatRequestManager.maximumTimeToWait(time.milliseconds());
+
+ assertEquals(timeToNextHeartbeatMs, maximumTimeToWait);
+ verify(pollTimer).update(time.milliseconds());
+ }
+ }
+
+ @ParameterizedTest
+ @CsvSource({"12, 5", "10, 6"})
+ public void testMaximumTimeToWaitSelectingMinimumWaitTime(final long remainingMs,
+ final long timeToNextHeartbeatMs) {
+ try (
+ final MockedConstruction timerMockedConstruction = mockConstruction(Timer.class, (mock, context) -> {
+ when(mock.remainingMs()).thenReturn(remainingMs);
+ });
+ final MockedConstruction heartbeatRequestStateMockedConstruction = mockConstruction(
+ HeartbeatRequestState.class,
+ (mock, context) -> {
+ when(mock.timeToNextHeartbeatMs(anyLong())).thenReturn(timeToNextHeartbeatMs);
+ })
+ ) {
+ final StreamsGroupHeartbeatRequestManager heartbeatRequestManager = createStreamsGroupHeartbeatRequestManager();
+ final Timer pollTimer = timerMockedConstruction.constructed().get(0);
+ time.sleep(1234);
+
+ final long maximumTimeToWait = heartbeatRequestManager.maximumTimeToWait(time.milliseconds());
+
+ assertEquals(5, maximumTimeToWait);
+ verify(pollTimer).update(time.milliseconds());
+ }
+ }
+
private static ConsumerConfig config() {
Properties prop = new Properties();
prop.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class);