From 97095928b8e9e4a1fb83bfc699ef5b3c73564cdf Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 14 Jan 2025 15:18:10 +0100
Subject: [PATCH 01/13] KAFKA-18518: Add processor to handle rebalance events

This commit adds a processor named
StreamsRebalanceEventsProcessor that handles the rebalance
events sent from the background thread of the async
consumer to the stream thread when an task
assignment changes. It also adds the corresponding rebalance
events.

Additionally, this commit adds StreamsRebalanceData that
maintains the data that is exchanges for the Streams rebalance
protocol.

All of these are used by the Streams heartbeat request manager
and the Streams membership manager that will be added in a future
commit.
---
 .../StreamsGroupRebalanceCallbacks.java       |  13 +
 .../internals/StreamsRebalanceData.java       | 197 ++++++++++++++
 .../StreamsRebalanceEventsProcessor.java      | 148 +++++++++++
 .../internals/events/ApplicationEvent.java    |   3 +
 .../internals/events/BackgroundEvent.java     |   4 +-
 ...sOnAllTasksLostCallbackCompletedEvent.java |  51 ++++
 ...eamsOnAllTasksLostCallbackNeededEvent.java |  29 ++
 ...OnTasksAssignedCallbackCompletedEvent.java |  51 ++++
 ...amsOnTasksAssignedCallbackNeededEvent.java |  41 +++
 ...sOnTasksRevokedCallbackCompletedEvent.java |  51 ++++
 ...eamsOnTasksRevokedCallbackNeededEvent.java |  42 +++
 .../StreamsRebalanceEventsProcessorTest.java  | 250 ++++++++++++++++++
 12 files changed, 879 insertions(+), 1 deletion(-)
 create mode 100644 clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java
 create mode 100644 clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
 create mode 100644 clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
 create mode 100644 clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnAllTasksLostCallbackCompletedEvent.java
 create mode 100644 clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnAllTasksLostCallbackNeededEvent.java
 create mode 100644 clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksAssignedCallbackCompletedEvent.java
 create mode 100644 clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksAssignedCallbackNeededEvent.java
 create mode 100644 clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksRevokedCallbackCompletedEvent.java
 create mode 100644 clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksRevokedCallbackNeededEvent.java
 create mode 100644 clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java
new file mode 100644
index 0000000000000..f6fc52fc18c93
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java
@@ -0,0 +1,13 @@
+package org.apache.kafka.clients.consumer.internals;
+
+import java.util.Optional;
+import java.util.Set;
+
+public interface StreamsGroupRebalanceCallbacks {
+
+    Optional<Exception> onTasksRevoked(final Set<StreamsRebalanceData.TaskId> tasks);
+
+    Optional<Exception> onTasksAssigned(final StreamsRebalanceData.Assignment assignment);
+
+    Optional<Exception> onAllTasksLost();
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
new file mode 100644
index 0000000000000..8ab47927f8c9d
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -0,0 +1,197 @@
+package org.apache.kafka.clients.consumer.internals;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicReference;
+
+public class StreamsRebalanceData {
+
+    public static class Assignment {
+
+        public static final Assignment EMPTY = new Assignment();
+
+        public final Set<TaskId> activeTasks = new HashSet<>();
+
+        public final Set<TaskId> standbyTasks = new HashSet<>();
+
+        public final Set<TaskId> warmupTasks = new HashSet<>();
+
+        public Assignment() {
+        }
+
+        public Assignment(final Set<TaskId> activeTasks,
+                          final Set<TaskId> standbyTasks,
+                          final Set<TaskId> warmupTasks) {
+            this.activeTasks.addAll(activeTasks);
+            this.standbyTasks.addAll(standbyTasks);
+            this.warmupTasks.addAll(warmupTasks);
+        }
+
+        @Override
+        public boolean equals(final Object o) {
+            if (this == o) {
+                return true;
+            }
+            if (o == null || getClass() != o.getClass()) {
+                return false;
+            }
+            final Assignment that = (Assignment) o;
+            return Objects.equals(activeTasks, that.activeTasks)
+                && Objects.equals(standbyTasks, that.standbyTasks)
+                && Objects.equals(warmupTasks, that.warmupTasks);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(activeTasks, standbyTasks, warmupTasks);
+        }
+
+        public Assignment copy() {
+            return new Assignment(activeTasks, standbyTasks, warmupTasks);
+        }
+
+        @Override
+        public String toString() {
+            return "Assignment{" +
+                "activeTasks=" + activeTasks +
+                ", standbyTasks=" + standbyTasks +
+                ", warmupTasks=" + warmupTasks +
+                '}';
+        }
+    }
+
+    public static class TaskId implements Comparable<TaskId> {
+
+        private final String subtopologyId;
+        private final int partitionId;
+
+        public int partitionId() {
+            return partitionId;
+        }
+
+        public String subtopologyId() {
+            return subtopologyId;
+        }
+
+        public TaskId(final String subtopologyId, final int partitionId) {
+            this.subtopologyId = subtopologyId;
+            this.partitionId = partitionId;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            TaskId taskId = (TaskId) o;
+            return partitionId == taskId.partitionId && Objects.equals(subtopologyId, taskId.subtopologyId);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(subtopologyId, partitionId);
+        }
+
+        @Override
+        public int compareTo(TaskId taskId) {
+            if (subtopologyId.equals(taskId.subtopologyId)) {
+                return partitionId - taskId.partitionId;
+            }
+            return subtopologyId.compareTo(taskId.subtopologyId);
+        }
+
+        @Override
+        public String toString() {
+            return "TaskId{" +
+                "subtopologyId=" + subtopologyId +
+                ", partitionId=" + partitionId +
+                '}';
+        }
+    }
+
+    public static class Subtopology {
+
+        public final Set<String> sourceTopics;
+        public final Set<String> repartitionSinkTopics;
+        public final Map<String, TopicInfo> stateChangelogTopics;
+        public final Map<String, TopicInfo> repartitionSourceTopics;
+        public final Collection<Set<String>> copartitionGroups;
+
+        public Subtopology(final Set<String> sourceTopics,
+                           final Set<String> repartitionSinkTopics,
+                           final Map<String, TopicInfo> repartitionSourceTopics,
+                           final Map<String, TopicInfo> stateChangelogTopics,
+                           final Collection<Set<String>> copartitionGroups
+        ) {
+            this.sourceTopics = sourceTopics;
+            this.repartitionSinkTopics = repartitionSinkTopics;
+            this.stateChangelogTopics = stateChangelogTopics;
+            this.repartitionSourceTopics = repartitionSourceTopics;
+            this.copartitionGroups = copartitionGroups;
+        }
+
+        @Override
+        public String toString() {
+            return "Subtopology{" +
+                "sourceTopics=" + sourceTopics +
+                ", repartitionSinkTopics=" + repartitionSinkTopics +
+                ", stateChangelogTopics=" + stateChangelogTopics +
+                ", repartitionSourceTopics=" + repartitionSourceTopics +
+                ", copartitionGroups=" + copartitionGroups +
+                '}';
+        }
+    }
+
+    public static class TopicInfo {
+
+        public final Optional<Integer> numPartitions;
+        public final Optional<Short> replicationFactor;
+        public final Map<String, String> topicConfigs;
+
+        public TopicInfo(final Optional<Integer> numPartitions,
+                         final Optional<Short> replicationFactor,
+                         final Map<String, String> topicConfigs) {
+            this.numPartitions = numPartitions;
+            this.replicationFactor = replicationFactor;
+            this.topicConfigs = topicConfigs;
+        }
+
+        @Override
+        public String toString() {
+            return "TopicInfo{" +
+                "numPartitions=" + numPartitions +
+                ", replicationFactor=" + replicationFactor +
+                ", topicConfigs=" + topicConfigs +
+                '}';
+        }
+    }
+
+    private final Map<String, Subtopology> subtopologies;
+
+    private final AtomicReference<Assignment> reconciledAssignment = new AtomicReference<>(
+        new Assignment(
+            new HashSet<>(),
+            new HashSet<>(),
+            new HashSet<>()
+        )
+    );
+
+    public StreamsRebalanceData(Map<String, Subtopology> subtopologies) {
+        this.subtopologies = subtopologies;
+    }
+
+    public Map<String, Subtopology> subtopologies() {
+        return subtopologies;
+    }
+
+    public void setReconciledAssignment(final Assignment assignment) {
+        reconciledAssignment.set(assignment);
+    }
+
+    public Assignment reconciledAssignment() {
+        return reconciledAssignment.get();
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
new file mode 100644
index 0000000000000..ae883ec12b8af
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
@@ -0,0 +1,148 @@
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.clients.consumer.internals.events.ApplicationEventHandler;
+import org.apache.kafka.clients.consumer.internals.events.BackgroundEvent;
+import org.apache.kafka.clients.consumer.internals.events.ErrorEvent;
+import org.apache.kafka.clients.consumer.internals.events.StreamsOnAllTasksLostCallbackCompletedEvent;
+import org.apache.kafka.clients.consumer.internals.events.StreamsOnAllTasksLostCallbackNeededEvent;
+import org.apache.kafka.clients.consumer.internals.events.StreamsOnTasksAssignedCallbackCompletedEvent;
+import org.apache.kafka.clients.consumer.internals.events.StreamsOnTasksAssignedCallbackNeededEvent;
+import org.apache.kafka.clients.consumer.internals.events.StreamsOnTasksRevokedCallbackCompletedEvent;
+import org.apache.kafka.clients.consumer.internals.events.StreamsOnTasksRevokedCallbackNeededEvent;
+import org.apache.kafka.common.KafkaException;
+
+import java.util.LinkedList;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.LinkedBlockingQueue;
+
+public class StreamsRebalanceEventsProcessor {
+
+    private final BlockingQueue<BackgroundEvent> onCallbackRequests = new LinkedBlockingQueue<>();
+    private ApplicationEventHandler applicationEventHandler = null;
+    private final StreamsGroupRebalanceCallbacks rebalanceCallbacks;
+    private final StreamsRebalanceData streamsRebalanceData;
+
+    public StreamsRebalanceEventsProcessor(StreamsRebalanceData streamsRebalanceData,
+                                           StreamsGroupRebalanceCallbacks rebalanceCallbacks) {
+        this.streamsRebalanceData = streamsRebalanceData;
+        this.rebalanceCallbacks = rebalanceCallbacks;
+    }
+
+    public CompletableFuture<Void> requestOnTasksAssignedCallbackInvocation(final StreamsRebalanceData.Assignment assignment) {
+        final StreamsOnTasksAssignedCallbackNeededEvent onTasksAssignedCallbackNeededEvent = new StreamsOnTasksAssignedCallbackNeededEvent(assignment);
+        onCallbackRequests.add(onTasksAssignedCallbackNeededEvent);
+        return onTasksAssignedCallbackNeededEvent.future();
+    }
+
+    public CompletableFuture<Void> requestOnTasksRevokedCallbackInvocation(final Set<StreamsRebalanceData.TaskId> activeTasksToRevoke) {
+        final StreamsOnTasksRevokedCallbackNeededEvent onTasksRevokedCallbackNeededEvent = new StreamsOnTasksRevokedCallbackNeededEvent(activeTasksToRevoke);
+        onCallbackRequests.add(onTasksRevokedCallbackNeededEvent);
+        return onTasksRevokedCallbackNeededEvent.future();
+    }
+
+    public CompletableFuture<Void> requestOnAllTasksLostCallbackInvocation() {
+        final StreamsOnAllTasksLostCallbackNeededEvent onAllTasksLostCallbackNeededEvent = new StreamsOnAllTasksLostCallbackNeededEvent();
+        onCallbackRequests.add(onAllTasksLostCallbackNeededEvent);
+        return onAllTasksLostCallbackNeededEvent.future();
+    }
+
+    public void setApplicationEventHandler(final ApplicationEventHandler applicationEventHandler) {
+        this.applicationEventHandler = applicationEventHandler;
+    }
+
+    private void process(final BackgroundEvent event) {
+        switch (event.type()) {
+            case ERROR:
+                throw ((ErrorEvent) event).error();
+
+            case STREAMS_ON_TASKS_REVOKED_CALLBACK_NEEDED:
+                processStreamsOnTasksRevokedCallbackNeededEvent((StreamsOnTasksRevokedCallbackNeededEvent) event);
+                break;
+
+            case STREAMS_ON_TASKS_ASSIGNED_CALLBACK_NEEDED:
+                processStreamsOnTasksAssignedCallbackNeededEvent((StreamsOnTasksAssignedCallbackNeededEvent) event);
+                break;
+
+            case STREAMS_ON_ALL_TASKS_LOST_CALLBACK_NEEDED:
+                processStreamsOnAllTasksLostCallbackNeededEvent((StreamsOnAllTasksLostCallbackNeededEvent) event);
+                break;
+
+            default:
+                throw new IllegalArgumentException("Background event type " + event.type() + " was not expected");
+
+        }
+    }
+
+    private void processStreamsOnTasksRevokedCallbackNeededEvent(final StreamsOnTasksRevokedCallbackNeededEvent event) {
+        StreamsOnTasksRevokedCallbackCompletedEvent invokedEvent = invokeOnTasksRevokedCallback(event.activeTasksToRevoke(), event.future());
+        applicationEventHandler.add(invokedEvent);
+        if (invokedEvent.error().isPresent()) {
+            throw invokedEvent.error().get();
+        }
+    }
+
+    private void processStreamsOnTasksAssignedCallbackNeededEvent(final StreamsOnTasksAssignedCallbackNeededEvent event) {
+        StreamsOnTasksAssignedCallbackCompletedEvent invokedEvent = invokeOnTasksAssignedCallback(event.assignment(), event.future());
+        applicationEventHandler.add(invokedEvent);
+        if (invokedEvent.error().isPresent()) {
+            throw invokedEvent.error().get();
+        }
+    }
+
+    private void processStreamsOnAllTasksLostCallbackNeededEvent(final StreamsOnAllTasksLostCallbackNeededEvent event) {
+        StreamsOnAllTasksLostCallbackCompletedEvent invokedEvent = invokeOnAllTasksLostCallback(event.future());
+        applicationEventHandler.add(invokedEvent);
+        if (invokedEvent.error().isPresent()) {
+            throw invokedEvent.error().get();
+        }
+    }
+
+    private StreamsOnTasksRevokedCallbackCompletedEvent invokeOnTasksRevokedCallback(final Set<StreamsRebalanceData.TaskId> activeTasksToRevoke,
+                                                                                     final CompletableFuture<Void> future) {
+        final Optional<Exception> exceptionFromCallback = rebalanceCallbacks.onTasksRevoked(activeTasksToRevoke);
+        return exceptionFromCallback
+            .map(exception ->
+                new StreamsOnTasksRevokedCallbackCompletedEvent(
+                    future,
+                    Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exception, "Task revocation callback throws an error"))
+                ))
+            .orElseGet(() -> new StreamsOnTasksRevokedCallbackCompletedEvent(future, Optional.empty()));
+    }
+
+    private StreamsOnTasksAssignedCallbackCompletedEvent invokeOnTasksAssignedCallback(final StreamsRebalanceData.Assignment assignment,
+                                                                                       final CompletableFuture<Void> future) {
+        Optional<KafkaException> error = Optional.empty();
+        final Optional<Exception> exceptionFromCallback = rebalanceCallbacks.onTasksAssigned(assignment);
+        if (exceptionFromCallback.isPresent()) {
+            error = Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(), "Task assignment callback throws an error"));
+        } else {
+            streamsRebalanceData.setReconciledAssignment(assignment);
+        }
+        return new StreamsOnTasksAssignedCallbackCompletedEvent(future, error);
+    }
+
+    private StreamsOnAllTasksLostCallbackCompletedEvent invokeOnAllTasksLostCallback(final CompletableFuture<Void> future) {
+        final Optional<Exception> exceptionFromCallback = rebalanceCallbacks.onAllTasksLost();
+        final Optional<KafkaException> error;
+        if (exceptionFromCallback.isPresent()) {
+            error = Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(), "All tasks lost callback throws an error"));
+        } else {
+            error = Optional.empty();
+            streamsRebalanceData.setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
+        }
+
+        return new StreamsOnAllTasksLostCallbackCompletedEvent(future, error);
+    }
+
+    public void process() {
+        LinkedList<BackgroundEvent> events = new LinkedList<>();
+        onCallbackRequests.drainTo(events);
+        for (BackgroundEvent event : events) {
+            process(event);
+        }
+    }
+
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEvent.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEvent.java
index dfb775f8947c1..6f0557772a4bf 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEvent.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEvent.java
@@ -41,6 +41,9 @@ public enum Type {
         SHARE_ACKNOWLEDGE_ON_CLOSE,
         SHARE_ACKNOWLEDGEMENT_COMMIT_CALLBACK_REGISTRATION,
         SEEK_UNVALIDATED,
+        STREAMS_ON_TASKS_ASSIGNED_CALLBACK_COMPLETED,
+        STREAMS_ON_TASKS_REVOKED_CALLBACK_COMPLETED,
+        STREAMS_ON_ALL_TASKS_LOST_CALLBACK_COMPLETED,
     }
 
     private final Type type;
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/BackgroundEvent.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/BackgroundEvent.java
index 02fc4b4a29ba4..9c704abbabd9f 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/BackgroundEvent.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/BackgroundEvent.java
@@ -27,7 +27,9 @@
 public abstract class BackgroundEvent {
 
     public enum Type {
-        ERROR, CONSUMER_REBALANCE_LISTENER_CALLBACK_NEEDED, SHARE_ACKNOWLEDGEMENT_COMMIT_CALLBACK
+        ERROR, CONSUMER_REBALANCE_LISTENER_CALLBACK_NEEDED, SHARE_ACKNOWLEDGEMENT_COMMIT_CALLBACK,
+        STREAMS_ON_TASKS_ASSIGNED_CALLBACK_NEEDED, STREAMS_ON_TASKS_REVOKED_CALLBACK_NEEDED,
+        STREAMS_ON_ALL_TASKS_LOST_CALLBACK_NEEDED
     }
 
     private final Type type;
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnAllTasksLostCallbackCompletedEvent.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnAllTasksLostCallbackCompletedEvent.java
new file mode 100644
index 0000000000000..b84e9d0c1386d
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnAllTasksLostCallbackCompletedEvent.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals.events;
+
+import org.apache.kafka.common.KafkaException;
+
+import java.util.Objects;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+
+public class StreamsOnAllTasksLostCallbackCompletedEvent extends ApplicationEvent {
+
+    private final CompletableFuture<Void> future;
+    private final Optional<KafkaException> error;
+
+    public StreamsOnAllTasksLostCallbackCompletedEvent(final CompletableFuture<Void> future,
+                                                       final Optional<KafkaException> error) {
+        super(Type.STREAMS_ON_ALL_TASKS_LOST_CALLBACK_COMPLETED);
+        this.future = Objects.requireNonNull(future);
+        this.error = Objects.requireNonNull(error);
+    }
+
+    public CompletableFuture<Void> future() {
+        return future;
+    }
+
+    public Optional<KafkaException> error() {
+        return error;
+    }
+
+    @Override
+    protected String toStringBase() {
+        return super.toStringBase() +
+            ", future=" + future +
+            ", error=" + error;
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnAllTasksLostCallbackNeededEvent.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnAllTasksLostCallbackNeededEvent.java
new file mode 100644
index 0000000000000..29e1a94ec7239
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnAllTasksLostCallbackNeededEvent.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals.events;
+
+public class StreamsOnAllTasksLostCallbackNeededEvent extends CompletableBackgroundEvent<Void> {
+
+    public StreamsOnAllTasksLostCallbackNeededEvent() {
+        super(Type.STREAMS_ON_ALL_TASKS_LOST_CALLBACK_NEEDED, Long.MAX_VALUE);
+    }
+
+    @Override
+    protected String toStringBase() {
+        return super.toStringBase();
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksAssignedCallbackCompletedEvent.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksAssignedCallbackCompletedEvent.java
new file mode 100644
index 0000000000000..96c2519bb2d33
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksAssignedCallbackCompletedEvent.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals.events;
+
+import org.apache.kafka.common.KafkaException;
+
+import java.util.Objects;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+
+public class StreamsOnTasksAssignedCallbackCompletedEvent extends ApplicationEvent {
+
+    private final CompletableFuture<Void> future;
+    private final Optional<KafkaException> error;
+
+    public StreamsOnTasksAssignedCallbackCompletedEvent(final CompletableFuture<Void> future,
+                                                        final Optional<KafkaException> error) {
+        super(Type.STREAMS_ON_TASKS_ASSIGNED_CALLBACK_COMPLETED);
+        this.future = Objects.requireNonNull(future);
+        this.error = Objects.requireNonNull(error);
+    }
+
+    public CompletableFuture<Void> future() {
+        return future;
+    }
+
+    public Optional<KafkaException> error() {
+        return error;
+    }
+
+    @Override
+    protected String toStringBase() {
+        return super.toStringBase() +
+            ", future=" + future +
+            ", error=" + error;
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksAssignedCallbackNeededEvent.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksAssignedCallbackNeededEvent.java
new file mode 100644
index 0000000000000..565bf97c6b775
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksAssignedCallbackNeededEvent.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals.events;
+
+import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData;
+
+import java.util.Objects;
+
+public class StreamsOnTasksAssignedCallbackNeededEvent extends CompletableBackgroundEvent<Void> {
+
+    private final StreamsRebalanceData.Assignment assignment;
+
+    public StreamsOnTasksAssignedCallbackNeededEvent(StreamsRebalanceData.Assignment assignment) {
+        super(Type.STREAMS_ON_TASKS_ASSIGNED_CALLBACK_NEEDED, Long.MAX_VALUE);
+        this.assignment = Objects.requireNonNull(assignment);
+    }
+
+    public StreamsRebalanceData.Assignment assignment() {
+        return assignment;
+    }
+
+    @Override
+    protected String toStringBase() {
+        return super.toStringBase() +
+            ", assignment=" + assignment;
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksRevokedCallbackCompletedEvent.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksRevokedCallbackCompletedEvent.java
new file mode 100644
index 0000000000000..5717012ac4576
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksRevokedCallbackCompletedEvent.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals.events;
+
+import org.apache.kafka.common.KafkaException;
+
+import java.util.Objects;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+
+public class StreamsOnTasksRevokedCallbackCompletedEvent extends ApplicationEvent {
+
+    private final CompletableFuture<Void> future;
+    private final Optional<KafkaException> error;
+
+    public StreamsOnTasksRevokedCallbackCompletedEvent(final CompletableFuture<Void> future,
+                                                       final Optional<KafkaException> error) {
+        super(Type.STREAMS_ON_TASKS_REVOKED_CALLBACK_COMPLETED);
+        this.future = Objects.requireNonNull(future);
+        this.error = Objects.requireNonNull(error);
+    }
+
+    public CompletableFuture<Void> future() {
+        return future;
+    }
+
+    public Optional<KafkaException> error() {
+        return error;
+    }
+
+    @Override
+    protected String toStringBase() {
+        return super.toStringBase() +
+            ", future=" + future +
+            ", error=" + error;
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksRevokedCallbackNeededEvent.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksRevokedCallbackNeededEvent.java
new file mode 100644
index 0000000000000..1e3e58a9f5b79
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/StreamsOnTasksRevokedCallbackNeededEvent.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals.events;
+
+import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData;
+
+import java.util.Objects;
+import java.util.Set;
+
+public class StreamsOnTasksRevokedCallbackNeededEvent extends CompletableBackgroundEvent<Void> {
+
+    private final Set<StreamsRebalanceData.TaskId> activeTasksToRevoke;
+
+    public StreamsOnTasksRevokedCallbackNeededEvent(final Set<StreamsRebalanceData.TaskId> activeTasksToRevoke) {
+        super(Type.STREAMS_ON_TASKS_REVOKED_CALLBACK_NEEDED, Long.MAX_VALUE);
+        this.activeTasksToRevoke = Objects.requireNonNull(activeTasksToRevoke);
+    }
+
+    public Set<StreamsRebalanceData.TaskId> activeTasksToRevoke() {
+        return activeTasksToRevoke;
+    }
+
+    @Override
+    protected String toStringBase() {
+        return super.toStringBase() +
+            ", active tasks to revoke=" + activeTasksToRevoke;
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
new file mode 100644
index 0000000000000..d9abae8feee9f
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
@@ -0,0 +1,250 @@
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.clients.consumer.internals.events.ApplicationEventHandler;
+import org.apache.kafka.clients.consumer.internals.events.StreamsOnAllTasksLostCallbackCompletedEvent;
+import org.apache.kafka.clients.consumer.internals.events.StreamsOnTasksAssignedCallbackCompletedEvent;
+import org.apache.kafka.clients.consumer.internals.events.StreamsOnTasksRevokedCallbackCompletedEvent;
+import org.apache.kafka.common.KafkaException;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import java.util.Collections;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+public class StreamsRebalanceEventsProcessorTest {
+
+    private static final String SUBTOPOLOGY_0 = "subtopology-0";
+    private static final String SUBTOPOLOGY_1 = "subtopology-1";
+
+    @Mock
+    private StreamsGroupRebalanceCallbacks rebalanceCallbacks;
+
+    @Mock
+    private ApplicationEventHandler applicationEventHandler;
+
+    @Test
+    public void shouldInvokeOnTasksAssignedCallback() {
+        final StreamsRebalanceData rebalanceData = new StreamsRebalanceData(Collections.emptyMap());
+        final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
+            new StreamsRebalanceEventsProcessor(rebalanceData, rebalanceCallbacks);
+        rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
+        final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 1)
+        );
+        final Set<StreamsRebalanceData.TaskId> standbyTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 1),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0)
+        );
+        final Set<StreamsRebalanceData.TaskId> warmupTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 2),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 3)
+        );
+        StreamsRebalanceData.Assignment assignment =
+            new StreamsRebalanceData.Assignment(activeTasks, standbyTasks, warmupTasks);
+        when(rebalanceCallbacks.onTasksAssigned(assignment)).thenReturn(Optional.empty());
+
+        final CompletableFuture<Void> onTasksAssignedExecuted = rebalanceEventsProcessor.requestOnTasksAssignedCallbackInvocation(assignment);
+
+        assertFalse(onTasksAssignedExecuted.isDone());
+        rebalanceEventsProcessor.process();
+        ArgumentCaptor<StreamsOnTasksAssignedCallbackCompletedEvent> streamsOnTasksAssignedCallbackCompletedCaptor =
+            ArgumentCaptor.forClass(StreamsOnTasksAssignedCallbackCompletedEvent.class);
+        verify(applicationEventHandler).add(streamsOnTasksAssignedCallbackCompletedCaptor.capture());
+        StreamsOnTasksAssignedCallbackCompletedEvent streamsOnTasksAssignedCallbackCompletedEvent =
+            streamsOnTasksAssignedCallbackCompletedCaptor.getValue();
+        assertFalse(streamsOnTasksAssignedCallbackCompletedEvent.future().isDone());
+        assertTrue(streamsOnTasksAssignedCallbackCompletedEvent.error().isEmpty());
+        assertEquals(assignment, rebalanceData.reconciledAssignment());
+    }
+
+    @Test
+    public void shouldReThrowErrorFromOnTasksAssignedCallbackAndPassErrorToBackground() {
+        final StreamsRebalanceData rebalanceData = new StreamsRebalanceData(Collections.emptyMap());
+        final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
+            new StreamsRebalanceEventsProcessor(rebalanceData, rebalanceCallbacks);
+        rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
+        final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 1)
+        );
+        final Set<StreamsRebalanceData.TaskId> standbyTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 1),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0)
+        );
+        final Set<StreamsRebalanceData.TaskId> warmupTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 2),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 3)
+        );
+        StreamsRebalanceData.Assignment assignment =
+            new StreamsRebalanceData.Assignment(activeTasks, standbyTasks, warmupTasks);
+        final Exception exception = new RuntimeException("Nobody expects the Spanish inquisition.");
+        when(rebalanceCallbacks.onTasksAssigned(assignment)).thenReturn(Optional.of(exception));
+
+        final CompletableFuture<Void> onTasksAssignedExecuted = rebalanceEventsProcessor.requestOnTasksAssignedCallbackInvocation(assignment);
+
+        assertFalse(onTasksAssignedExecuted.isDone());
+        final Exception actualException = assertThrows(KafkaException.class, rebalanceEventsProcessor::process);
+        assertEquals("Task assignment callback throws an error", actualException.getMessage());
+        ArgumentCaptor<StreamsOnTasksAssignedCallbackCompletedEvent> streamsOnTasksAssignedCallbackCompletedCaptor =
+            ArgumentCaptor.forClass(StreamsOnTasksAssignedCallbackCompletedEvent.class);
+        verify(applicationEventHandler).add(streamsOnTasksAssignedCallbackCompletedCaptor.capture());
+        StreamsOnTasksAssignedCallbackCompletedEvent streamsOnTasksAssignedCallbackCompletedEvent =
+            streamsOnTasksAssignedCallbackCompletedCaptor.getValue();
+        assertFalse(streamsOnTasksAssignedCallbackCompletedEvent.future().isDone());
+        assertTrue(streamsOnTasksAssignedCallbackCompletedEvent.error().isPresent());
+        assertEquals(exception, streamsOnTasksAssignedCallbackCompletedEvent.error().get().getCause());
+        assertEquals(exception, actualException.getCause());
+        assertEquals(StreamsRebalanceData.Assignment.EMPTY, rebalanceData.reconciledAssignment());
+    }
+
+    @Test
+    public void shouldInvokeOnTasksRevokedCallback() {
+        final StreamsRebalanceData rebalanceData = new StreamsRebalanceData(Collections.emptyMap());
+        final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
+            new StreamsRebalanceEventsProcessor(rebalanceData, rebalanceCallbacks);
+        rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
+        final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 1)
+        );
+        when(rebalanceCallbacks.onTasksRevoked(activeTasks)).thenReturn(Optional.empty());
+
+        final CompletableFuture<Void> onTasksRevokedExecuted = rebalanceEventsProcessor.requestOnTasksRevokedCallbackInvocation(activeTasks);
+
+        assertFalse(onTasksRevokedExecuted.isDone());
+        rebalanceEventsProcessor.process();
+        ArgumentCaptor<StreamsOnTasksRevokedCallbackCompletedEvent> streamsOnTasksRevokedCallbackCompletedCaptor =
+            ArgumentCaptor.forClass(StreamsOnTasksRevokedCallbackCompletedEvent.class);
+        verify(applicationEventHandler).add(streamsOnTasksRevokedCallbackCompletedCaptor.capture());
+        StreamsOnTasksRevokedCallbackCompletedEvent streamsOnTasksRevokedCallbackCompletedEvent =
+            streamsOnTasksRevokedCallbackCompletedCaptor.getValue();
+        assertFalse(streamsOnTasksRevokedCallbackCompletedEvent.future().isDone());
+        assertTrue(streamsOnTasksRevokedCallbackCompletedEvent.error().isEmpty());
+    }
+
+    @Test
+    public void shouldReThrowErrorFromOnTasksRevokedCallbackAndPassErrorToBackground() {
+        final StreamsRebalanceData rebalanceData = new StreamsRebalanceData(Collections.emptyMap());
+        final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
+            new StreamsRebalanceEventsProcessor(rebalanceData, rebalanceCallbacks);
+        rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
+        final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 1)
+        );
+        final Exception exception = new RuntimeException("Nobody expects the Spanish inquisition.");
+        when(rebalanceCallbacks.onTasksRevoked(activeTasks)).thenReturn(Optional.of(exception));
+
+        final CompletableFuture<Void> onTasksRevokedExecuted = rebalanceEventsProcessor.requestOnTasksRevokedCallbackInvocation(activeTasks);
+
+        assertFalse(onTasksRevokedExecuted.isDone());
+        final Exception actualException = assertThrows(KafkaException.class, rebalanceEventsProcessor::process);
+        assertEquals("Task revocation callback throws an error", actualException.getMessage());
+        ArgumentCaptor<StreamsOnTasksRevokedCallbackCompletedEvent> streamsOnTasksRevokedCallbackCompletedCaptor =
+            ArgumentCaptor.forClass(StreamsOnTasksRevokedCallbackCompletedEvent.class);
+        verify(applicationEventHandler).add(streamsOnTasksRevokedCallbackCompletedCaptor.capture());
+        StreamsOnTasksRevokedCallbackCompletedEvent streamsOnTasksRevokedCallbackCompletedEvent =
+            streamsOnTasksRevokedCallbackCompletedCaptor.getValue();
+        assertTrue(streamsOnTasksRevokedCallbackCompletedEvent.error().isPresent());
+        assertEquals(exception, streamsOnTasksRevokedCallbackCompletedEvent.error().get().getCause());
+        assertEquals(exception, actualException.getCause());
+    }
+
+    @Test
+    public void shouldInvokeOnAllTasksLostCallback() {
+        final StreamsRebalanceData rebalanceData = new StreamsRebalanceData(Collections.emptyMap());
+        final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
+            new StreamsRebalanceEventsProcessor(rebalanceData, rebalanceCallbacks);
+        rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
+        final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 1)
+        );
+        final Set<StreamsRebalanceData.TaskId> standbyTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 1),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0)
+        );
+        final Set<StreamsRebalanceData.TaskId> warmupTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 2),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 3)
+        );
+        StreamsRebalanceData.Assignment assignment =
+            new StreamsRebalanceData.Assignment(activeTasks, standbyTasks, warmupTasks);
+        when(rebalanceCallbacks.onTasksAssigned(assignment)).thenReturn(Optional.empty());
+        rebalanceEventsProcessor.requestOnTasksAssignedCallbackInvocation(assignment);
+        rebalanceEventsProcessor.process();
+        assertEquals(assignment, rebalanceData.reconciledAssignment());
+        when(rebalanceCallbacks.onAllTasksLost()).thenReturn(Optional.empty());
+
+        final CompletableFuture<Void> onAllTasksLostExecuted = rebalanceEventsProcessor.requestOnAllTasksLostCallbackInvocation();
+
+        assertFalse(onAllTasksLostExecuted.isDone());
+        rebalanceEventsProcessor.process();
+        ArgumentCaptor<StreamsOnAllTasksLostCallbackCompletedEvent> streamsOnAllTasksLostCallbackCompletedCaptor =
+            ArgumentCaptor.forClass(StreamsOnAllTasksLostCallbackCompletedEvent.class);
+        verify(applicationEventHandler).add(streamsOnAllTasksLostCallbackCompletedCaptor.capture());
+        StreamsOnAllTasksLostCallbackCompletedEvent streamsOnAllTasksLostCallbackCompletedEvent =
+            streamsOnAllTasksLostCallbackCompletedCaptor.getValue();
+        assertFalse(streamsOnAllTasksLostCallbackCompletedEvent.future().isDone());
+        assertTrue(streamsOnAllTasksLostCallbackCompletedEvent.error().isEmpty());
+        assertEquals(StreamsRebalanceData.Assignment.EMPTY, rebalanceData.reconciledAssignment());
+    }
+
+    @Test
+    public void shouldReThrowErrorFromOnAllTasksLostCallbackAndPassErrorToBackground() {
+        final StreamsRebalanceData rebalanceData = new StreamsRebalanceData(Collections.emptyMap());
+        final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
+            new StreamsRebalanceEventsProcessor(rebalanceData, rebalanceCallbacks);
+        rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
+        final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 1)
+        );
+        final Set<StreamsRebalanceData.TaskId> standbyTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 1),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0)
+        );
+        final Set<StreamsRebalanceData.TaskId> warmupTasks = Set.of(
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_1, 2),
+            new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 3)
+        );
+        StreamsRebalanceData.Assignment assignment =
+            new StreamsRebalanceData.Assignment(activeTasks, standbyTasks, warmupTasks);
+        when(rebalanceCallbacks.onTasksAssigned(assignment)).thenReturn(Optional.empty());
+        rebalanceEventsProcessor.requestOnTasksAssignedCallbackInvocation(assignment);
+        rebalanceEventsProcessor.process();
+        assertEquals(assignment, rebalanceData.reconciledAssignment());
+        final Exception exception = new RuntimeException("Nobody expects the Spanish inquisition.");
+        when(rebalanceCallbacks.onAllTasksLost()).thenReturn(Optional.of(exception));
+
+        final CompletableFuture<Void> onAllTasksLostExecuted = rebalanceEventsProcessor.requestOnAllTasksLostCallbackInvocation();
+
+        assertFalse(onAllTasksLostExecuted.isDone());
+        final Exception actualException = assertThrows(KafkaException.class, rebalanceEventsProcessor::process);
+        assertEquals("All tasks lost callback throws an error", actualException.getMessage());
+        ArgumentCaptor<StreamsOnAllTasksLostCallbackCompletedEvent> streamsOnAllTasksLostCallbackCompletedCaptor =
+            ArgumentCaptor.forClass(StreamsOnAllTasksLostCallbackCompletedEvent.class);
+        verify(applicationEventHandler).add(streamsOnAllTasksLostCallbackCompletedCaptor.capture());
+        StreamsOnAllTasksLostCallbackCompletedEvent streamsOnAllTasksLostCallbackCompletedEvent =
+            streamsOnAllTasksLostCallbackCompletedCaptor.getValue();
+        assertFalse(streamsOnAllTasksLostCallbackCompletedEvent.future().isDone());
+        assertTrue(streamsOnAllTasksLostCallbackCompletedEvent.error().isPresent());
+        assertEquals(exception, streamsOnAllTasksLostCallbackCompletedEvent.error().get().getCause());
+        assertEquals(exception, actualException.getCause());
+        assertEquals(assignment, rebalanceData.reconciledAssignment());
+    }
+}
\ No newline at end of file

From dc845c0302be22697a565284db162d166cff007e Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 14 Jan 2025 16:31:01 +0100
Subject: [PATCH 02/13] Add license headers

---
 .../StreamsGroupRebalanceCallbacks.java          | 16 ++++++++++++++++
 .../consumer/internals/StreamsRebalanceData.java | 16 ++++++++++++++++
 .../StreamsRebalanceEventsProcessor.java         | 16 ++++++++++++++++
 .../StreamsRebalanceEventsProcessorTest.java     | 16 ++++++++++++++++
 4 files changed, 64 insertions(+)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java
index f6fc52fc18c93..dae6a15f8e481 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java
@@ -1,3 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 package org.apache.kafka.clients.consumer.internals;
 
 import java.util.Optional;
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
index 8ab47927f8c9d..101f3038cee7a 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -1,3 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 package org.apache.kafka.clients.consumer.internals;
 
 import java.util.Collection;
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
index ae883ec12b8af..22fae2fc82b0d 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
@@ -1,3 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 package org.apache.kafka.clients.consumer.internals;
 
 import org.apache.kafka.clients.consumer.internals.events.ApplicationEventHandler;
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
index d9abae8feee9f..fafe5acf21719 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
@@ -1,3 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 package org.apache.kafka.clients.consumer.internals;
 
 import org.apache.kafka.clients.consumer.internals.events.ApplicationEventHandler;

From 21c4bee2c7a56fc99bf9765d5e8e6753d3ec162e Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 14 Jan 2025 16:45:03 +0100
Subject: [PATCH 03/13] Apply spotless

---
 .../consumer/internals/StreamsRebalanceEventsProcessorTest.java  | 1 +
 1 file changed, 1 insertion(+)

diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
index fafe5acf21719..f30aa2718b177 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
@@ -21,6 +21,7 @@
 import org.apache.kafka.clients.consumer.internals.events.StreamsOnTasksAssignedCallbackCompletedEvent;
 import org.apache.kafka.clients.consumer.internals.events.StreamsOnTasksRevokedCallbackCompletedEvent;
 import org.apache.kafka.common.KafkaException;
+
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
 import org.mockito.ArgumentCaptor;

From 29dd4e0c28f7b7ef498435412f05da80576d68dd Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 14 Jan 2025 17:06:08 +0100
Subject: [PATCH 04/13] Add javadocs

---
 .../StreamsGroupRebalanceCallbacks.java       | 20 +++++++++
 .../internals/StreamsRebalanceData.java       | 11 ++---
 .../StreamsRebalanceEventsProcessor.java      | 41 +++++++++++++++++++
 3 files changed, 65 insertions(+), 7 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java
index dae6a15f8e481..4840e79ecebc4 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupRebalanceCallbacks.java
@@ -19,11 +19,31 @@
 import java.util.Optional;
 import java.util.Set;
 
+/**
+ * Callbacks for handling Streams group rebalance events in Kafka Streams.
+ */
 public interface StreamsGroupRebalanceCallbacks {
 
+    /**
+     * Called when tasks are revoked from a stream thread.
+     *
+     * @param tasks The tasks to be revoked.
+     * @return The exception thrown during the callback, if any.
+     */
     Optional<Exception> onTasksRevoked(final Set<StreamsRebalanceData.TaskId> tasks);
 
+    /**
+     * Called when tasks are assigned from a stream thread.
+     *
+     * @param assignment The tasks assigned.
+     * @return The exception thrown during the callback, if any.
+     */
     Optional<Exception> onTasksAssigned(final StreamsRebalanceData.Assignment assignment);
 
+    /**
+     * Called when a stream thread loses all assigned tasks.
+     *
+     * @return The exception thrown during the callback, if any.
+     */
     Optional<Exception> onAllTasksLost();
 }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
index 101f3038cee7a..ccfb267c4c002 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -24,6 +24,9 @@
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicReference;
 
+/**
+ * This class holds the data that is needed to participate in the Streams rebalance protocol.
+ */
 public class StreamsRebalanceData {
 
     public static class Assignment {
@@ -187,13 +190,7 @@ public String toString() {
 
     private final Map<String, Subtopology> subtopologies;
 
-    private final AtomicReference<Assignment> reconciledAssignment = new AtomicReference<>(
-        new Assignment(
-            new HashSet<>(),
-            new HashSet<>(),
-            new HashSet<>()
-        )
-    );
+    private final AtomicReference<Assignment> reconciledAssignment = new AtomicReference<>(Assignment.EMPTY);
 
     public StreamsRebalanceData(Map<String, Subtopology> subtopologies) {
         this.subtopologies = subtopologies;
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
index 22fae2fc82b0d..4b411d1481c27 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
@@ -34,6 +34,14 @@
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.LinkedBlockingQueue;
 
+/**
+ * Processes events from the Streams rebalance protocol.
+ * <p>
+ * The Streams rebalance processor receives events from the background thread of the async consumer, more precisely
+ * from the Streams membership manager and handles them.
+ * For example, events are requests for invoking the task assignment and task revocation callbacks.
+ * Results of the event handling are passed back to the background thread.
+ */
 public class StreamsRebalanceEventsProcessor {
 
     private final BlockingQueue<BackgroundEvent> onCallbackRequests = new LinkedBlockingQueue<>();
@@ -41,30 +49,60 @@ public class StreamsRebalanceEventsProcessor {
     private final StreamsGroupRebalanceCallbacks rebalanceCallbacks;
     private final StreamsRebalanceData streamsRebalanceData;
 
+    /**
+     * Constructs the Streams rebalance processor.
+     *
+     * @param streamsRebalanceData
+     * @param rebalanceCallbacks
+     */
     public StreamsRebalanceEventsProcessor(StreamsRebalanceData streamsRebalanceData,
                                            StreamsGroupRebalanceCallbacks rebalanceCallbacks) {
         this.streamsRebalanceData = streamsRebalanceData;
         this.rebalanceCallbacks = rebalanceCallbacks;
     }
 
+    /**
+     * Requests the invocation of the task assignment callback.
+     *
+     * @param assignment The tasks to be assigned to the member of the Streams group.
+     * @return A future that will be completed when the callback has been invoked.
+     */
     public CompletableFuture<Void> requestOnTasksAssignedCallbackInvocation(final StreamsRebalanceData.Assignment assignment) {
         final StreamsOnTasksAssignedCallbackNeededEvent onTasksAssignedCallbackNeededEvent = new StreamsOnTasksAssignedCallbackNeededEvent(assignment);
         onCallbackRequests.add(onTasksAssignedCallbackNeededEvent);
         return onTasksAssignedCallbackNeededEvent.future();
     }
 
+    /**
+     * Requests the invocation of the task revocation callback.
+     *
+     * @param activeTasksToRevoke The tasks to revoke from the member of the Streams group
+     * @return A future that will be completed when the callback has been invoked.
+     */
     public CompletableFuture<Void> requestOnTasksRevokedCallbackInvocation(final Set<StreamsRebalanceData.TaskId> activeTasksToRevoke) {
         final StreamsOnTasksRevokedCallbackNeededEvent onTasksRevokedCallbackNeededEvent = new StreamsOnTasksRevokedCallbackNeededEvent(activeTasksToRevoke);
         onCallbackRequests.add(onTasksRevokedCallbackNeededEvent);
         return onTasksRevokedCallbackNeededEvent.future();
     }
 
+    /**
+     * Requests the invocation of the all tasks lost callback.
+     *
+     * @return A future that will be completed when the callback has been invoked.
+     */
     public CompletableFuture<Void> requestOnAllTasksLostCallbackInvocation() {
         final StreamsOnAllTasksLostCallbackNeededEvent onAllTasksLostCallbackNeededEvent = new StreamsOnAllTasksLostCallbackNeededEvent();
         onCallbackRequests.add(onAllTasksLostCallbackNeededEvent);
         return onAllTasksLostCallbackNeededEvent.future();
     }
 
+    /**
+     * Sets the application event handler.
+     *
+     * The application handler sends the results of the callbacks to the background thread.
+     *
+     * @param applicationEventHandler The application handler.
+     */
     public void setApplicationEventHandler(final ApplicationEventHandler applicationEventHandler) {
         this.applicationEventHandler = applicationEventHandler;
     }
@@ -153,6 +191,9 @@ private StreamsOnAllTasksLostCallbackCompletedEvent invokeOnAllTasksLostCallback
         return new StreamsOnAllTasksLostCallbackCompletedEvent(future, error);
     }
 
+    /**
+     * Processes all events received from the background thread so far.
+     */
     public void process() {
         LinkedList<BackgroundEvent> events = new LinkedList<>();
         onCallbackRequests.drainTo(events);

From e7ac06ef054f5530619ae0414e26e273f278c35e Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Mon, 20 Jan 2025 17:32:52 +0100
Subject: [PATCH 05/13] Improve TaskId

---
 .../internals/StreamsRebalanceData.java       | 104 +++++++++---------
 .../internals/StreamsRebalanceDataTest.java   |  35 ++++++
 2 files changed, 87 insertions(+), 52 deletions(-)
 create mode 100644 clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
index ccfb267c4c002..504f74e7e641d 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.clients.consumer.internals;
 
 import java.util.Collection;
+import java.util.Comparator;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Objects;
@@ -29,17 +30,64 @@
  */
 public class StreamsRebalanceData {
 
+    public static class TaskId implements Comparable<TaskId> {
+
+        private final String subtopologyId;
+        private final int partitionId;
+
+        public TaskId(final String subtopologyId, final int partitionId) {
+            this.subtopologyId = subtopologyId;
+            this.partitionId = partitionId;
+        }
+
+        public int partitionId() {
+            return partitionId;
+        }
+
+        public String subtopologyId() {
+            return subtopologyId;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            TaskId taskId = (TaskId) o;
+            return partitionId == taskId.partitionId && Objects.equals(subtopologyId, taskId.subtopologyId);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(subtopologyId, partitionId);
+        }
+
+        @Override
+        public int compareTo(TaskId taskId) {
+            Objects.requireNonNull(taskId, "taskId cannot be null");
+            return Comparator.comparing(TaskId::subtopologyId)
+                .thenComparingInt(TaskId::partitionId).compare(this, taskId);
+        }
+
+        @Override
+        public String toString() {
+            return "TaskId{" +
+                "subtopologyId=" + subtopologyId +
+                ", partitionId=" + partitionId +
+                '}';
+        }
+    }
+
     public static class Assignment {
 
         public static final Assignment EMPTY = new Assignment();
 
-        public final Set<TaskId> activeTasks = new HashSet<>();
+        private final Set<TaskId> activeTasks = new HashSet<>();
 
-        public final Set<TaskId> standbyTasks = new HashSet<>();
+        private final Set<TaskId> standbyTasks = new HashSet<>();
 
-        public final Set<TaskId> warmupTasks = new HashSet<>();
+        private final Set<TaskId> warmupTasks = new HashSet<>();
 
-        public Assignment() {
+        private Assignment() {
         }
 
         public Assignment(final Set<TaskId> activeTasks,
@@ -83,54 +131,6 @@ public String toString() {
         }
     }
 
-    public static class TaskId implements Comparable<TaskId> {
-
-        private final String subtopologyId;
-        private final int partitionId;
-
-        public int partitionId() {
-            return partitionId;
-        }
-
-        public String subtopologyId() {
-            return subtopologyId;
-        }
-
-        public TaskId(final String subtopologyId, final int partitionId) {
-            this.subtopologyId = subtopologyId;
-            this.partitionId = partitionId;
-        }
-
-        @Override
-        public boolean equals(Object o) {
-            if (this == o) return true;
-            if (o == null || getClass() != o.getClass()) return false;
-            TaskId taskId = (TaskId) o;
-            return partitionId == taskId.partitionId && Objects.equals(subtopologyId, taskId.subtopologyId);
-        }
-
-        @Override
-        public int hashCode() {
-            return Objects.hash(subtopologyId, partitionId);
-        }
-
-        @Override
-        public int compareTo(TaskId taskId) {
-            if (subtopologyId.equals(taskId.subtopologyId)) {
-                return partitionId - taskId.partitionId;
-            }
-            return subtopologyId.compareTo(taskId.subtopologyId);
-        }
-
-        @Override
-        public String toString() {
-            return "TaskId{" +
-                "subtopologyId=" + subtopologyId +
-                ", partitionId=" + partitionId +
-                '}';
-        }
-    }
-
     public static class Subtopology {
 
         public final Set<String> sourceTopics;
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
new file mode 100644
index 0000000000000..098c01f8cd907
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
@@ -0,0 +1,35 @@
+package org.apache.kafka.clients.consumer.internals;
+
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class StreamsRebalanceDataTest {
+
+    @Test
+    public void testTaskIdEqualsAndHashCode() {
+        final StreamsRebalanceData.TaskId task = new StreamsRebalanceData.TaskId("subtopologyId1", 1);
+        final StreamsRebalanceData.TaskId taskEqual = new StreamsRebalanceData.TaskId(task.subtopologyId(), task.partitionId());
+        final StreamsRebalanceData.TaskId taskUnequalSubtopology = new StreamsRebalanceData.TaskId(task.subtopologyId() + "1", task.partitionId());
+        final StreamsRebalanceData.TaskId taskUnequalPartition = new StreamsRebalanceData.TaskId(task.subtopologyId(), task.partitionId() + 1);
+        assertEquals(task, taskEqual);
+        assertEquals(task.hashCode(), taskEqual.hashCode());
+        assertNotEquals(task, taskUnequalSubtopology);
+        assertNotEquals(task.hashCode(), taskUnequalSubtopology.hashCode());
+        assertNotEquals(task, taskUnequalPartition);
+        assertNotEquals(task.hashCode(), taskUnequalSubtopology.hashCode());
+    }
+
+    @Test
+    public void testTaskIdCompareTo() {
+        final StreamsRebalanceData.TaskId task = new StreamsRebalanceData.TaskId("subtopologyId1", 1);
+        assertTrue(task.compareTo(new StreamsRebalanceData.TaskId(task.subtopologyId(), task.partitionId())) == 0);
+        assertTrue(task.compareTo(new StreamsRebalanceData.TaskId(task.subtopologyId() + "1", task.partitionId())) < 0);
+        assertTrue(task.compareTo(new StreamsRebalanceData.TaskId(task.subtopologyId(), task.partitionId() + 1)) < 0);
+        assertTrue(new StreamsRebalanceData.TaskId(task.subtopologyId() + "1", task.partitionId()).compareTo(task) > 0);
+        assertTrue(new StreamsRebalanceData.TaskId(task.subtopologyId(), task.partitionId() + 1).compareTo(task) > 0);
+    }
+}
\ No newline at end of file

From c653ddd59ecf5b492788712feffcba507f14bd7a Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 21 Jan 2025 09:48:02 +0100
Subject: [PATCH 06/13] Improve Assignment

---
 .../internals/StreamsRebalanceData.java       |  28 +++-
 .../internals/StreamsRebalanceDataTest.java   | 127 ++++++++++++++++++
 2 files changed, 149 insertions(+), 6 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
index 504f74e7e641d..cd4417df508db 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.clients.consumer.internals;
 
 import java.util.Collection;
+import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashSet;
 import java.util.Map;
@@ -81,21 +82,36 @@ public static class Assignment {
 
         public static final Assignment EMPTY = new Assignment();
 
-        private final Set<TaskId> activeTasks = new HashSet<>();
+        private final Set<TaskId> activeTasks;
 
-        private final Set<TaskId> standbyTasks = new HashSet<>();
+        private final Set<TaskId> standbyTasks;
 
-        private final Set<TaskId> warmupTasks = new HashSet<>();
+        private final Set<TaskId> warmupTasks;
 
         private Assignment() {
+            this.activeTasks = Set.of();
+            this.standbyTasks = Set.of();
+            this.warmupTasks = Set.of();
         }
 
         public Assignment(final Set<TaskId> activeTasks,
                           final Set<TaskId> standbyTasks,
                           final Set<TaskId> warmupTasks) {
-            this.activeTasks.addAll(activeTasks);
-            this.standbyTasks.addAll(standbyTasks);
-            this.warmupTasks.addAll(warmupTasks);
+            this.activeTasks = Collections.unmodifiableSet(Objects.requireNonNull(activeTasks, "Active tasks cannot be null"));
+            this.standbyTasks = Collections.unmodifiableSet(Objects.requireNonNull(standbyTasks, "Standby tasks cannot be null"));
+            this.warmupTasks = Collections.unmodifiableSet(Objects.requireNonNull(warmupTasks, "Warmup tasks cannot be null"));
+        }
+
+        public Set<TaskId> activeTasks() {
+            return activeTasks;
+        }
+
+        public Set<TaskId> standbyTasks() {
+            return standbyTasks;
+        }
+
+        public Set<TaskId> warmupTasks() {
+            return warmupTasks;
         }
 
         @Override
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
index 098c01f8cd907..bb0432ea3ebe3 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
@@ -3,8 +3,14 @@
 
 import org.junit.jupiter.api.Test;
 
+import java.util.HashSet;
+import java.util.Set;
+
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class StreamsRebalanceDataTest {
@@ -15,6 +21,7 @@ public void testTaskIdEqualsAndHashCode() {
         final StreamsRebalanceData.TaskId taskEqual = new StreamsRebalanceData.TaskId(task.subtopologyId(), task.partitionId());
         final StreamsRebalanceData.TaskId taskUnequalSubtopology = new StreamsRebalanceData.TaskId(task.subtopologyId() + "1", task.partitionId());
         final StreamsRebalanceData.TaskId taskUnequalPartition = new StreamsRebalanceData.TaskId(task.subtopologyId(), task.partitionId() + 1);
+
         assertEquals(task, taskEqual);
         assertEquals(task.hashCode(), taskEqual.hashCode());
         assertNotEquals(task, taskUnequalSubtopology);
@@ -26,10 +33,130 @@ public void testTaskIdEqualsAndHashCode() {
     @Test
     public void testTaskIdCompareTo() {
         final StreamsRebalanceData.TaskId task = new StreamsRebalanceData.TaskId("subtopologyId1", 1);
+
         assertTrue(task.compareTo(new StreamsRebalanceData.TaskId(task.subtopologyId(), task.partitionId())) == 0);
         assertTrue(task.compareTo(new StreamsRebalanceData.TaskId(task.subtopologyId() + "1", task.partitionId())) < 0);
         assertTrue(task.compareTo(new StreamsRebalanceData.TaskId(task.subtopologyId(), task.partitionId() + 1)) < 0);
         assertTrue(new StreamsRebalanceData.TaskId(task.subtopologyId() + "1", task.partitionId()).compareTo(task) > 0);
         assertTrue(new StreamsRebalanceData.TaskId(task.subtopologyId(), task.partitionId() + 1).compareTo(task) > 0);
     }
+
+    @Test
+    public void shouldNotModifyEmptyAssignment() {
+        final StreamsRebalanceData.Assignment emptyAssignment = StreamsRebalanceData.Assignment.EMPTY;
+
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> emptyAssignment.activeTasks().add(new StreamsRebalanceData.TaskId("subtopologyId1", 1))
+        );
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> emptyAssignment.standbyTasks().add(new StreamsRebalanceData.TaskId("subtopologyId1", 1))
+        );
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> emptyAssignment.warmupTasks().add(new StreamsRebalanceData.TaskId("subtopologyId1", 1))
+        );
+    }
+
+    @Test
+    public void shouldNotModifyAssignment() {
+        final StreamsRebalanceData.Assignment assignment = new StreamsRebalanceData.Assignment(
+            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 1)),
+            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 2)),
+            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 3))
+        );
+
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> assignment.activeTasks().add(new StreamsRebalanceData.TaskId("subtopologyId2", 1))
+        );
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> assignment.standbyTasks().add(new StreamsRebalanceData.TaskId("subtopologyId2", 2))
+        );
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> assignment.warmupTasks().add(new StreamsRebalanceData.TaskId("subtopologyId2", 3))
+        );
+    }
+
+    @Test
+    public void assignmentShouldNotAcceptNulls() {
+        final Exception exception1 = assertThrows(NullPointerException.class, () -> new StreamsRebalanceData.Assignment(null, Set.of(), Set.of()));
+        assertEquals("Active tasks cannot be null", exception1.getMessage());
+        final Exception exception2 = assertThrows(NullPointerException.class, () -> new StreamsRebalanceData.Assignment(Set.of(), null, Set.of()));
+        assertEquals("Standby tasks cannot be null", exception2.getMessage());
+        final Exception exception3 = assertThrows(NullPointerException.class, () -> new StreamsRebalanceData.Assignment(Set.of(), Set.of(), null));
+        assertEquals("Warmup tasks cannot be null", exception3.getMessage());
+    }
+
+    @Test
+    public void testAssignmentEqualsAndHashCode() {
+        final StreamsRebalanceData.TaskId additionalTask = new StreamsRebalanceData.TaskId("subtopologyId2", 1);
+        final StreamsRebalanceData.Assignment assignment = new StreamsRebalanceData.Assignment(
+            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 1)),
+            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 2)),
+            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 3))
+        );
+        final StreamsRebalanceData.Assignment assignmentEqual = new StreamsRebalanceData.Assignment(
+            assignment.activeTasks(),
+            assignment.standbyTasks(),
+            assignment.warmupTasks()
+        );
+        Set<StreamsRebalanceData.TaskId> unequalActiveTasks = new HashSet<>(assignment.activeTasks());
+        unequalActiveTasks.add(additionalTask);
+        final StreamsRebalanceData.Assignment assignmentUnequalActiveTasks = new StreamsRebalanceData.Assignment(
+            unequalActiveTasks,
+            assignment.standbyTasks(),
+            assignment.warmupTasks()
+        );
+        Set<StreamsRebalanceData.TaskId> unequalStandbyTasks = new HashSet<>(assignment.standbyTasks());
+        unequalStandbyTasks.add(additionalTask);
+        final StreamsRebalanceData.Assignment assignmentUnequalStandbyTasks = new StreamsRebalanceData.Assignment(
+            assignment.activeTasks(),
+            unequalStandbyTasks,
+            assignment.warmupTasks()
+        );
+        Set<StreamsRebalanceData.TaskId> unequalWarmupTasks = new HashSet<>(assignment.warmupTasks());
+        unequalWarmupTasks.add(additionalTask);
+        final StreamsRebalanceData.Assignment assignmentUnequalWarmupTasks = new StreamsRebalanceData.Assignment(
+            assignment.activeTasks(),
+            assignment.standbyTasks(),
+            unequalWarmupTasks
+        );
+
+        assertEquals(assignment, assignmentEqual);
+        assertNotEquals(assignment, assignmentUnequalActiveTasks);
+        assertNotEquals(assignment, assignmentUnequalStandbyTasks);
+        assertNotEquals(assignment, assignmentUnequalWarmupTasks);
+        assertEquals(assignment.hashCode(), assignmentEqual.hashCode());
+        assertNotEquals(assignment.hashCode(), assignmentUnequalActiveTasks.hashCode());
+        assertNotEquals(assignment.hashCode(), assignmentUnequalStandbyTasks.hashCode());
+        assertNotEquals(assignment.hashCode(), assignmentUnequalWarmupTasks.hashCode());
+    }
+
+    @Test
+    public void shouldCopyAssignment() {
+        final StreamsRebalanceData.Assignment assignment = new StreamsRebalanceData.Assignment(
+            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 1)),
+            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 2)),
+            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 3))
+        );
+
+        final StreamsRebalanceData.Assignment copy = assignment.copy();
+
+        assertEquals(assignment, copy);
+        assertNotSame(assignment, copy);
+    }
+
+    @Test
+    public void shouldCopyEmptyAssignment() {
+        final StreamsRebalanceData.Assignment emptyAssignment = StreamsRebalanceData.Assignment.EMPTY;
+
+        final StreamsRebalanceData.Assignment copy = emptyAssignment.copy();
+
+        assertEquals(emptyAssignment, copy);
+        assertNotSame(emptyAssignment, copy);
+    }
 }
\ No newline at end of file

From 88d016c5c910e823babf187a53ecaa411bd29965 Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 21 Jan 2025 10:17:00 +0100
Subject: [PATCH 07/13] Add null checks to TaskId

---
 .../clients/consumer/internals/StreamsRebalanceData.java   | 2 +-
 .../consumer/internals/StreamsRebalanceDataTest.java       | 7 ++++++-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
index cd4417df508db..bd84c5d55a131 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -37,7 +37,7 @@ public static class TaskId implements Comparable<TaskId> {
         private final int partitionId;
 
         public TaskId(final String subtopologyId, final int partitionId) {
-            this.subtopologyId = subtopologyId;
+            this.subtopologyId = Objects.requireNonNull(subtopologyId, "Subtopology ID cannot be null");
             this.partitionId = partitionId;
         }
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
index bb0432ea3ebe3..2925b67d87ad9 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
@@ -7,7 +7,6 @@
 import java.util.Set;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
 import static org.junit.jupiter.api.Assertions.assertNotSame;
 import static org.junit.jupiter.api.Assertions.assertThrows;
@@ -30,6 +29,12 @@ public void testTaskIdEqualsAndHashCode() {
         assertNotEquals(task.hashCode(), taskUnequalSubtopology.hashCode());
     }
 
+    @Test
+    public void taskIdShouldNotAcceptNulls() {
+        final Exception exception = assertThrows(NullPointerException.class, () -> new StreamsRebalanceData.TaskId(null, 1));
+        assertEquals("Subtopology ID cannot be null", exception.getMessage());
+    }
+
     @Test
     public void testTaskIdCompareTo() {
         final StreamsRebalanceData.TaskId task = new StreamsRebalanceData.TaskId("subtopologyId1", 1);

From cc8e0b56fd3caa5e8e669ebb4d5aed9a651094ce Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 21 Jan 2025 15:17:13 +0100
Subject: [PATCH 08/13] Improve subtopology

---
 .../internals/StreamsRebalanceData.java       | 72 ++++++++++++----
 .../internals/StreamsRebalanceDataTest.java   | 85 +++++++++++++++++--
 2 files changed, 136 insertions(+), 21 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
index bd84c5d55a131..569d15d9f38b5 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -97,9 +97,9 @@ private Assignment() {
         public Assignment(final Set<TaskId> activeTasks,
                           final Set<TaskId> standbyTasks,
                           final Set<TaskId> warmupTasks) {
-            this.activeTasks = Collections.unmodifiableSet(Objects.requireNonNull(activeTasks, "Active tasks cannot be null"));
-            this.standbyTasks = Collections.unmodifiableSet(Objects.requireNonNull(standbyTasks, "Standby tasks cannot be null"));
-            this.warmupTasks = Collections.unmodifiableSet(Objects.requireNonNull(warmupTasks, "Warmup tasks cannot be null"));
+            this.activeTasks = Set.copyOf(Objects.requireNonNull(activeTasks, "Active tasks cannot be null"));
+            this.standbyTasks = Set.copyOf(Objects.requireNonNull(standbyTasks, "Standby tasks cannot be null"));
+            this.warmupTasks = Set.copyOf(Objects.requireNonNull(warmupTasks, "Warmup tasks cannot be null"));
         }
 
         public Set<TaskId> activeTasks() {
@@ -149,11 +149,11 @@ public String toString() {
 
     public static class Subtopology {
 
-        public final Set<String> sourceTopics;
-        public final Set<String> repartitionSinkTopics;
-        public final Map<String, TopicInfo> stateChangelogTopics;
-        public final Map<String, TopicInfo> repartitionSourceTopics;
-        public final Collection<Set<String>> copartitionGroups;
+        private final Set<String> sourceTopics;
+        private final Set<String> repartitionSinkTopics;
+        private final Map<String, TopicInfo> stateChangelogTopics;
+        private final Map<String, TopicInfo> repartitionSourceTopics;
+        private final Collection<Set<String>> copartitionGroups;
 
         public Subtopology(final Set<String> sourceTopics,
                            final Set<String> repartitionSinkTopics,
@@ -161,11 +161,39 @@ public Subtopology(final Set<String> sourceTopics,
                            final Map<String, TopicInfo> stateChangelogTopics,
                            final Collection<Set<String>> copartitionGroups
         ) {
-            this.sourceTopics = sourceTopics;
-            this.repartitionSinkTopics = repartitionSinkTopics;
-            this.stateChangelogTopics = stateChangelogTopics;
-            this.repartitionSourceTopics = repartitionSourceTopics;
-            this.copartitionGroups = copartitionGroups;
+            this.sourceTopics = Set.copyOf(Objects.requireNonNull(sourceTopics, "Subtopology ID cannot be null"));
+            this.repartitionSinkTopics =
+                Set.copyOf(Objects.requireNonNull(repartitionSinkTopics, "Repartition sink topics cannot be null"));
+            this.repartitionSourceTopics =
+                Map.copyOf(Objects.requireNonNull(repartitionSourceTopics, "Repartition source topics cannot be null"));
+            this.stateChangelogTopics =
+                Map.copyOf(Objects.requireNonNull(stateChangelogTopics, "State changelog topics cannot be null"));
+            this.copartitionGroups =
+                Collections.unmodifiableCollection(Objects.requireNonNull(
+                    copartitionGroups,
+                    "Co-partition groups cannot be null"
+                    )
+                );
+        }
+
+        public Set<String> sourceTopics() {
+            return sourceTopics;
+        }
+
+        public Set<String> repartitionSinkTopics() {
+            return repartitionSinkTopics;
+        }
+
+        public Map<String, TopicInfo> stateChangelogTopics() {
+            return stateChangelogTopics;
+        }
+
+        public Map<String, TopicInfo> repartitionSourceTopics() {
+            return repartitionSourceTopics;
+        }
+
+        public Collection<Set<String>> copartitionGroups() {
+            return copartitionGroups;
         }
 
         @Override
@@ -182,9 +210,9 @@ public String toString() {
 
     public static class TopicInfo {
 
-        public final Optional<Integer> numPartitions;
-        public final Optional<Short> replicationFactor;
-        public final Map<String, String> topicConfigs;
+        private final Optional<Integer> numPartitions;
+        private final Optional<Short> replicationFactor;
+        private final Map<String, String> topicConfigs;
 
         public TopicInfo(final Optional<Integer> numPartitions,
                          final Optional<Short> replicationFactor,
@@ -194,6 +222,18 @@ public TopicInfo(final Optional<Integer> numPartitions,
             this.topicConfigs = topicConfigs;
         }
 
+        public Optional<Integer> numPartitions() {
+            return numPartitions;
+        }
+
+        public Optional<Short> replicationFactor() {
+            return replicationFactor;
+        }
+
+        public Map<String, String> topicConfigs() {
+            return topicConfigs;
+        }
+
         @Override
         public String toString() {
             return "TopicInfo{" +
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
index 2925b67d87ad9..0f7b2c4e80b49 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
@@ -3,8 +3,16 @@
 
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.OptionalInt;
 import java.util.Set;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
@@ -47,7 +55,7 @@ public void testTaskIdCompareTo() {
     }
 
     @Test
-    public void shouldNotModifyEmptyAssignment() {
+    public void emptyAssignmentShouldNotBeModifiable() {
         final StreamsRebalanceData.Assignment emptyAssignment = StreamsRebalanceData.Assignment.EMPTY;
 
         assertThrows(
@@ -65,11 +73,11 @@ public void shouldNotModifyEmptyAssignment() {
     }
 
     @Test
-    public void shouldNotModifyAssignment() {
+    public void assignmentShouldNotBeModifiable() {
         final StreamsRebalanceData.Assignment assignment = new StreamsRebalanceData.Assignment(
-            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 1)),
-            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 2)),
-            Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 3))
+            new HashSet<>(Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 1))),
+            new HashSet<>(Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 2))),
+            new HashSet<>(Set.of(new StreamsRebalanceData.TaskId("subtopologyId1", 3)))
         );
 
         assertThrows(
@@ -164,4 +172,71 @@ public void shouldCopyEmptyAssignment() {
         assertEquals(emptyAssignment, copy);
         assertNotSame(emptyAssignment, copy);
     }
+
+    @Test
+    public void subtopologyShouldNotAcceptNulls() {
+        final Exception exception1 = assertThrows(
+            NullPointerException.class,
+            () -> new StreamsRebalanceData.Subtopology(null, Set.of(), Map.of(), Map.of(), List.of())
+        );
+        assertEquals("Subtopology ID cannot be null", exception1.getMessage());
+        final Exception exception2 = assertThrows(
+            NullPointerException.class,
+            () -> new StreamsRebalanceData.Subtopology(Set.of(), null, Map.of(), Map.of(), List.of())
+        );
+        assertEquals("Repartition sink topics cannot be null", exception2.getMessage());
+        final Exception exception3 = assertThrows(
+            NullPointerException.class,
+            () -> new StreamsRebalanceData.Subtopology(Set.of(), Set.of(), null, Map.of(), List.of())
+        );
+        assertEquals("Repartition source topics cannot be null", exception3.getMessage());
+        final Exception exception4 = assertThrows(
+            NullPointerException.class,
+            () -> new StreamsRebalanceData.Subtopology(Set.of(), Set.of(), Map.of(), null, List.of())
+        );
+        assertEquals("State changelog topics cannot be null", exception4.getMessage());
+        final Exception exception5 = assertThrows(
+            NullPointerException.class,
+            () -> new StreamsRebalanceData.Subtopology(Set.of(), Set.of(), Map.of(), Map.of(), null)
+        );
+        assertEquals("Co-partition groups cannot be null", exception5.getMessage());
+    }
+
+    @Test
+    public void subtopologyShouldBeModifiable() {
+        final StreamsRebalanceData.Subtopology subtopology = new StreamsRebalanceData.Subtopology(
+            new HashSet<>(Set.of("sourceTopic1")),
+            new HashSet<>(Set.of("repartitionSinkTopic1")),
+            Map.of("repartitionSourceTopic1", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of()))
+                .entrySet().stream()
+                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)),
+            Map.of("stateChangelogTopic1", new StreamsRebalanceData.TopicInfo(Optional.of(0), Optional.of((short) 1), Map.of()))
+                .entrySet().stream()
+                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)),
+            new ArrayList<>(List.of(Set.of("sourceTopic1")))
+        );
+
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> subtopology.sourceTopics().add("sourceTopic2")
+        );
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> subtopology.repartitionSinkTopics().add("repartitionSinkTopic2")
+        );
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> subtopology.repartitionSourceTopics().put("repartitionSourceTopic2", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of()))
+        );
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> subtopology.stateChangelogTopics().put("stateChangelogTopic2", new StreamsRebalanceData.TopicInfo(Optional.of(0), Optional.of((short) 1), Map.of()))
+        );
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> subtopology.copartitionGroups().add(Set.of("sourceTopic2"))
+        );
+    }
+
+
 }
\ No newline at end of file

From de359897d2972b5a543c015b443e3fa3e17089e9 Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 21 Jan 2025 15:27:43 +0100
Subject: [PATCH 09/13] Improve topic info

---
 .../internals/StreamsRebalanceData.java       |  7 ++--
 .../internals/StreamsRebalanceDataTest.java   | 35 ++++++++++++++++++-
 2 files changed, 38 insertions(+), 4 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
index 569d15d9f38b5..66313d09d57c6 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -217,9 +217,10 @@ public static class TopicInfo {
         public TopicInfo(final Optional<Integer> numPartitions,
                          final Optional<Short> replicationFactor,
                          final Map<String, String> topicConfigs) {
-            this.numPartitions = numPartitions;
-            this.replicationFactor = replicationFactor;
-            this.topicConfigs = topicConfigs;
+            this.numPartitions = Objects.requireNonNull(numPartitions, "Number of partitions cannot be null");
+            this.replicationFactor = Objects.requireNonNull(replicationFactor, "Replication factor cannot be null");
+            this.topicConfigs =
+                Map.copyOf(Objects.requireNonNull(topicConfigs, "Additional topic configs cannot be null"));
         }
 
         public Optional<Integer> numPartitions() {
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
index 0f7b2c4e80b49..3c500b6647c59 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
@@ -203,7 +203,7 @@ public void subtopologyShouldNotAcceptNulls() {
     }
 
     @Test
-    public void subtopologyShouldBeModifiable() {
+    public void subtopologyShouldNotBeModifiable() {
         final StreamsRebalanceData.Subtopology subtopology = new StreamsRebalanceData.Subtopology(
             new HashSet<>(Set.of("sourceTopic1")),
             new HashSet<>(Set.of("repartitionSinkTopic1")),
@@ -238,5 +238,38 @@ public void subtopologyShouldBeModifiable() {
         );
     }
 
+    @Test
+    public void topicInfoShouldNotAcceptNulls() {
+        final Exception exception1 = assertThrows(
+            NullPointerException.class,
+            () -> new StreamsRebalanceData.TopicInfo(null, Optional.of((short) 1), Map.of())
+        );
+        assertEquals("Number of partitions cannot be null", exception1.getMessage());
+        final Exception exception2 = assertThrows(
+            NullPointerException.class,
+            () -> new StreamsRebalanceData.TopicInfo(Optional.of(1), null, Map.of())
+        );
+        assertEquals("Replication factor cannot be null", exception2.getMessage());
+        final Exception exception3 = assertThrows(
+            NullPointerException.class,
+            () -> new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), null)
+        );
+        assertEquals("Additional topic configs cannot be null", exception3.getMessage());
+    }
+
+    @Test
+    public void topicInfoShouldNotBeModifiable() {
+        final StreamsRebalanceData.TopicInfo topicInfo = new StreamsRebalanceData.TopicInfo(
+            Optional.of(1),
+            Optional.of((short) 1),
+            Map.of("key1", "value1")
+                .entrySet().stream()
+                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
+        );
 
+        assertThrows(
+            UnsupportedOperationException.class,
+            () -> topicInfo.topicConfigs().put("key2", "value2")
+        );
+    }
 }
\ No newline at end of file

From 31f170d2b69462dc0140765cff4dd08cf72a1602 Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 21 Jan 2025 15:41:14 +0100
Subject: [PATCH 10/13] Improve streams rebalance data

---
 .../internals/StreamsRebalanceData.java       |  2 +-
 .../internals/StreamsRebalanceDataTest.java   | 37 ++++++++++++++-----
 2 files changed, 28 insertions(+), 11 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
index 66313d09d57c6..5a23e48d68e8c 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -250,7 +250,7 @@ public String toString() {
     private final AtomicReference<Assignment> reconciledAssignment = new AtomicReference<>(Assignment.EMPTY);
 
     public StreamsRebalanceData(Map<String, Subtopology> subtopologies) {
-        this.subtopologies = subtopologies;
+        this.subtopologies = Map.copyOf(Objects.requireNonNull(subtopologies, "Subtopologies cannot be null"));
     }
 
     public Map<String, Subtopology> subtopologies() {
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
index 3c500b6647c59..f04620b7769b0 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
@@ -5,6 +5,7 @@
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -258,18 +259,34 @@ public void topicInfoShouldNotAcceptNulls() {
     }
 
     @Test
-    public void topicInfoShouldNotBeModifiable() {
-        final StreamsRebalanceData.TopicInfo topicInfo = new StreamsRebalanceData.TopicInfo(
-            Optional.of(1),
-            Optional.of((short) 1),
-            Map.of("key1", "value1")
-                .entrySet().stream()
-                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
-        );
+    public void streamsRebalanceDataShouldNotHaveModifiableSubtopologies() {
+        final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(new HashMap<>());
 
         assertThrows(
             UnsupportedOperationException.class,
-            () -> topicInfo.topicConfigs().put("key2", "value2")
+            () -> streamsRebalanceData.subtopologies().put("subtopologyId2", new StreamsRebalanceData.Subtopology(
+                Set.of(),
+                Set.of(),
+                Map.of(),
+                Map.of(),
+                List.of()
+            ))
         );
     }
-}
\ No newline at end of file
+
+    @Test
+    public void streamsRebalanceDataShouldNotAcceptNulls() {
+        final Exception exception = assertThrows(
+            NullPointerException.class,
+            () -> new StreamsRebalanceData(null)
+        );
+        assertEquals("Subtopologies cannot be null", exception.getMessage());
+    }
+
+    @Test
+    public void streamsRebalanceDataShouldBeConstructedWithEmptyAssignment() {
+        final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(new HashMap<>());
+
+        assertEquals(StreamsRebalanceData.Assignment.EMPTY, streamsRebalanceData.reconciledAssignment());
+    }
+}

From eadecd357d0882af571defbccabf8b91cf29f024 Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 21 Jan 2025 15:42:16 +0100
Subject: [PATCH 11/13] Put types of background event on separate lines

---
 .../clients/consumer/internals/events/BackgroundEvent.java | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/BackgroundEvent.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/BackgroundEvent.java
index 9c704abbabd9f..b2f8a3666c499 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/BackgroundEvent.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/BackgroundEvent.java
@@ -27,8 +27,11 @@
 public abstract class BackgroundEvent {
 
     public enum Type {
-        ERROR, CONSUMER_REBALANCE_LISTENER_CALLBACK_NEEDED, SHARE_ACKNOWLEDGEMENT_COMMIT_CALLBACK,
-        STREAMS_ON_TASKS_ASSIGNED_CALLBACK_NEEDED, STREAMS_ON_TASKS_REVOKED_CALLBACK_NEEDED,
+        ERROR,
+        CONSUMER_REBALANCE_LISTENER_CALLBACK_NEEDED,
+        SHARE_ACKNOWLEDGEMENT_COMMIT_CALLBACK,
+        STREAMS_ON_TASKS_ASSIGNED_CALLBACK_NEEDED,
+        STREAMS_ON_TASKS_REVOKED_CALLBACK_NEEDED,
         STREAMS_ON_ALL_TASKS_LOST_CALLBACK_NEEDED
     }
 

From a319b91c4dea83cf07575203aef6b639161ea0f9 Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 21 Jan 2025 15:53:30 +0100
Subject: [PATCH 12/13] Harmonize implementations of callbacks invocations

---
 .../StreamsRebalanceEventsProcessor.java      | 20 +++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
index 4b411d1481c27..db91e8fcece8e 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
@@ -156,38 +156,38 @@ private void processStreamsOnAllTasksLostCallbackNeededEvent(final StreamsOnAllT
 
     private StreamsOnTasksRevokedCallbackCompletedEvent invokeOnTasksRevokedCallback(final Set<StreamsRebalanceData.TaskId> activeTasksToRevoke,
                                                                                      final CompletableFuture<Void> future) {
+        final Optional<KafkaException> error;
         final Optional<Exception> exceptionFromCallback = rebalanceCallbacks.onTasksRevoked(activeTasksToRevoke);
-        return exceptionFromCallback
-            .map(exception ->
-                new StreamsOnTasksRevokedCallbackCompletedEvent(
-                    future,
-                    Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exception, "Task revocation callback throws an error"))
-                ))
-            .orElseGet(() -> new StreamsOnTasksRevokedCallbackCompletedEvent(future, Optional.empty()));
+        if (exceptionFromCallback.isPresent()) {
+            error = Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(), "Task revocation callback throws an error"));
+        } else {
+            error = Optional.empty();
+        }
+        return new StreamsOnTasksRevokedCallbackCompletedEvent(future, error);
     }
 
     private StreamsOnTasksAssignedCallbackCompletedEvent invokeOnTasksAssignedCallback(final StreamsRebalanceData.Assignment assignment,
                                                                                        final CompletableFuture<Void> future) {
-        Optional<KafkaException> error = Optional.empty();
+        final Optional<KafkaException> error;
         final Optional<Exception> exceptionFromCallback = rebalanceCallbacks.onTasksAssigned(assignment);
         if (exceptionFromCallback.isPresent()) {
             error = Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(), "Task assignment callback throws an error"));
         } else {
+            error = Optional.empty();
             streamsRebalanceData.setReconciledAssignment(assignment);
         }
         return new StreamsOnTasksAssignedCallbackCompletedEvent(future, error);
     }
 
     private StreamsOnAllTasksLostCallbackCompletedEvent invokeOnAllTasksLostCallback(final CompletableFuture<Void> future) {
-        final Optional<Exception> exceptionFromCallback = rebalanceCallbacks.onAllTasksLost();
         final Optional<KafkaException> error;
+        final Optional<Exception> exceptionFromCallback = rebalanceCallbacks.onAllTasksLost();
         if (exceptionFromCallback.isPresent()) {
             error = Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(), "All tasks lost callback throws an error"));
         } else {
             error = Optional.empty();
             streamsRebalanceData.setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
         }
-
         return new StreamsOnAllTasksLostCallbackCompletedEvent(future, error);
     }
 

From 2015d2490fd1061b0937443b2a3d8797699b763c Mon Sep 17 00:00:00 2001
From: Bruno Cadonna <cadonna@apache.org>
Date: Tue, 21 Jan 2025 16:38:43 +0100
Subject: [PATCH 13/13] fix checkstyle issues

---
 .../internals/StreamsRebalanceData.java       |  1 -
 .../internals/StreamsRebalanceDataTest.java   | 19 ++++++++++++++++---
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
index 5a23e48d68e8c..a8670eeb1b192 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -19,7 +19,6 @@
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
-import java.util.HashSet;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
index f04620b7769b0..8a67c580f06d4 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java
@@ -1,18 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 package org.apache.kafka.clients.consumer.internals;
 
 
 import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
-import java.util.OptionalInt;
 import java.util.Set;
-import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;