kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From guozh...@apache.org
Subject [kafka] branch 2.4 updated: KAFKA-8179: Part 7, cooperative rebalancing in Streams (#7386)
Date Mon, 07 Oct 2019 16:30:24 GMT
This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch 2.4
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.4 by this push:
     new 133c33f  KAFKA-8179: Part 7, cooperative rebalancing in Streams (#7386)
133c33f is described below

commit 133c33fde1c4d3b79196b72522f83a75cb6b0e65
Author: A. Sophie Blee-Goldman <sophie@confluent.io>
AuthorDate: Mon Oct 7 09:27:09 2019 -0700

    KAFKA-8179: Part 7, cooperative rebalancing in Streams (#7386)
    
    Key improvements with this PR:
    
    * tasks will remain available for IQ during a rebalance (but not during restore)
    * continue restoring and processing standby tasks during a rebalance
    * continue processing active tasks during rebalance until the RecordQueue is empty*
    * only revoked tasks must suspended/closed
    * StreamsPartitionAssignor tries to return tasks to their previous consumers within a client
    * but do not try to commit, for now (pending KAFKA-7312)
    
    
    Reviewers: John Roesler <john@confluent.io>, Boyang Chen <boyang@confluent.io>, Guozhang Wang <wangguoz@gmail.com>
---
 checkstyle/suppressions.xml                        |   3 +
 .../consumer/internals/ConsumerCoordinator.java    |  20 +-
 .../consumer/internals/SubscriptionState.java      |  15 +-
 .../clients/consumer/internals/FetcherTest.java    |  31 +-
 .../processor/internals/AssignedStreamsTasks.java  |  15 +-
 .../streams/processor/internals/StandbyTask.java   |  19 -
 .../processor/internals/StoreChangelogReader.java  |  16 +-
 .../streams/processor/internals/StreamTask.java    |   8 +-
 .../streams/processor/internals/StreamThread.java  |  27 +-
 .../internals/StreamsPartitionAssignor.java        | 612 +++++++++++++++------
 .../kafka/streams/processor/internals/Task.java    |   8 +-
 .../streams/processor/internals/TaskManager.java   |  34 +-
 .../assignment/AssignorConfiguration.java          |  12 +-
 .../internals/assignment/ClientState.java          |  56 +-
 .../internals/assignment/StickyTaskAssignor.java   |   8 +-
 .../processor/internals/AbstractTaskTest.java      |   6 -
 .../processor/internals/StandbyTaskTest.java       |  23 -
 .../internals/StreamsPartitionAssignorTest.java    | 481 ++++++++++++++--
 .../processor/internals/TaskManagerTest.java       |  28 +-
 .../internals/assignment/ClientStateTest.java      |   8 +-
 .../assignment/StickyTaskAssignorTest.java         |  34 +-
 .../kafka/streams/tests/SmokeTestDriver.java       |   2 +-
 .../kafka/streams/tests/StreamsUpgradeTest.java    |  26 +-
 tests/kafkatest/tests/streams/streams_eos_test.py  |   2 +-
 24 files changed, 1058 insertions(+), 436 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 8927849..2f21309 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -197,6 +197,9 @@
     <suppress checks="MethodLength"
               files="RocksDBWindowStoreTest.java"/>
 
+    <suppress checks="MemberName"
+              files="StreamsPartitionAssignorTest.java"/>
+
     <suppress checks="ClassDataAbstractionCoupling"
               files=".*[/\\]streams[/\\].*test[/\\].*.java"/>
 
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index b5b5ce2..6b39acb 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -355,26 +355,29 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         Set<TopicPartition> addedPartitions = new HashSet<>(assignedPartitions);
         addedPartitions.removeAll(ownedPartitions);
 
-        // Invoke user's revocation callback before changing assignment or updating state
         if (protocol == RebalanceProtocol.COOPERATIVE) {
             Set<TopicPartition> revokedPartitions = new HashSet<>(ownedPartitions);
             revokedPartitions.removeAll(assignedPartitions);
 
-            log.info("Updating with newly assigned partitions: {}, compare with already owned partitions: {}, " +
-                    "newly added partitions: {}, revoking partitions: {}",
+            log.info("Updating assignment with\n" +
+                    "now assigned partitions: {}\n" +
+                    "compare with previously owned partitions: {}\n" +
+                    "newly added partitions: {}\n" +
+                    "revoked partitions: {}\n",
                 Utils.join(assignedPartitions, ", "),
                 Utils.join(ownedPartitions, ", "),
                 Utils.join(addedPartitions, ", "),
-                Utils.join(revokedPartitions, ", "));
-
+                Utils.join(revokedPartitions, ", ")
+            );
 
             if (!revokedPartitions.isEmpty()) {
-                // revoke partitions that was previously owned but no longer assigned;
-                // note that we should only change the assignment AFTER we've triggered
-                // the revoke callback
+                // revoke partitions that were previously owned but no longer assigned;
+                // note that we should only change the assignment (or update the assignor's state)
+                // AFTER we've triggered  the revoke callback
                 firstException.compareAndSet(null, invokePartitionsRevoked(revokedPartitions));
 
                 // if revoked any partitions, need to re-join the group afterwards
+                log.debug("Need to revoke partitions {} and re-join the group", revokedPartitions);
                 requestRejoin();
             }
         }
@@ -679,7 +682,6 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             }
         }
 
-
         isLeader = false;
         subscriptions.resetGroupSubscription();
 
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
index 4641e5c..953505f 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
@@ -270,8 +270,14 @@ public class SubscriptionState {
         if (!this.partitionsAutoAssigned())
             throw new IllegalArgumentException("Attempt to dynamically assign partitions while manual assignment in use");
 
+        Map<TopicPartition, TopicPartitionState> assignedPartitionStates = new HashMap<>(assignments.size());
+        for (TopicPartition tp : assignments) {
+            TopicPartitionState state = this.assignment.stateValue(tp);
+            if (state == null)
+                state = new TopicPartitionState();
+            assignedPartitionStates.put(tp, state);
+        }
 
-        Map<TopicPartition, TopicPartitionState> assignedPartitionStates = partitionToStateMap(assignments);
         assignmentId++;
         this.assignment.set(assignedPartitionStates);
     }
@@ -674,13 +680,6 @@ public class SubscriptionState {
         return rebalanceListener;
     }
 
-    private static Map<TopicPartition, TopicPartitionState> partitionToStateMap(Collection<TopicPartition> assignments) {
-        Map<TopicPartition, TopicPartitionState> map = new HashMap<>(assignments.size());
-        for (TopicPartition tp : assignments)
-            map.put(tp, new TopicPartitionState());
-        return map;
-    }
-
     private static class TopicPartitionState {
 
         private FetchState fetchState;
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index ff0afe9..9e281a7 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -848,7 +848,7 @@ public class FetcherTest {
     }
 
     @Test
-    public void testFetchDuringRebalance() {
+    public void testFetchDuringEagerRebalance() {
         buildFetcher();
 
         subscriptions.subscribe(singleton(topicName), listener);
@@ -859,7 +859,9 @@ public class FetcherTest {
 
         assertEquals(1, fetcher.sendFetches());
 
-        // Now the rebalance happens and fetch positions are cleared
+        // Now the eager rebalance happens and fetch positions are cleared
+        subscriptions.assignFromSubscribed(Collections.emptyList());
+
         subscriptions.assignFromSubscribed(singleton(tp0));
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
@@ -869,6 +871,31 @@ public class FetcherTest {
     }
 
     @Test
+    public void testFetchDuringCooperativeRebalance() {
+        buildFetcher();
+
+        subscriptions.subscribe(singleton(topicName), listener);
+        subscriptions.assignFromSubscribed(singleton(tp0));
+        subscriptions.seek(tp0, 0);
+
+        client.updateMetadata(initialUpdateResponse);
+
+        assertEquals(1, fetcher.sendFetches());
+
+        // Now the cooperative rebalance happens and fetch positions are NOT cleared for unrevoked partitions
+        subscriptions.assignFromSubscribed(singleton(tp0));
+
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        consumerClient.poll(time.timer(0));
+
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
+
+        // The active fetch should NOT be ignored since the position for tp0 is still valid
+        assertEquals(1, fetchedRecords.size());
+        assertEquals(3, fetchedRecords.get(tp0).size());
+    }
+
+    @Test
     public void testInFlightFetchOnPausedPartition() {
         buildFetcher();
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedStreamsTasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedStreamsTasks.java
index da0fc20..65c4c95 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedStreamsTasks.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedStreamsTasks.java
@@ -74,11 +74,15 @@ class AssignedStreamsTasks extends AssignedTasks<StreamTask> implements Restorin
     @Override
     void closeTask(final StreamTask task, final boolean clean) {
         if (suspended.containsKey(task.id())) {
-            task.closeSuspended(clean, false, null);
+            task.closeSuspended(clean, null);
         } else {
             task.close(clean, false);
         }
     }
+
+    boolean hasRestoringTasks() {
+        return !restoring.isEmpty();
+    }
     
     Set<TaskId> suspendedTaskIds() {
         return suspended.keySet();
@@ -107,7 +111,7 @@ class AssignedStreamsTasks extends AssignedTasks<StreamTask> implements Restorin
             } else if (restoring.containsKey(task)) {
                 revokedRestoringTasks.add(task);
             } else if (!suspended.containsKey(task)) {
-                log.warn("Task {} was revoked but cannot be found in the assignment", task);
+                log.warn("Task {} was revoked but cannot be found in the assignment, may have been closed due to error", task);
             }
         }
 
@@ -131,7 +135,7 @@ class AssignedStreamsTasks extends AssignedTasks<StreamTask> implements Restorin
                 task.suspend();
                 suspended.put(id, task);
             } catch (final TaskMigratedException closeAsZombieAndSwallow) {
-                // as we suspend a task, we are either shutting down or rebalancing, thus, we swallow and move on
+                // swallow and move on since we are rebalancing
                 log.info("Failed to suspend {} {} since it got migrated to another thread already. " +
                     "Closing it as zombie and move on.", taskTypeName, id);
                 firstException.compareAndSet(null, closeZombieTask(task));
@@ -248,7 +252,7 @@ class AssignedStreamsTasks extends AssignedTasks<StreamTask> implements Restorin
 
         try {
             final boolean clean = !isZombie;
-            task.closeSuspended(clean, isZombie, null);
+            task.closeSuspended(clean, null);
         } catch (final RuntimeException e) {
             log.error("Failed to close suspended {} {} due to the following error:", taskTypeName, task.id(), e);
             return e;
@@ -264,7 +268,6 @@ class AssignedStreamsTasks extends AssignedTasks<StreamTask> implements Restorin
         for (final TaskId revokedTask : revokedTasks) {
             final StreamTask suspendedTask = suspended.get(revokedTask);
 
-            // task may not be in the suspended tasks if it was closed due to some error
             if (suspendedTask != null) {
                 firstException.compareAndSet(null, closeSuspended(false, suspendedTask));
             } else {
@@ -335,7 +338,7 @@ class AssignedStreamsTasks extends AssignedTasks<StreamTask> implements Restorin
                 return true;
             } else {
                 log.warn("Couldn't resume task {} assigned partitions {}, task partitions {}", taskId, partitions, task.partitions());
-                task.closeSuspended(true, false, null);
+                task.closeSuspended(true, null);
             }
         }
         return false;
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
index f10c25b..fbc116a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
@@ -120,18 +120,6 @@ public class StandbyTask extends AbstractTask {
         commitNeeded = false;
     }
 
-    /**
-     * <pre>
-     * - flush store
-     * - checkpoint store
-     * </pre>
-     */
-    @Override
-    public void suspend() {
-        log.debug("Suspending");
-        flushAndCheckpointState();
-    }
-
     private void flushAndCheckpointState() {
         stateMgr.flush();
         stateMgr.checkpoint(Collections.emptyMap());
@@ -163,13 +151,6 @@ public class StandbyTask extends AbstractTask {
         taskClosed = true;
     }
 
-    @Override
-    public void closeSuspended(final boolean clean,
-                               final boolean isZombie,
-                               final RuntimeException e) {
-        close(clean, isZombie);
-    }
-
     /**
      * Updates a state store using records from one change log partition
      *
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
index 6c6a1c4..55a33c0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
@@ -79,8 +79,7 @@ public class StoreChangelogReader implements ChangelogReader {
             initialize(active);
         }
 
-        if (needsRestoring.isEmpty() || restoreConsumer.assignment().isEmpty()) {
-            restoreConsumer.unsubscribe();
+        if (checkForCompletedRestoration()) {
             return completedRestorers;
         }
 
@@ -116,9 +115,7 @@ public class StoreChangelogReader implements ChangelogReader {
 
         needsRestoring.removeAll(completedRestorers);
 
-        if (needsRestoring.isEmpty()) {
-            restoreConsumer.unsubscribe();
-        }
+        checkForCompletedRestoration();
 
         return completedRestorers;
     }
@@ -337,7 +334,14 @@ public class StoreChangelogReader implements ChangelogReader {
         return nextPosition;
     }
 
-
+    private boolean checkForCompletedRestoration() {
+        if (needsRestoring.isEmpty()) {
+            log.info("Finished restoring all active tasks");
+            restoreConsumer.unsubscribe();
+            return true;
+        }
+        return false;
+    }
 
     private boolean hasPartition(final TopicPartition topicPartition) {
         final List<PartitionInfo> partitions = partitionInfo.get(topicPartition.topic());
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 40466b3..780fe1f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -573,7 +573,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator
      * @throws TaskMigratedException if committing offsets failed (non-EOS)
      *                               or if the task producer got fenced (EOS)
      */
-    @Override
     public void suspend() {
         log.debug("Suspending");
         suspend(true, false);
@@ -687,10 +686,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator
     }
 
     // helper to avoid calling suspend() twice if a suspended task is not reassigned and closed
-    @Override
-    public void closeSuspended(final boolean clean,
-                               final boolean isZombie,
-                               RuntimeException firstException) {
+    void closeSuspended(final boolean clean, RuntimeException firstException) {
         try {
             closeStateManager(clean);
         } catch (final RuntimeException e) {
@@ -742,7 +738,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator
             log.error("Could not close task due to the following error:", e);
         }
 
-        closeSuspended(clean, isZombie, firstException);
+        closeSuspended(clean, firstException);
 
         taskClosed = true;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 29e1bc7..c71ff27 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -132,13 +132,13 @@ public class StreamThread extends Thread {
      */
     public enum State implements ThreadStateTransitionValidator {
 
-        CREATED(1, 5),                   // 0
-        STARTING(2, 3, 5),               // 1
-        PARTITIONS_REVOKED(3, 5),        // 2
-        PARTITIONS_ASSIGNED(2, 3, 4, 5), // 3
-        RUNNING(2, 3, 5),                // 4
-        PENDING_SHUTDOWN(6),             // 5
-        DEAD;                            // 6
+        CREATED(1, 5),                    // 0
+        STARTING(2, 3, 5),                // 1
+        PARTITIONS_REVOKED(2, 3, 5),      // 2
+        PARTITIONS_ASSIGNED(2, 3, 4, 5),  // 3
+        RUNNING(2, 3, 5),                 // 4
+        PENDING_SHUTDOWN(6),              // 5
+        DEAD;                             // 6
 
         private final Set<Integer> validTransitions = new HashSet<>();
 
@@ -734,9 +734,9 @@ public class StreamThread extends Thread {
             // to unblock the restoration as soon as possible
             records = pollRequests(Duration.ZERO);
         } else if (state == State.PARTITIONS_REVOKED) {
-            // try to fetch some records with normal poll time
-            // in order to wait long enough to get the join response
-            records = pollRequests(pollTime);
+            // try to fetch som records with zero poll millis to unblock
+            // other useful work while waiting for the join response
+            records = pollRequests(Duration.ZERO);
         } else if (state == State.RUNNING || state == State.STARTING) {
             // try to fetch some records with normal poll time
             // in order to get long polling
@@ -970,7 +970,12 @@ public class StreamThread extends Thread {
                 }
             }
 
-            lastCommitMs = now;
+            if (committed == -1) {
+                log.trace("Unable to commit as we are in the middle of a rebalance, will try again when it completes.");
+            } else {
+                lastCommitMs = now;
+            }
+            
             processStandbyRecords = true;
         } else {
             committed = taskManager.maybeCommitActiveTasksPerUserRequested();
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
index 2e6c9c0..8b2c95a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
@@ -47,15 +47,18 @@ import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.TreeMap;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 
+import static java.util.UUID.randomUUID;
 import static org.apache.kafka.common.utils.Utils.getHost;
 import static org.apache.kafka.common.utils.Utils.getPort;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.EARLIEST_PROBEABLE_VERSION;
@@ -63,20 +66,21 @@ import static org.apache.kafka.streams.processor.internals.assignment.StreamsAss
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.UNKNOWN;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.VERSION_FIVE;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.VERSION_FOUR;
-import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.VERSION_ONE;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.VERSION_THREE;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.VERSION_TWO;
+import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.VERSION_ONE;
 
 public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Configurable {
+
     private Logger log;
     private String logPrefix;
 
     private static class AssignedPartition implements Comparable<AssignedPartition> {
+
         private final TaskId taskId;
         private final TopicPartition partition;
 
-        AssignedPartition(final TaskId taskId,
-                          final TopicPartition partition) {
+        AssignedPartition(final TaskId taskId, final TopicPartition partition) {
             this.taskId = taskId;
             this.partition = partition;
         }
@@ -103,6 +107,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
     }
 
     private static class ClientMetadata {
+
         private final HostInfo hostInfo;
         private final Set<String> consumers;
         private final ClientState state;
@@ -132,12 +137,15 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             state = new ClientState();
         }
 
-        void addConsumer(final String consumerMemberId,
-                         final SubscriptionInfo info) {
+        void addConsumer(final String consumerMemberId, final List<TopicPartition> ownedPartitions) {
             consumers.add(consumerMemberId);
-            state.addPreviousActiveTasks(consumerMemberId, info.prevTasks());
-            state.addPreviousStandbyTasks(consumerMemberId, info.standbyTasks());
             state.incrementCapacity();
+            state.addOwnedPartitions(ownedPartitions, consumerMemberId);
+        }
+
+        void addPreviousTasks(final SubscriptionInfo info) {
+            state.addPreviousActiveTasks(info.prevTasks());
+            state.addPreviousStandbyTasks(info.standbyTasks());
         }
 
         @Override
@@ -177,9 +185,9 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
     }
 
     /**
-     * We need to have the PartitionAssignor and its StreamThread to be mutually accessible
-     * since the former needs later's cached metadata while sending subscriptions,
-     * and the latter needs former's returned assignment when adding tasks.
+     * We need to have the PartitionAssignor and its StreamThread to be mutually accessible since the former needs
+     * later's cached metadata while sending subscriptions, and the latter needs former's returned assignment when
+     * adding tasks.
      *
      * @throws KafkaException if the stream thread is not specified
      */
@@ -189,7 +197,8 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
 
         logPrefix = assignorConfiguration.logPrefix();
         log = new LogContext(logPrefix).logger(getClass());
-        usedSubscriptionMetadataVersion = assignorConfiguration.configuredMetadataVersion(usedSubscriptionMetadataVersion);
+        usedSubscriptionMetadataVersion = assignorConfiguration
+            .configuredMetadataVersion(usedSubscriptionMetadataVersion);
         taskManager = assignorConfiguration.getTaskManager();
         assignmentErrorCode = assignorConfiguration.getAssignmentErrorCode(configs);
         numStandbyReplicas = assignorConfiguration.getNumStandbyReplicas();
@@ -221,37 +230,64 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         // 1. Client UUID (a unique id assigned to an instance of KafkaStreams)
         // 2. Task ids of previously running tasks
         // 3. Task ids of valid local states on the client's state directory.
-
-        final Set<TaskId> previousActiveTasks = taskManager.previousRunningTaskIds();
         final Set<TaskId> standbyTasks = taskManager.cachedTasksIds();
-        standbyTasks.removeAll(previousActiveTasks);
-        final SubscriptionInfo data = new SubscriptionInfo(
+        final Set<TaskId> activeTasks = prepareForSubscription(taskManager,
+            topics,
+            standbyTasks,
+            rebalanceProtocol);
+        return new SubscriptionInfo(
             usedSubscriptionMetadataVersion,
             taskManager.processId(),
-            previousActiveTasks,
+            activeTasks,
             standbyTasks,
-            userEndPoint);
+            userEndPoint)
+            .encode();
+    }
+
+    protected static Set<TaskId> prepareForSubscription(final TaskManager taskManager,
+        final Set<String> topics,
+        final Set<TaskId> standbyTasks,
+        final RebalanceProtocol rebalanceProtocol) {
+        // Any tasks that are not yet running are counted as standby tasks for assignment purposes,
+        // along with any old tasks for which we still found state on disk
+        final Set<TaskId> activeTasks;
+
+        switch (rebalanceProtocol) {
+            case EAGER:
+                // In eager, onPartitionsRevoked is called first and we must get the previously saved running task ids
+                activeTasks = taskManager.previousRunningTaskIds();
+                standbyTasks.removeAll(activeTasks);
+                break;
+            case COOPERATIVE:
+                // In cooperative, we will use the encoded ownedPartitions to determine the running tasks
+                activeTasks = Collections.emptySet();
+                standbyTasks.removeAll(taskManager.activeTaskIds());
+                break;
+            default:
+                throw new IllegalStateException("Streams partition assignor's rebalance protocol is unknown");
+        }
 
         taskManager.updateSubscriptionsFromMetadata(topics);
+        taskManager.setRebalanceInProgress(true);
 
-        return data.encode();
+        return activeTasks;
     }
 
     private Map<String, Assignment> errorAssignment(final Map<UUID, ClientMetadata> clientsMetadata,
                                                     final String topic,
                                                     final int errorCode) {
         log.error("{} is unknown yet during rebalance," +
-                      " please make sure they have been pre-created before starting the Streams application.", topic);
+            " please make sure they have been pre-created before starting the Streams application.", topic);
         final Map<String, Assignment> assignment = new HashMap<>();
         for (final ClientMetadata clientMetadata : clientsMetadata.values()) {
             for (final String consumerId : clientMetadata.consumers) {
                 assignment.put(consumerId, new Assignment(
                     Collections.emptyList(),
                     new AssignmentInfo(LATEST_SUPPORTED_VERSION,
-                                       Collections.emptyList(),
-                                       Collections.emptyMap(),
-                                       Collections.emptyMap(),
-                                       errorCode).encode()
+                        Collections.emptyList(),
+                        Collections.emptyMap(),
+                        Collections.emptyMap(),
+                        errorCode).encode()
                 ));
             }
         }
@@ -283,7 +319,12 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         final Map<String, Subscription> subscriptions = groupSubscription.groupSubscription();
         // construct the client metadata from the decoded subscription info
         final Map<UUID, ClientMetadata> clientMetadataMap = new HashMap<>();
-        final Set<String> futureConsumers = new HashSet<>();
+        final Set<TopicPartition> allOwnedPartitions = new HashSet<>();
+
+        // keep track of any future consumers in a "dummy" Client since we can't decipher their subscription
+        final UUID futureId = randomUUID();
+        final ClientMetadata futureClient = new ClientMetadata(null);
+        clientMetadataMap.put(futureId, futureClient);
 
         int minReceivedMetadataVersion = LATEST_SUPPORTED_VERSION;
         int minSupportedMetadataVersion = LATEST_SUPPORTED_VERSION;
@@ -292,58 +333,59 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         for (final Map.Entry<String, Subscription> entry : subscriptions.entrySet()) {
             final String consumerId = entry.getKey();
             final Subscription subscription = entry.getValue();
-
             final SubscriptionInfo info = SubscriptionInfo.decode(subscription.userData());
             final int usedVersion = info.version();
+
+            minReceivedMetadataVersion = updateMinReceivedVersion(usedVersion, minReceivedMetadataVersion);
+            minSupportedMetadataVersion = updateMinSupportedVersion(info.latestSupportedVersion(), minSupportedMetadataVersion);
+
+            final UUID processId;
             if (usedVersion > LATEST_SUPPORTED_VERSION) {
                 futureMetadataVersion = usedVersion;
-                futureConsumers.add(consumerId);
-                continue;
-            }
-            if (usedVersion < minReceivedMetadataVersion) {
-                minReceivedMetadataVersion = usedVersion;
+                processId = futureId;
+            } else {
+                processId = info.processId();
             }
 
-            final int latestSupportedVersion = info.latestSupportedVersion();
-            if (latestSupportedVersion < minSupportedMetadataVersion) {
-                minSupportedMetadataVersion = latestSupportedVersion;
-            }
+            ClientMetadata clientMetadata = clientMetadataMap.get(processId);
 
             // create the new client metadata if necessary
-            ClientMetadata clientMetadata = clientMetadataMap.get(info.processId());
-
             if (clientMetadata == null) {
                 clientMetadata = new ClientMetadata(info.userEndPoint());
                 clientMetadataMap.put(info.processId(), clientMetadata);
             }
 
-            // add the consumer to the client
-            clientMetadata.addConsumer(consumerId, info);
+            // add the consumer and any info its its subscription to the client
+            clientMetadata.addConsumer(consumerId, subscription.ownedPartitions());
+            allOwnedPartitions.addAll(subscription.ownedPartitions());
+            if (info.prevTasks() != null && info.standbyTasks() != null) {
+                clientMetadata.addPreviousTasks(info);
+            }
         }
 
         final boolean versionProbing;
         if (futureMetadataVersion == UNKNOWN) {
             versionProbing = false;
+            clientMetadataMap.remove(futureId);
+        } else if (minReceivedMetadataVersion >= EARLIEST_PROBEABLE_VERSION) {
+            versionProbing = true;
+            log.info("Received a future (version probing) subscription (version: {})."
+                    + " Sending assignment back (with supported version {}).",
+                futureMetadataVersion,
+                minSupportedMetadataVersion);
+
         } else {
-            if (minReceivedMetadataVersion >= EARLIEST_PROBEABLE_VERSION) {
-                log.info("Received a future (version probing) subscription (version: {})."
-                             + " Sending empty assignment back (with supported version {}).",
-                         futureMetadataVersion,
-                         LATEST_SUPPORTED_VERSION);
-                versionProbing = true;
-            } else {
-                throw new IllegalStateException(
-                    "Received a future (version probing) subscription (version: " + futureMetadataVersion
-                        + ") and an incompatible pre Kafka 2.0 subscription (version: " + minReceivedMetadataVersion
-                        + ") at the same time."
-                );
-            }
+            throw new IllegalStateException(
+                "Received a future (version probing) subscription (version: " + futureMetadataVersion
+                    + ") and an incompatible pre Kafka 2.0 subscription (version: " + minReceivedMetadataVersion
+                    + ") at the same time."
+            );
         }
 
         if (minReceivedMetadataVersion < LATEST_SUPPORTED_VERSION) {
             log.info("Downgrading metadata to version {}. Latest supported version is {}.",
-                     minReceivedMetadataVersion,
-                     LATEST_SUPPORTED_VERSION);
+                minReceivedMetadataVersion,
+                LATEST_SUPPORTED_VERSION);
         }
 
         log.debug("Constructed client metadata {} from the member subscriptions.", clientMetadataMap);
@@ -361,9 +403,10 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                 if (!topicsInfo.repartitionSourceTopics.keySet().contains(topic) &&
                     !metadata.topics().contains(topic)) {
                     log.error("Missing source topic {} during assignment. Returning error {}.",
-                              topic, AssignorError.INCOMPLETE_SOURCE_TOPIC_METADATA.name());
+                        topic, AssignorError.INCOMPLETE_SOURCE_TOPIC_METADATA.name());
                     return new GroupAssignment(
-                        errorAssignment(clientMetadataMap, topic, AssignorError.INCOMPLETE_SOURCE_TOPIC_METADATA.code())
+                        errorAssignment(clientMetadataMap, topic,
+                            AssignorError.INCOMPLETE_SOURCE_TOPIC_METADATA.code())
                     );
                 }
             }
@@ -378,7 +421,8 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
 
             for (final InternalTopologyBuilder.TopicsInfo topicsInfo : topicGroups.values()) {
                 for (final String topicName : topicsInfo.repartitionSourceTopics.keySet()) {
-                    final Optional<Integer> maybeNumPartitions = repartitionTopicMetadata.get(topicName).numberOfPartitions();
+                    final Optional<Integer> maybeNumPartitions = repartitionTopicMetadata.get(topicName)
+                        .numberOfPartitions();
                     Integer numPartitions = null;
 
                     if (!maybeNumPartitions.isPresent()) {
@@ -395,7 +439,8 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                                     // map().join().join(map())
                                     if (repartitionTopicMetadata.containsKey(sourceTopicName)) {
                                         if (repartitionTopicMetadata.get(sourceTopicName).numberOfPartitions().isPresent()) {
-                                            numPartitionsCandidate = repartitionTopicMetadata.get(sourceTopicName).numberOfPartitions().get();
+                                            numPartitionsCandidate =
+                                                repartitionTopicMetadata.get(sourceTopicName).numberOfPartitions().get();
                                         }
                                     } else {
                                         final Integer count = metadata.partitionCountForTopic(sourceTopicName);
@@ -427,7 +472,6 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             }
         } while (numPartitionsNeeded);
 
-
         // ensure the co-partitioning topics within the group have the same number of partitions,
         // and enforce the number of partitions for those repartition topics to be the same if they
         // are co-partitioned as well.
@@ -470,19 +514,23 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         final Map<TaskId, Set<TopicPartition>> partitionsForTask =
             partitionGrouper.partitionGroups(sourceTopicsByGroup, fullMetadata);
 
+        final Map<TopicPartition, TaskId> taskForPartition = new HashMap<>();
+
         // check if all partitions are assigned, and there are no duplicates of partitions in multiple tasks
         final Set<TopicPartition> allAssignedPartitions = new HashSet<>();
         final Map<Integer, Set<TaskId>> tasksByTopicGroup = new HashMap<>();
         for (final Map.Entry<TaskId, Set<TopicPartition>> entry : partitionsForTask.entrySet()) {
+            final TaskId id = entry.getKey();
             final Set<TopicPartition> partitions = entry.getValue();
+
             for (final TopicPartition partition : partitions) {
+                taskForPartition.put(partition, id);
                 if (allAssignedPartitions.contains(partition)) {
                     log.warn("Partition {} is assigned to more than one tasks: {}", partition, partitionsForTask);
                 }
             }
             allAssignedPartitions.addAll(partitions);
 
-            final TaskId id = entry.getKey();
             tasksByTopicGroup.computeIfAbsent(id.topicGroupId, k -> new HashSet<>()).add(id);
         }
         for (final String topic : allSourceTopics) {
@@ -491,13 +539,15 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                 log.warn("No partitions found for topic {}", topic);
             } else {
                 for (final PartitionInfo partitionInfo : partitionInfoList) {
-                    final TopicPartition partition = new TopicPartition(partitionInfo.topic(), partitionInfo.partition());
+                    final TopicPartition partition = new TopicPartition(partitionInfo.topic(),
+                        partitionInfo.partition());
                     if (!allAssignedPartitions.contains(partition)) {
                         log.warn("Partition {} is not assigned to any tasks: {}"
-                                     + " Possible causes of a partition not getting assigned"
-                                     + " is that another topic defined in the topology has not been"
-                                     + " created when starting your streams application,"
-                                     + " resulting in no tasks created for this topology at all.", partition, partitionsForTask);
+                                + " Possible causes of a partition not getting assigned"
+                                + " is that another topic defined in the topology has not been"
+                                + " created when starting your streams application,"
+                                + " resulting in no tasks created for this topology at all.", partition,
+                            partitionsForTask);
                     }
                 }
             }
@@ -533,15 +583,32 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
 
         // ---------------- Step Two ---------------- //
 
-        // assign tasks to clients
         final Map<UUID, ClientState> states = new HashMap<>();
         for (final Map.Entry<UUID, ClientMetadata> entry : clientMetadataMap.entrySet()) {
-            states.put(entry.getKey(), entry.getValue().state);
+            final ClientState state = entry.getValue().state;
+            states.put(entry.getKey(), state);
+
+            // Either the active tasks (eager) OR the owned partitions (cooperative) were encoded in the subscription
+            // according to the rebalancing protocol, so convert any partitions in a client to tasks where necessary
+            if (!state.ownedPartitions().isEmpty()) {
+                final Set<TaskId> previousActiveTasks = new HashSet<>();
+                for (final Map.Entry<TopicPartition, String> partitionEntry : state.ownedPartitions().entrySet()) {
+                    final TopicPartition tp = partitionEntry.getKey();
+                    final TaskId task = taskForPartition.get(tp);
+                    if (task != null) {
+                        previousActiveTasks.add(task);
+                    } else {
+                        log.error("No task found for topic partition {}", tp);
+                    }
+                }
+                state.addPreviousActiveTasks(previousActiveTasks);
+            }
         }
 
         log.debug("Assigning tasks {} to clients {} with number of replicas {}",
-                  partitionsForTask.keySet(), states, numStandbyReplicas);
+            partitionsForTask.keySet(), states, numStandbyReplicas);
 
+        // assign tasks to clients
         final StickyTaskAssignor<UUID> taskAssignor = new StickyTaskAssignor<>(states, partitionsForTask.keySet());
         taskAssignor.assign(numStandbyReplicas);
 
@@ -577,7 +644,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                 clientMetadataMap,
                 partitionsForTask,
                 partitionsByHostState,
-                futureConsumers,
+                allOwnedPartitions,
                 minReceivedMetadataVersion,
                 minSupportedMetadataVersion
             );
@@ -586,6 +653,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                 clientMetadataMap,
                 partitionsForTask,
                 partitionsByHostState,
+                allOwnedPartitions,
                 minReceivedMetadataVersion,
                 minSupportedMetadataVersion
             );
@@ -594,132 +662,165 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         return new GroupAssignment(assignment);
     }
 
-    private static Map<String, Assignment> computeNewAssignment(final Map<UUID, ClientMetadata> clientsMetadata,
-                                                                final Map<TaskId, Set<TopicPartition>> partitionsForTask,
-                                                                final Map<HostInfo, Set<TopicPartition>> partitionsByHostState,
-                                                                final int minUserMetadataVersion,
-                                                                final int minSupportedMetadataVersion) {
+    private Map<String, Assignment> computeNewAssignment(final Map<UUID, ClientMetadata> clientsMetadata,
+                                                         final Map<TaskId, Set<TopicPartition>> partitionsForTask,
+                                                         final Map<HostInfo, Set<TopicPartition>> partitionsByHostState,
+                                                         final Set<TopicPartition> allOwnedPartitions,
+                                                         final int minUserMetadataVersion,
+                                                         final int minSupportedMetadataVersion) {
+        // keep track of whether a 2nd rebalance is unavoidable so we can skip trying to get a completely sticky assignment
+        boolean rebalanceRequired = false;
         final Map<String, Assignment> assignment = new HashMap<>();
 
         // within the client, distribute tasks to its owned consumers
-        for (final Map.Entry<UUID, ClientMetadata> entry : clientsMetadata.entrySet()) {
-            final Set<String> consumers = entry.getValue().consumers;
-            final ClientState state = entry.getValue().state;
-
-            final List<List<TaskId>> interleavedActive =
-                interleaveTasksByGroupId(state.activeTasks(), consumers.size());
-            final List<List<TaskId>> interleavedStandby =
-                interleaveTasksByGroupId(state.standbyTasks(), consumers.size());
-
-            int consumerTaskIndex = 0;
-
-            for (final String consumer : consumers) {
-                final List<TaskId> activeTasks = interleavedActive.get(consumerTaskIndex);
-
-                // These will be filled in by buildAssignedActiveTaskAndPartitionsList below
-                final List<TopicPartition> activePartitionsList = new ArrayList<>();
-                final List<TaskId> assignedActiveList = new ArrayList<>();
-
-                buildAssignedActiveTaskAndPartitionsList(activeTasks, activePartitionsList, assignedActiveList, partitionsForTask);
+        for (final ClientMetadata clientMetadata : clientsMetadata.values()) {
+            final ClientState state = clientMetadata.state;
+            final Set<String> consumers = clientMetadata.consumers;
+            Map<String, List<TaskId>> activeTaskAssignments;
+
+            // Try to avoid triggering another rebalance by giving active tasks back to their previous owners within a
+            // client, without violating load balance. If we already know another rebalance will be required, or the
+            // client had no owned partitions, try to balance the workload as evenly as possible by interleaving the
+            // tasks among consumers and hopefully spreading the heavier subtopologies evenly across threads.
+            if (rebalanceRequired || state.ownedPartitions().isEmpty()) {
+                activeTaskAssignments = interleaveConsumerTasksByGroupId(state.activeTasks(), consumers);
+            } else if ((activeTaskAssignments = tryStickyAndBalancedTaskAssignmentWithinClient(state, consumers, partitionsForTask, allOwnedPartitions))
+                        .equals(Collections.emptyMap())) {
+                rebalanceRequired = true;
+                activeTaskAssignments = interleaveConsumerTasksByGroupId(state.activeTasks(), consumers);
+            }
 
-                final Map<TaskId, Set<TopicPartition>> standby = new HashMap<>();
-                if (!state.standbyTasks().isEmpty()) {
-                    final List<TaskId> assignedStandbyList = interleavedStandby.get(consumerTaskIndex);
-                    for (final TaskId taskId : assignedStandbyList) {
-                        standby.computeIfAbsent(taskId, k -> new HashSet<>()).addAll(partitionsForTask.get(taskId));
-                    }
-                }
+            final Map<String, List<TaskId>> interleavedStandby =
+                interleaveConsumerTasksByGroupId(state.standbyTasks(), consumers);
 
-                consumerTaskIndex++;
-
-                // finally, encode the assignment before sending back to coordinator
-                assignment.put(
-                    consumer,
-                    new Assignment(
-                        activePartitionsList,
-                        new AssignmentInfo(
-                            minUserMetadataVersion,
-                            minSupportedMetadataVersion,
-                            assignedActiveList,
-                            standby,
-                            partitionsByHostState,
-                            0
-                        ).encode()
-                    )
-                );
-            }
+            addClientAssignments(
+                assignment,
+                clientMetadata,
+                partitionsForTask,
+                partitionsByHostState,
+                allOwnedPartitions,
+                activeTaskAssignments,
+                interleavedStandby,
+                minUserMetadataVersion,
+                minSupportedMetadataVersion);
         }
 
         return assignment;
     }
 
-    private static Map<String, Assignment> versionProbingAssignment(final Map<UUID, ClientMetadata> clientsMetadata,
-                                                                    final Map<TaskId, Set<TopicPartition>> partitionsForTask,
-                                                                    final Map<HostInfo, Set<TopicPartition>> partitionsByHostState,
-                                                                    final Set<String> futureConsumers,
-                                                                    final int minUserMetadataVersion,
-                                                                    final int minSupportedMetadataVersion) {
+    private Map<String, Assignment> versionProbingAssignment(final Map<UUID, ClientMetadata> clientsMetadata,
+                                                             final Map<TaskId, Set<TopicPartition>> partitionsForTask,
+                                                             final Map<HostInfo, Set<TopicPartition>> partitionsByHostState,
+                                                             final Set<TopicPartition> allOwnedPartitions,
+                                                             final int minUserMetadataVersion,
+                                                             final int minSupportedMetadataVersion) {
         final Map<String, Assignment> assignment = new HashMap<>();
 
-        // assign previously assigned tasks to "old consumers"
+        // Since we know another rebalance will be triggered anyway, just try and generate a balanced assignment
+        // (without violating cooperative protocol) now so that on the second rebalance we can just give tasks
+        // back to their previous owners
+        // within the client, distribute tasks to its owned consumers
         for (final ClientMetadata clientMetadata : clientsMetadata.values()) {
-            for (final String consumerId : clientMetadata.consumers) {
+            final ClientState state = clientMetadata.state;
 
-                if (futureConsumers.contains(consumerId)) {
-                    continue;
-                }
+            final Map<String, List<TaskId>> interleavedActive =
+                interleaveConsumerTasksByGroupId(state.activeTasks(), clientMetadata.consumers);
+            final Map<String, List<TaskId>> interleavedStandby =
+                interleaveConsumerTasksByGroupId(state.standbyTasks(), clientMetadata.consumers);
 
-                // Return the same active tasks that were claimed in the subscription
-                final List<TaskId> activeTasks = new ArrayList<>(clientMetadata.state.prevActiveTasksForConsumer(consumerId));
-
-                // These will be filled in by buildAssignedActiveTaskAndPartitionsList below
-                final List<TopicPartition> activePartitionsList = new ArrayList<>();
-                final List<TaskId> assignedActiveList = new ArrayList<>();
-
-                buildAssignedActiveTaskAndPartitionsList(activeTasks, activePartitionsList, assignedActiveList, partitionsForTask);
+            addClientAssignments(
+                assignment,
+                clientMetadata,
+                partitionsForTask,
+                partitionsByHostState,
+                allOwnedPartitions,
+                interleavedActive,
+                interleavedStandby,
+                minUserMetadataVersion,
+                minSupportedMetadataVersion);
+        }
 
-                // Return the same standby tasks that were claimed in the subscription
-                final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
-                for (final TaskId taskId : clientMetadata.state.prevStandbyTasksForConsumer(consumerId)) {
-                    standbyTasks.put(taskId, partitionsForTask.get(taskId));
-                }
+        return assignment;
+    }
 
-                assignment.put(consumerId, new Assignment(
+    private void addClientAssignments(final Map<String, Assignment> assignment,
+                                      final ClientMetadata clientMetadata,
+                                      final Map<TaskId, Set<TopicPartition>> partitionsForTask,
+                                      final Map<HostInfo, Set<TopicPartition>> partitionsByHostState,
+                                      final Set<TopicPartition> allOwnedPartitions,
+                                      final Map<String, List<TaskId>> activeTaskAssignments,
+                                      final Map<String, List<TaskId>> standbyTaskAssignments,
+                                      final int minUserMetadataVersion,
+                                      final int minSupportedMetadataVersion) {
+
+        // Loop through the consumers and build their assignment
+        for (final String consumer : clientMetadata.consumers) {
+            final List<TaskId> activeTasksForConsumer = activeTaskAssignments.get(consumer);
+
+            // These will be filled in by buildAssignedActiveTaskAndPartitionsList below
+            final List<TopicPartition> activePartitionsList = new ArrayList<>();
+            final List<TaskId> assignedActiveList = new ArrayList<>();
+
+            buildAssignedActiveTaskAndPartitionsList(consumer,
+                                                     clientMetadata.state,
+                                                     activeTasksForConsumer,
+                                                     partitionsForTask,
+                                                     allOwnedPartitions,
+                                                     activePartitionsList,
+                                                     assignedActiveList);
+
+            final Map<TaskId, Set<TopicPartition>> standbyTaskMap =
+                buildStandbyTaskMap(standbyTaskAssignments.get(consumer), partitionsForTask);
+
+            // finally, encode the assignment and insert into map with all assignments
+            assignment.put(
+                consumer,
+                new Assignment(
                     activePartitionsList,
                     new AssignmentInfo(
                         minUserMetadataVersion,
                         minSupportedMetadataVersion,
                         assignedActiveList,
-                        standbyTasks,
+                        standbyTaskMap,
                         partitionsByHostState,
-                        0)
-                        .encode()
-                ));
-            }
-        }
-
-        // add empty assignment for "future version" clients (ie, empty version probing response)
-        for (final String consumerId : futureConsumers) {
-            assignment.put(consumerId, new Assignment(
-                Collections.emptyList(),
-                new AssignmentInfo(minUserMetadataVersion, minSupportedMetadataVersion).encode()
-            ));
+                        AssignorError.NONE.code()
+                    ).encode()
+                )
+            );
         }
-
-        return assignment;
     }
 
-    private static void buildAssignedActiveTaskAndPartitionsList(final List<TaskId> activeTasks,
-                                                                 final List<TopicPartition> activePartitionsList,
-                                                                 final List<TaskId> assignedActiveList,
-                                                                 final Map<TaskId, Set<TopicPartition>> partitionsForTask) {
+    private void buildAssignedActiveTaskAndPartitionsList(final String consumer,
+                                                          final ClientState clientState,
+                                                          final List<TaskId> activeTasksForConsumer,
+                                                          final Map<TaskId, Set<TopicPartition>> partitionsForTask,
+                                                          final Set<TopicPartition> allOwnedPartitions,
+                                                          final List<TopicPartition> activePartitionsList,
+                                                          final List<TaskId> assignedActiveList) {
         final List<AssignedPartition> assignedPartitions = new ArrayList<>();
 
         // Build up list of all assigned partition-task pairs
-        for (final TaskId taskId : activeTasks) {
+        for (final TaskId taskId : activeTasksForConsumer) {
+            final List<AssignedPartition> assignedPartitionsForTask = new ArrayList<>();
             for (final TopicPartition partition : partitionsForTask.get(taskId)) {
-                assignedPartitions.add(new AssignedPartition(taskId, partition));
+                final String oldOwner = clientState.ownedPartitions().get(partition);
+                final boolean newPartitionForConsumer = oldOwner == null || !oldOwner.equals(consumer);
+
+                // If the partition is new to this consumer but is still owned by another, remove from the assignment
+                // until it has been revoked and can safely be reassigned according the COOPERATIVE protocol
+                if (newPartitionForConsumer && allOwnedPartitions.contains(partition)) {
+                    log.debug("Removing task {} from assignment until it is safely revoked", taskId);
+                    clientState.removeFromAssignment(taskId);
+                    // Clear the assigned partitions list for this task if any partition can not safely be assigned,
+                    // so as not to encode a partial task
+                    assignedPartitionsForTask.clear();
+                    break;
+                } else {
+                    assignedPartitionsForTask.add(new AssignedPartition(taskId, partition));
+                }
             }
+            // assignedPartitionsForTask will either contain all partitions for the task or be empty, so just add all
+            assignedPartitions.addAll(assignedPartitionsForTask);
         }
 
         // Add one copy of a task for each corresponding partition, so the receiver can determine the task <-> tp mapping
@@ -730,17 +831,175 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         }
     }
 
-    // visible for testing
-    static List<List<TaskId>> interleaveTasksByGroupId(final Collection<TaskId> taskIds, final int numberThreads) {
+    private static Map<TaskId, Set<TopicPartition>> buildStandbyTaskMap(final Collection<TaskId> standbys,
+                                                                        final Map<TaskId, Set<TopicPartition>> partitionsForTask) {
+        final Map<TaskId, Set<TopicPartition>> standbyTaskMap = new HashMap<>();
+        for (final TaskId task : standbys) {
+            standbyTaskMap.put(task, partitionsForTask.get(task));
+        }
+        return standbyTaskMap;
+    }
+
+    /**
+     * Generates an assignment that tries to satisfy two conditions: no active task previously owned by a consumer
+     * be assigned to another (ie nothing gets revoked), and the number of tasks is evenly distributed throughout
+     * the client.
+     * <p>
+     * If it is impossible to satisfy both constraints we abort early and return an empty map so we can use a
+     * different assignment strategy that tries to distribute tasks of a single subtopology across different threads.
+     *
+     * @param state state for this client
+     * @param consumers the consumers in this client
+     * @param partitionsForTask mapping from task to its associated partitions
+     * @param allOwnedPartitions set of all partitions claimed as owned by the group
+     * @return task assignment for the consumers of this client
+     *         empty map if it is not possible to generate a balanced assignment without moving a task to a new consumer
+     */
+    Map<String, List<TaskId>> tryStickyAndBalancedTaskAssignmentWithinClient(final ClientState state,
+                                                                             final Set<String> consumers,
+                                                                             final Map<TaskId, Set<TopicPartition>> partitionsForTask,
+                                                                             final Set<TopicPartition> allOwnedPartitions) {
+        final Map<String, List<TaskId>> assignments = new HashMap<>();
+        final LinkedList<TaskId> newTasks = new LinkedList<>();
+        final Set<String> unfilledConsumers = new HashSet<>(consumers);
+
+        final int maxTasksPerClient = (int) Math.ceil(((double) state.activeTaskCount()) / consumers.size());
+
+        // initialize task list for consumers
+        for (final String consumer : consumers) {
+            assignments.put(consumer, new ArrayList<>());
+        }
+
+        for (final TaskId task : state.activeTasks()) {
+            final Set<String> previousConsumers = previousConsumersOfTaskPartitions(partitionsForTask.get(task), state.ownedPartitions(), allOwnedPartitions);
+
+            // If this task's partitions were owned by different consumers, we can't avoid revoking partitions
+            if (previousConsumers.size() > 1) {
+                log.warn("The partitions of task {} were claimed as owned by different StreamThreads. " +
+                    "This indicates the mapping from partitions to tasks has changed!", task);
+                return Collections.emptyMap();
+            }
+
+            // If this is a new task, or its old consumer no longer exists, it can be freely (re)assigned
+            if (previousConsumers.isEmpty()) {
+                log.debug("Task {} was not previously owned by any consumers still in the group. It's owner may " +
+                    "have died or it may be a new task", task);
+                newTasks.add(task);
+            } else {
+                final String consumer = previousConsumers.iterator().next();
+
+                // If the previous consumer was from another client, these partitions will have to be revoked
+                if (!consumers.contains(consumer)) {
+                    log.debug("This client was assigned a task {} whose partition(s) were previously owned by another " +
+                        "client, falling back to an interleaved assignment since a rebalance is inevitable.", task);
+                    return Collections.emptyMap();
+                }
+
+                // If this consumer previously owned more tasks than it has capacity for, some must be revoked
+                if (assignments.get(consumer).size() >= maxTasksPerClient) {
+                    log.debug("Cannot create a sticky and balanced assignment as this client's consumers owned more " +
+                        "previous tasks than it has capacity for during this assignment, falling back to interleaved " +
+                        "assignment since a realance is inevitable.");
+                    return Collections.emptyMap();
+                }
+
+                assignments.get(consumer).add(task);
+
+                // If we have now reached capacity, remove it from set of consumers who still need more tasks
+                if (assignments.get(consumer).size() == maxTasksPerClient) {
+                    unfilledConsumers.remove(consumer);
+                }
+            }
+        }
+
+        // Interleave any remaining tasks by groupId among the consumers with remaining capacity. For further
+        // explanation, see the javadocs for #interleaveConsumerTasksByGroupId
+        Collections.sort(newTasks);
+        while (!newTasks.isEmpty()) {
+            if (unfilledConsumers.isEmpty()) {
+                throw new IllegalStateException("Some tasks could not be distributed");
+            }
+
+            final Iterator<String> consumerIt = unfilledConsumers.iterator();
+
+            // Loop through the unfilled consumers and distribute tasks until newTasks is empty
+            while (consumerIt.hasNext()) {
+                final String consumer = consumerIt.next();
+                final List<TaskId> consumerAssignment = assignments.get(consumer);
+                final TaskId task = newTasks.poll();
+                if (task == null) {
+                    break;
+                }
+
+                consumerAssignment.add(task);
+                if (consumerAssignment.size() == maxTasksPerClient) {
+                    consumerIt.remove();
+                }
+            }
+        }
+
+        return assignments;
+    }
+
+    /**
+     * Get the previous consumer for the partitions of a task
+     *
+     * @param taskPartitions the TopicPartitions for a single given task
+     * @param clientOwnedPartitions the partitions owned by all consumers in a client
+     * @param allOwnedPartitions all partitions claimed as owned by any consumer in any client
+     * @return set of consumer(s) that previously owned the partitions in this task
+     *         empty set signals that it is a new task, or its previous owner is no longer in the group
+     */
+    Set<String> previousConsumersOfTaskPartitions(final Set<TopicPartition> taskPartitions,
+                                                  final Map<TopicPartition, String> clientOwnedPartitions,
+                                                  final Set<TopicPartition> allOwnedPartitions) {
+        // this "foreignConsumer" indicates a partition was owned by someone from another client -- we don't really care who
+        final String foreignConsumer = "";
+        final Set<String> previousConsumers = new HashSet<>();
+
+        for (final TopicPartition tp : taskPartitions) {
+            final String currentPartitionConsumer = clientOwnedPartitions.get(tp);
+            if (currentPartitionConsumer != null) {
+                previousConsumers.add(currentPartitionConsumer);
+            } else if (allOwnedPartitions.contains(tp)) {
+                previousConsumers.add(foreignConsumer);
+            }
+        }
+
+        return previousConsumers;
+    }
+
+    /**
+     * Generate an assignment that attempts to maximize load balance without regard for stickiness, by spreading
+     * tasks of the same groupId (subtopology) over different consumers.
+     *
+     * @param taskIds the set of tasks to be distributed
+     * @param consumers the set of consumers to receive tasks
+     * @return a map of task assignments keyed by the consumer id
+     */
+    static Map<String, List<TaskId>> interleaveConsumerTasksByGroupId(final Collection<TaskId> taskIds,
+                                                                      final Set<String> consumers) {
+        // First we make a sorted list of the tasks, grouping them by groupId
         final LinkedList<TaskId> sortedTasks = new LinkedList<>(taskIds);
         Collections.sort(sortedTasks);
-        final List<List<TaskId>> taskIdsForConsumerAssignment = new ArrayList<>(numberThreads);
-        for (int i = 0; i < numberThreads; i++) {
-            taskIdsForConsumerAssignment.add(new ArrayList<>());
+
+        // Initialize the assignment map and task list for each consumer. We use a TreeMap here for a consistent
+        // ordering of the consumers in the hope they will end up with the same set of tasks in subsequent assignments
+        final Map<String, List<TaskId>> taskIdsForConsumerAssignment = new TreeMap<>();
+        for (final String consumer : consumers) {
+            taskIdsForConsumerAssignment.put(consumer, new ArrayList<>());
         }
+
+        // We loop until the tasks have all been assigned, removing them from the list when they are given to a
+        // consumer. To interleave the tasks, we loop through the consumers and give each one task from the head
+        // of the list. When we finish going through the list of consumers we start over at the beginning of the
+        // consumers list, continuing until we run out of tasks.
         while (!sortedTasks.isEmpty()) {
-            for (final List<TaskId> taskIdList : taskIdsForConsumerAssignment) {
+            for (final Map.Entry<String, List<TaskId>> consumerTaskIds : taskIdsForConsumerAssignment.entrySet()) {
+                final List<TaskId> taskIdList = consumerTaskIds.getValue();
                 final TaskId taskId = sortedTasks.poll();
+
+                // Check for null here as we may run out of tasks before giving every consumer exactly the same number
                 if (taskId == null) {
                     break;
                 }
@@ -774,7 +1033,6 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
     protected boolean maybeUpdateSubscriptionVersion(final int receivedAssignmentMetadataVersion,
                                                      final int latestCommonlySupportedVersion) {
         if (receivedAssignmentMetadataVersion >= EARLIEST_PROBEABLE_VERSION) {
-
             // If the latest commonly supported version is now greater than our used version, this indicates we have just
             // completed the rolling upgrade and can now update our subscription version for the final rebalance
             if (latestCommonlySupportedVersion > usedSubscriptionMetadataVersion) {
@@ -880,6 +1138,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         taskManager.setPartitionsToTaskId(partitionsToTaskId);
         taskManager.setAssignmentMetadata(activeTasks, info.standbyTasks());
         taskManager.updateSubscriptionsFromAssignment(partitions);
+        taskManager.setRebalanceInProgress(false);
     }
 
     private static void processVersionOneAssignment(final String logPrefix,
@@ -973,12 +1232,23 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         }
     }
 
+    private int updateMinReceivedVersion(final int usedVersion, final int minReceivedMetadataVersion) {
+        return usedVersion < minReceivedMetadataVersion ? usedVersion : minReceivedMetadataVersion;
+    }
+
+    private int updateMinSupportedVersion(final int supportedVersion, final int minSupportedMetadataVersion) {
+        return supportedVersion < minSupportedMetadataVersion ? supportedVersion : minSupportedMetadataVersion;
+    }
+
     protected void setAssignmentErrorCode(final Integer errorCode) {
         assignmentErrorCode.set(errorCode);
     }
 
-
     // following functions are for test only
+    void setRebalanceProtocol(final RebalanceProtocol rebalanceProtocol) {
+        this.rebalanceProtocol = rebalanceProtocol;
+    }
+
     void setInternalTopicManager(final InternalTopicManager internalTopicManager) {
         this.internalTopicManager = internalTopicManager;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
index da9e656..af1b4bd 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
@@ -27,7 +27,7 @@ import java.util.Set;
 
 public interface Task {
     /**
-     * Initialize the task and return {@code true} if the task is ready to run, i.e, it has not state stores
+     * Initialize the task and return {@code true} if the task is ready to run, i.e, it has no state stores
      * @return true if this task has no state stores that may need restoring.
      * @throws IllegalStateException If store gets registered after initialized is already finished
      * @throws StreamsException if the store's change log does not contain the partition
@@ -40,14 +40,8 @@ public interface Task {
 
     void commit();
 
-    void suspend();
-
     void resume();
 
-    void closeSuspended(final boolean clean,
-                        final boolean isZombie,
-                        final RuntimeException e);
-
     void close(final boolean clean,
                final boolean isZombie);
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index cd90fad..a2dac40 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -60,6 +60,7 @@ public class TaskManager {
 
     private final Admin adminClient;
     private DeleteRecordsResult deleteRecordsResult;
+    private boolean rebalanceInProgress = false;  // if we are in the middle of a rebalance, it is not safe to commit
 
     // the restore consumer is only ever assigned changelogs from restoring tasks or standbys (but not both)
     private boolean restoreConsumerAssignedStandbys = false;
@@ -144,7 +145,7 @@ public class TaskManager {
                     addedActiveTasks.put(taskId, partitions);
                 }
             } catch (final StreamsException e) {
-                log.error("Failed to resume an active task {} due to the following error:", taskId, e);
+                log.error("Failed to resume a suspended active task {} due to the following error:", taskId, e);
                 throw e;
             }
         }
@@ -303,7 +304,11 @@ public class TaskManager {
         }
     }
 
-    Set<TaskId> activeTaskIds() {
+    public Set<TaskId> previousRunningTaskIds() {
+        return active.previousRunningTaskIds();
+    }
+
+    public Set<TaskId> activeTaskIds() {
         return active.allAssignedTaskIds();
     }
 
@@ -319,10 +324,6 @@ public class TaskManager {
         return revokedStandbyTasks.keySet();
     }
 
-    public Set<TaskId> previousRunningTaskIds() {
-        return active.previousRunningTaskIds();
-    }
-
     Set<TaskId> previousActiveTaskIds() {
         final HashSet<TaskId> previousActiveTasks = new HashSet<>(assignedActiveTasks.keySet());
         previousActiveTasks.addAll(revokedActiveTasks.keySet());
@@ -378,9 +379,11 @@ public class TaskManager {
         active.initializeNewTasks();
         standby.initializeNewTasks();
 
-        final Collection<TopicPartition> restored = changelogReader.restore(active);
-        active.updateRestored(restored);
-        removeChangelogsFromRestoreConsumer(restored, false);
+        if (active.hasRestoringTasks()) {
+            final Collection<TopicPartition> restored = changelogReader.restore(active);
+            active.updateRestored(restored);
+            removeChangelogsFromRestoreConsumer(restored, false);
+        }
 
         if (active.allTasksRunning()) {
             final Set<TopicPartition> assignment = consumer.assignment();
@@ -420,6 +423,10 @@ public class TaskManager {
         }
     }
 
+    public void setRebalanceInProgress(final boolean rebalanceInProgress) {
+        this.rebalanceInProgress = rebalanceInProgress;
+    }
+
     public void setClusterMetadata(final Cluster cluster) {
         this.cluster = cluster;
     }
@@ -493,10 +500,10 @@ public class TaskManager {
     /**
      * @throws TaskMigratedException if committing offsets failed (non-EOS)
      *                               or if the task producer got fenced (EOS)
+     * @return number of committed offsets, or -1 if we are in the middle of a rebalance and cannot commit
      */
     int commitAll() {
-        final int committed = active.commit();
-        return committed + standby.commit();
+        return rebalanceInProgress ? -1 : active.commit() + standby.commit();
     }
 
     /**
@@ -518,7 +525,7 @@ public class TaskManager {
      *                               or if the task producer got fenced (EOS)
      */
     int maybeCommitActiveTasksPerUserRequested() {
-        return active.maybeCommitPerUserRequested();
+        return rebalanceInProgress ? -1 : active.maybeCommitPerUserRequested();
     }
 
     void maybePurgeCommitedRecords() {
@@ -528,7 +535,8 @@ public class TaskManager {
         if (deleteRecordsResult == null || deleteRecordsResult.all().isDone()) {
 
             if (deleteRecordsResult != null && deleteRecordsResult.all().isCompletedExceptionally()) {
-                log.debug("Previous delete-records request has failed: {}. Try sending the new request now", deleteRecordsResult.lowWatermarks());
+                log.debug("Previous delete-records request has failed: {}. Try sending the new request now",
+                    deleteRecordsResult.lowWatermarks());
             }
 
             final Map<TopicPartition, RecordsToDelete> recordsToDelete = new HashMap<>();
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
index ac88f2f..1e406e2 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
@@ -153,7 +153,8 @@ public final class AssignorConfiguration {
                     throw new IllegalArgumentException("Unknown configuration value for parameter 'upgrade.from': " + upgradeFrom);
             }
         }
-        return RebalanceProtocol.EAGER;
+
+        return RebalanceProtocol.COOPERATIVE;
     }
 
     public String logPrefix() {
@@ -181,14 +182,19 @@ public final class AssignorConfiguration {
                         upgradeFrom
                     );
                     return VERSION_TWO;
+                case StreamsConfig.UPGRADE_FROM_20:
+                case StreamsConfig.UPGRADE_FROM_21:
+                case StreamsConfig.UPGRADE_FROM_22:
+                case StreamsConfig.UPGRADE_FROM_23:
+                    // These configs are for cooperative rebalancing and should not affect the metadata version
+                    break;
                 default:
                     throw new IllegalArgumentException(
                         "Unknown configuration value for parameter 'upgrade.from': " + upgradeFrom
                     );
             }
-        } else {
-            return priorVersion;
         }
+        return priorVersion;
     }
 
     public int getNumStandbyReplicas() {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
index ab213d5..df42b14 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
@@ -16,11 +16,13 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import java.util.HashMap;
-import java.util.Map;
+import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.streams.processor.TaskId;
 
+import java.util.Collection;
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
 
 public class ClientState {
@@ -31,8 +33,7 @@ public class ClientState {
     private final Set<TaskId> prevStandbyTasks;
     private final Set<TaskId> prevAssignedTasks;
 
-    private final Map<String, Set<TaskId>> prevActiveTasksByConsumer;
-    private final Map<String, Set<TaskId>> prevStandbyTasksByConsumer;
+    private final Map<TopicPartition, String> ownedPartitions;
 
     private int capacity;
 
@@ -48,7 +49,6 @@ public class ClientState {
              new HashSet<>(),
              new HashSet<>(),
              new HashMap<>(),
-             new HashMap<>(),
              capacity);
     }
 
@@ -58,8 +58,7 @@ public class ClientState {
                         final Set<TaskId> prevActiveTasks,
                         final Set<TaskId> prevStandbyTasks,
                         final Set<TaskId> prevAssignedTasks,
-                        final Map<String, Set<TaskId>> prevActiveTasksByConsumer,
-                        final Map<String, Set<TaskId>> prevStandbyTasksByConsumer,
+                        final Map<TopicPartition, String> ownedPartitions,
                         final int capacity) {
         this.activeTasks = activeTasks;
         this.standbyTasks = standbyTasks;
@@ -67,8 +66,7 @@ public class ClientState {
         this.prevActiveTasks = prevActiveTasks;
         this.prevStandbyTasks = prevStandbyTasks;
         this.prevAssignedTasks = prevAssignedTasks;
-        this.prevActiveTasksByConsumer = prevActiveTasksByConsumer;
-        this.prevStandbyTasksByConsumer = prevStandbyTasksByConsumer;
+        this.ownedPartitions = ownedPartitions;
         this.capacity = capacity;
     }
 
@@ -80,8 +78,7 @@ public class ClientState {
             new HashSet<>(prevActiveTasks),
             new HashSet<>(prevStandbyTasks),
             new HashSet<>(prevAssignedTasks),
-            new HashMap<>(prevActiveTasksByConsumer),
-            new HashMap<>(prevStandbyTasksByConsumer),
+            new HashMap<>(ownedPartitions),
             capacity);
     }
 
@@ -111,6 +108,10 @@ public class ClientState {
         return prevStandbyTasks;
     }
 
+    public Map<TopicPartition, String> ownedPartitions() {
+        return ownedPartitions;
+    }
+
     @SuppressWarnings("WeakerAccess")
     public int assignedTaskCount() {
         return assignedTasks.size();
@@ -125,24 +126,25 @@ public class ClientState {
         return activeTasks.size();
     }
 
-    public void addPreviousActiveTasks(final String consumer, final Set<TaskId> prevTasks) {
+    public void addPreviousActiveTasks(final Set<TaskId> prevTasks) {
         prevActiveTasks.addAll(prevTasks);
         prevAssignedTasks.addAll(prevTasks);
-        prevActiveTasksByConsumer.put(consumer, prevTasks);
     }
 
-    public void addPreviousStandbyTasks(final String consumer, final Set<TaskId> standbyTasks) {
+    public void addPreviousStandbyTasks(final Set<TaskId> standbyTasks) {
         prevStandbyTasks.addAll(standbyTasks);
         prevAssignedTasks.addAll(standbyTasks);
-        prevStandbyTasksByConsumer.put(consumer, standbyTasks);
     }
 
-    public Set<TaskId> prevActiveTasksForConsumer(final String consumer) {
-        return prevActiveTasksByConsumer.get(consumer);
+    public void addOwnedPartitions(final Collection<TopicPartition> ownedPartitions, final String consumer) {
+        for (final TopicPartition tp : ownedPartitions) {
+            this.ownedPartitions.put(tp, consumer);
+        }
     }
 
-    public Set<TaskId> prevStandbyTasksForConsumer(final String consumer) {
-        return prevStandbyTasksByConsumer.get(consumer);
+    public void removeFromAssignment(final TaskId task) {
+        activeTasks.remove(task);
+        assignedTasks.remove(task);
     }
 
     @Override
@@ -153,6 +155,7 @@ public class ClientState {
                 ") prevActiveTasks: (" + prevActiveTasks +
                 ") prevStandbyTasks: (" + prevStandbyTasks +
                 ") prevAssignedTasks: (" + prevAssignedTasks +
+                ") prevOwnedPartitionsByConsumerId: (" + ownedPartitions.keySet() +
                 ") capacity: " + capacity +
                 "]";
     }
@@ -182,16 +185,6 @@ public class ClientState {
         }
     }
 
-    Set<TaskId> previousStandbyTasks() {
-        final Set<TaskId> standby = new HashSet<>(prevAssignedTasks);
-        standby.removeAll(prevActiveTasks);
-        return standby;
-    }
-
-    Set<TaskId> previousActiveTasks() {
-        return prevActiveTasks;
-    }
-
     boolean hasAssignedTask(final TaskId taskId) {
         return assignedTasks.contains(taskId);
     }
@@ -212,4 +205,9 @@ public class ClientState {
     boolean hasUnfulfilledQuota(final int tasksPerThread) {
         return activeTasks.size() < capacity * tasksPerThread;
     }
+
+    // the following methods are used for testing only
+    public void assignActiveTasks(final Collection<TaskId> tasks) {
+        activeTasks.addAll(tasks);
+    }
 }
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
index 157497d..d1da8b8 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
@@ -228,14 +228,12 @@ public class StickyTaskAssignor<ID> implements TaskAssignor<ID, TaskId> {
 
     private void mapPreviousTaskAssignment(final Map<ID, ClientState> clients) {
         for (final Map.Entry<ID, ClientState> clientState : clients.entrySet()) {
-            for (final TaskId activeTask : clientState.getValue().previousActiveTasks()) {
+            for (final TaskId activeTask : clientState.getValue().prevActiveTasks()) {
                 previousActiveTaskAssignment.put(activeTask, clientState.getKey());
             }
 
-            for (final TaskId prevAssignedTask : clientState.getValue().previousStandbyTasks()) {
-                if (!previousStandbyTaskAssignment.containsKey(prevAssignedTask)) {
-                    previousStandbyTaskAssignment.put(prevAssignedTask, new HashSet<>());
-                }
+            for (final TaskId prevAssignedTask : clientState.getValue().prevStandbyTasks()) {
+                previousStandbyTaskAssignment.computeIfAbsent(prevAssignedTask, t -> new HashSet<>());
                 previousStandbyTaskAssignment.get(prevAssignedTask).add(clientState.getKey());
             }
         }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
index a8526bf..7ce9712 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
@@ -227,15 +227,9 @@ public class AbstractTaskTest {
             public void commit() {}
 
             @Override
-            public void suspend() {}
-
-            @Override
             public void close(final boolean clean, final boolean isZombie) {}
 
             @Override
-            public void closeSuspended(final boolean clean, final boolean isZombie, final RuntimeException e) {}
-
-            @Override
             public boolean initializeStateStores() {
                 return false;
             }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index 88d4d6f..c2b9cdb 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -443,7 +443,6 @@ public class StandbyTaskTest {
             singletonList(makeWindowedConsumerRecord(changelogName, 10, 1, 0L, 60_000L))
         );
 
-        task.suspend();
         task.close(true, false);
 
         final File taskDir = stateDirectory.directoryForTask(taskId);
@@ -817,26 +816,4 @@ public class StandbyTaskTest {
         final double expectedCloseTaskMetric = 1.0;
         verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName);
     }
-
-    @Test
-    public void shouldRecordTaskClosedMetricOnCloseSuspended() throws IOException {
-        final MetricName metricName = setupCloseTaskMetric();
-        final StandbyTask task = new StandbyTask(
-            taskId,
-            ktablePartitions,
-            ktableTopology,
-            consumer,
-            changelogReader,
-            createConfig(baseDir),
-            streamsMetrics,
-            stateDirectory
-        );
-
-        final boolean clean = true;
-        final boolean isZombie = false;
-        task.closeSuspended(clean, isZombie, new RuntimeException());
-
-        final double expectedCloseTaskMetric = 1.0;
-        verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName);
-    }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
index e997f2e..9ff6e33 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
@@ -16,8 +16,10 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.util.Arrays;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.GroupSubscription;
+import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.Node;
@@ -37,6 +39,7 @@ import org.apache.kafka.streams.kstream.Materialized;
 import org.apache.kafka.streams.kstream.ValueJoiner;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
+import org.apache.kafka.streams.processor.internals.assignment.ClientState;
 import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
 import org.apache.kafka.streams.state.HostInfo;
 import org.apache.kafka.test.MockClientSupplier;
@@ -66,11 +69,17 @@ import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 @SuppressWarnings("unchecked")
 public class StreamsPartitionAssignorTest {
+    private final String c1 = "consumer1";
+    private final String c2 = "consumer2";
+    private final String c3 = "consumer3";
+    private final String c4 = "consumer4";
 
     private final TopicPartition t1p0 = new TopicPartition("topic1", 0);
     private final TopicPartition t1p1 = new TopicPartition("topic1", 1);
@@ -84,6 +93,40 @@ public class StreamsPartitionAssignorTest {
     private final TopicPartition t3p1 = new TopicPartition("topic3", 1);
     private final TopicPartition t3p2 = new TopicPartition("topic3", 2);
     private final TopicPartition t3p3 = new TopicPartition("topic3", 3);
+    private final TopicPartition t4p0 = new TopicPartition("topic4", 0);
+    private final TopicPartition t4p1 = new TopicPartition("topic4", 1);
+    private final TopicPartition t4p2 = new TopicPartition("topic4", 2);
+    private final TopicPartition t4p3 = new TopicPartition("topic4", 3);
+
+    private final TaskId task0_0 = new TaskId(0, 0);
+    private final TaskId task0_1 = new TaskId(0, 1);
+    private final TaskId task0_2 = new TaskId(0, 2);
+    private final TaskId task0_3 = new TaskId(0, 3);
+    private final TaskId task1_0 = new TaskId(1, 0);
+    private final TaskId task1_1 = new TaskId(1, 1);
+    private final TaskId task1_2 = new TaskId(1, 2);
+    private final TaskId task1_3 = new TaskId(1, 3);
+    private final TaskId task2_0 = new TaskId(2, 0);
+    private final TaskId task2_1 = new TaskId(2, 1);
+    private final TaskId task2_2 = new TaskId(2, 2);
+    private final TaskId task2_3 = new TaskId(2, 3);
+
+    private final Map<TaskId, Set<TopicPartition>> partitionsForTask = new HashMap<TaskId, Set<TopicPartition>>() {{
+            put(task0_0, Utils.mkSet(t1p0, t2p0));
+            put(task0_1, Utils.mkSet(t1p1, t2p1));
+            put(task0_2, Utils.mkSet(t1p2, t2p2));
+            put(task0_3, Utils.mkSet(t1p3, t2p3));
+
+            put(task1_0, Utils.mkSet(t3p0));
+            put(task1_1, Utils.mkSet(t3p1));
+            put(task1_2, Utils.mkSet(t3p2));
+            put(task1_3, Utils.mkSet(t3p3));
+
+            put(task2_0, Utils.mkSet(t4p0));
+            put(task2_1, Utils.mkSet(t4p1));
+            put(task2_2, Utils.mkSet(t4p2));
+            put(task2_3, Utils.mkSet(t4p3));
+        }};
 
     private final Set<String> allTopics = Utils.mkSet("topic1", "topic2");
 
@@ -109,10 +152,6 @@ public class StreamsPartitionAssignorTest {
         Collections.emptySet(),
         Collections.emptySet());
 
-    private final TaskId task0 = new TaskId(0, 0);
-    private final TaskId task1 = new TaskId(0, 1);
-    private final TaskId task2 = new TaskId(0, 2);
-    private final TaskId task3 = new TaskId(0, 3);
     private final StreamsPartitionAssignor partitionAssignor = new StreamsPartitionAssignor();
     private final MockClientSupplier mockClientSupplier = new MockClientSupplier();
     private final InternalTopologyBuilder builder = new InternalTopologyBuilder();
@@ -137,6 +176,11 @@ public class StreamsPartitionAssignorTest {
         partitionAssignor.configure(configurationMap);
     }
 
+    private void configureDefault() {
+        createMockTaskManager();
+        partitionAssignor.configure(configProps());
+    }
+
     private void createMockTaskManager() {
         final StreamsBuilder builder = new StreamsBuilder();
         final InternalTopologyBuilder internalTopologyBuilder = TopologyWrapper.getInternalTopologyBuilder(builder.build());
@@ -153,6 +197,7 @@ public class StreamsPartitionAssignorTest {
         EasyMock.expect(taskManager.adminClient()).andReturn(null).anyTimes();
         EasyMock.expect(taskManager.builder()).andReturn(builder).anyTimes();
         EasyMock.expect(taskManager.previousRunningTaskIds()).andReturn(prevTasks).anyTimes();
+        EasyMock.expect(taskManager.activeTaskIds()).andReturn(prevTasks).anyTimes();
         EasyMock.expect(taskManager.cachedTasksIds()).andReturn(cachedTasks).anyTimes();
         EasyMock.expect(taskManager.processId()).andReturn(processId).anyTimes();
     }
@@ -169,6 +214,141 @@ public class StreamsPartitionAssignorTest {
     }
 
     @Test
+    public void shouldUseEagerRebalancingProtocol() {
+        createMockTaskManager();
+        final Map<String, Object> props = configProps();
+        props.put(StreamsConfig.UPGRADE_FROM_CONFIG, StreamsConfig.UPGRADE_FROM_23);
+        partitionAssignor.configure(props);
+
+        assertEquals(1, partitionAssignor.supportedProtocols().size());
+        assertTrue(partitionAssignor.supportedProtocols().contains(RebalanceProtocol.EAGER));
+        assertFalse(partitionAssignor.supportedProtocols().contains(RebalanceProtocol.COOPERATIVE));
+    }
+
+    @Test
+    public void shouldUseCooperativeRebalancingProtocol() {
+        createMockTaskManager();
+        final Map<String, Object> props = configProps();
+        partitionAssignor.configure(props);
+
+        assertEquals(2, partitionAssignor.supportedProtocols().size());
+        assertTrue(partitionAssignor.supportedProtocols().contains(RebalanceProtocol.COOPERATIVE));
+    }
+
+    @Test
+    public void shouldProduceStickyAndBalancedAssignmentWhenNothingChanges() {
+        configureDefault();
+        final ClientState state = new ClientState();
+        final List<TaskId> allTasks = Arrays.asList(task0_0, task0_1, task0_2, task0_3, task1_0, task1_1, task1_2, task1_3);
+
+        final Map<String, List<TaskId>> previousAssignment = new HashMap<String, List<TaskId>>() {{
+                put(c1, Arrays.asList(task0_0, task1_1, task1_3));
+                put(c2, Arrays.asList(task0_3, task1_0));
+                put(c3, Arrays.asList(task0_1, task0_2, task1_2));
+            }};
+
+        for (final Map.Entry<String, List<TaskId>> entry : previousAssignment.entrySet()) {
+            for (final TaskId task : entry.getValue()) {
+                state.addOwnedPartitions(partitionsForTask.get(task), entry.getKey());
+            }
+        }
+
+        final Set<String> consumers = Utils.mkSet(c1, c2, c3);
+        state.assignActiveTasks(allTasks);
+
+        assertEquivalentAssignment(previousAssignment,
+            partitionAssignor.tryStickyAndBalancedTaskAssignmentWithinClient(state, consumers, partitionsForTask, Collections.emptySet()));
+    }
+
+    @Test
+    public void shouldProduceStickyAndBalancedAssignmentWhenNewTasksAreAdded() {
+        configureDefault();
+        final ClientState state = new ClientState();
+
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2, task0_3, task1_0, task1_1, task1_2, task1_3);
+
+        final Map<String, List<TaskId>> previousAssignment = new HashMap<String, List<TaskId>>() {{
+                put(c1, new ArrayList<>(Arrays.asList(task0_0, task1_1, task1_3)));
+                put(c2, new ArrayList<>(Arrays.asList(task0_3, task1_0)));
+                put(c3, new ArrayList<>(Arrays.asList(task0_1, task0_2, task1_2)));
+            }};
+
+        for (final Map.Entry<String, List<TaskId>> entry : previousAssignment.entrySet()) {
+            for (final TaskId task : entry.getValue()) {
+                state.addOwnedPartitions(partitionsForTask.get(task), entry.getKey());
+            }
+        }
+
+        final Set<String> consumers = Utils.mkSet(c1, c2, c3);
+
+        // We should be able to add a new task without sacrificing stickyness
+        final TaskId newTask = task2_0;
+        allTasks.add(newTask);
+        state.assignActiveTasks(allTasks);
+
+        final Map<String, List<TaskId>> newAssignment = partitionAssignor.tryStickyAndBalancedTaskAssignmentWithinClient(state, consumers, partitionsForTask, Collections.emptySet());
+
+        previousAssignment.get(c2).add(newTask);
+        assertEquivalentAssignment(previousAssignment, newAssignment);
+    }
+
+    @Test
+    public void shouldReturnEmptyMapWhenStickyAndBalancedAssignmentIsNotPossibleBecauseNewConsumerJoined() {
+        configureDefault();
+        final ClientState state = new ClientState();
+
+        final List<TaskId> allTasks = Arrays.asList(task0_0, task0_1, task0_2, task0_3, task1_0, task1_1, task1_2, task1_3);
+
+        final Map<String, List<TaskId>> previousAssignment = new HashMap<String, List<TaskId>>() {{
+                put(c1, Arrays.asList(task0_0, task1_1, task1_3));
+                put(c2, Arrays.asList(task0_3, task1_0));
+                put(c3, Arrays.asList(task0_1, task0_2, task1_2));
+            }};
+
+        for (final Map.Entry<String, List<TaskId>> entry : previousAssignment.entrySet()) {
+            for (final TaskId task : entry.getValue()) {
+                state.addOwnedPartitions(partitionsForTask.get(task), entry.getKey());
+            }
+        }
+
+        // If we add a new consumer here, we cannot produce an assignment that is both sticky and balanced
+        final Set<String> consumers = Utils.mkSet(c1, c2, c3, c4);
+        state.assignActiveTasks(allTasks);
+
+        assertThat(partitionAssignor.tryStickyAndBalancedTaskAssignmentWithinClient(state, consumers, partitionsForTask, Collections.emptySet()),
+            equalTo(Collections.emptyMap()));
+    }
+
+    @Test
+    public void shouldReturnEmptyMapWhenStickyAndBalancedAssignmentIsNotPossibleBecauseOtherClientOwnedPartition() {
+        configureDefault();
+        final ClientState state = new ClientState();
+
+        final List<TaskId> allTasks = Arrays.asList(task0_0, task0_1, task0_2, task0_3, task1_0, task1_1, task1_2, task1_3);
+
+        final Map<String, List<TaskId>> previousAssignment = new HashMap<String, List<TaskId>>() {{
+                put(c1, new ArrayList<>(Arrays.asList(task1_1, task1_3)));
+                put(c2, new ArrayList<>(Arrays.asList(task0_3, task1_0)));
+                put(c3, new ArrayList<>(Arrays.asList(task0_1, task0_2, task1_2)));
+            }};
+
+        for (final Map.Entry<String, List<TaskId>> entry : previousAssignment.entrySet()) {
+            for (final TaskId task : entry.getValue()) {
+                state.addOwnedPartitions(partitionsForTask.get(task), entry.getKey());
+            }
+        }
+
+        // Add the partitions of task0_0 to allOwnedPartitions but not c1's ownedPartitions/previousAssignment
+        final Set<TopicPartition> allOwnedPartitions = new HashSet<>(partitionsForTask.get(task0_0));
+
+        final Set<String> consumers = Utils.mkSet(c1, c2, c3);
+        state.assignActiveTasks(allTasks);
+
+        assertThat(partitionAssignor.tryStickyAndBalancedTaskAssignmentWithinClient(state, consumers, partitionsForTask, allOwnedPartitions),
+            equalTo(Collections.emptyMap()));
+    }
+
+    @Test
     public void shouldInterleaveTasksByGroupId() {
         final TaskId taskIdA0 = new TaskId(0, 0);
         final TaskId taskIdA1 = new TaskId(0, 1);
@@ -182,21 +362,31 @@ public class StreamsPartitionAssignorTest {
         final TaskId taskIdC0 = new TaskId(2, 0);
         final TaskId taskIdC1 = new TaskId(2, 1);
 
+        final String c1 = "c1";
+        final String c2 = "c2";
+        final String c3 = "c3";
+
+        final Set<String> consumers = Utils.mkSet(c1, c2, c3);
+
         final List<TaskId> expectedSubList1 = asList(taskIdA0, taskIdA3, taskIdB2);
         final List<TaskId> expectedSubList2 = asList(taskIdA1, taskIdB0, taskIdC0);
         final List<TaskId> expectedSubList3 = asList(taskIdA2, taskIdB1, taskIdC1);
-        final List<List<TaskId>> embeddedList = asList(expectedSubList1, expectedSubList2, expectedSubList3);
+
+        final Map<String, List<TaskId>> assignment = new HashMap<>();
+        assignment.put(c1, expectedSubList1);
+        assignment.put(c2, expectedSubList2);
+        assignment.put(c3, expectedSubList3);
 
         final List<TaskId> tasks = asList(taskIdC0, taskIdC1, taskIdB0, taskIdB1, taskIdB2, taskIdA0, taskIdA1, taskIdA2, taskIdA3);
         Collections.shuffle(tasks);
 
-        final List<List<TaskId>> interleavedTaskIds = StreamsPartitionAssignor.interleaveTasksByGroupId(tasks, 3);
+        final Map<String, List<TaskId>> interleavedTaskIds = StreamsPartitionAssignor.interleaveConsumerTasksByGroupId(tasks, consumers);
 
-        assertThat(interleavedTaskIds, equalTo(embeddedList));
+        assertThat(interleavedTaskIds, equalTo(assignment));
     }
 
     @Test
-    public void testSubscription() {
+    public void testEagerSubscription() {
         builder.addSource(null, "source1", null, null, null, "topic1");
         builder.addSource(null, "source2", null, null, null, "topic2");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
@@ -212,6 +402,7 @@ public class StreamsPartitionAssignorTest {
         EasyMock.replay(taskManager);
 
         configurePartitionAssignor(Collections.emptyMap());
+        partitionAssignor.setRebalanceProtocol(RebalanceProtocol.EAGER);
 
         final Set<String> topics = Utils.mkSet("topic1", "topic2");
         final ConsumerPartitionAssignor.Subscription subscription = new ConsumerPartitionAssignor.Subscription(new ArrayList<>(topics), partitionAssignor.subscriptionUserData(topics));
@@ -222,24 +413,60 @@ public class StreamsPartitionAssignorTest {
         final Set<TaskId> standbyTasks = new HashSet<>(cachedTasks);
         standbyTasks.removeAll(prevTasks);
 
+        // When following the eager protocol, we must encode the previous tasks ourselves since we must revoke
+        // everything and thus the "ownedPartitions" field in the subscription will be empty
         final SubscriptionInfo info = new SubscriptionInfo(processId, prevTasks, standbyTasks, null);
         assertEquals(info.encode(), subscription.userData());
     }
 
     @Test
+    public void testCooperativeSubscription() {
+        builder.addSource(null, "source1", null, null, null, "topic1");
+        builder.addSource(null, "source2", null, null, null, "topic2");
+        builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
+
+        final Set<TaskId> prevTasks = Utils.mkSet(
+            new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1));
+        final Set<TaskId> cachedTasks = Utils.mkSet(
+            new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1),
+            new TaskId(0, 2), new TaskId(1, 2), new TaskId(2, 2));
+
+        final UUID processId = UUID.randomUUID();
+        createMockTaskManager(prevTasks, cachedTasks, processId, builder);
+        EasyMock.replay(taskManager);
+
+        configurePartitionAssignor(Collections.emptyMap());
+
+        final Set<String> topics = Utils.mkSet("topic1", "topic2");
+        final ConsumerPartitionAssignor.Subscription subscription = new ConsumerPartitionAssignor.Subscription(
+            new ArrayList<>(topics), partitionAssignor.subscriptionUserData(topics));
+
+        Collections.sort(subscription.topics());
+        assertEquals(asList("topic1", "topic2"), subscription.topics());
+
+        final Set<TaskId> standbyTasks = new HashSet<>(cachedTasks);
+        standbyTasks.removeAll(prevTasks);
+
+        // We don't encode the active tasks when following the cooperative protocol, as these are inferred from the
+        // ownedPartitions encoded in the subscription
+        final SubscriptionInfo info = new SubscriptionInfo(processId, Collections.emptySet(), standbyTasks, null);
+        assertEquals(info.encode(), subscription.userData());
+    }
+
+    @Test
     public void testAssignBasic() {
         builder.addSource(null, "source1", null, null, null, "topic1");
         builder.addSource(null, "source2", null, null, null, "topic2");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
         final List<String> topics = asList("topic1", "topic2");
-        final Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2);
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2);
 
-        final Set<TaskId> prevTasks10 = Utils.mkSet(task0);
-        final Set<TaskId> prevTasks11 = Utils.mkSet(task1);
-        final Set<TaskId> prevTasks20 = Utils.mkSet(task2);
-        final Set<TaskId> standbyTasks10 = Utils.mkSet(task1);
-        final Set<TaskId> standbyTasks11 = Utils.mkSet(task2);
-        final Set<TaskId> standbyTasks20 = Utils.mkSet(task0);
+        final Set<TaskId> prevTasks10 = Utils.mkSet(task0_0);
+        final Set<TaskId> prevTasks11 = Utils.mkSet(task0_1);
+        final Set<TaskId> prevTasks20 = Utils.mkSet(task0_2);
+        final Set<TaskId> standbyTasks10 = Utils.mkSet(task0_1);
+        final Set<TaskId> standbyTasks11 = Utils.mkSet(task0_2);
+        final Set<TaskId> standbyTasks20 = Utils.mkSet(task0_0);
 
         final UUID uuid1 = UUID.randomUUID();
         final UUID uuid2 = UUID.randomUUID();
@@ -251,14 +478,20 @@ public class StreamsPartitionAssignorTest {
         partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
 
         subscriptions.put("consumer10",
-                new ConsumerPartitionAssignor.Subscription(topics,
-                        new SubscriptionInfo(uuid1, prevTasks10, standbyTasks10, userEndPoint).encode()));
+                new ConsumerPartitionAssignor.Subscription(
+                    topics,
+                    new SubscriptionInfo(uuid1, prevTasks10, standbyTasks10, userEndPoint).encode(),
+                    Collections.singletonList(t1p0)));
         subscriptions.put("consumer11",
-                new ConsumerPartitionAssignor.Subscription(topics,
-                        new SubscriptionInfo(uuid1, prevTasks11, standbyTasks11, userEndPoint).encode()));
+                new ConsumerPartitionAssignor.Subscription(
+                    topics,
+                    new SubscriptionInfo(uuid1, prevTasks11, standbyTasks11, userEndPoint).encode(),
+                    Collections.singletonList(t1p1)));
         subscriptions.put("consumer20",
-                new ConsumerPartitionAssignor.Subscription(topics,
-                        new SubscriptionInfo(uuid2, prevTasks20, standbyTasks20, userEndPoint).encode()));
+                new ConsumerPartitionAssignor.Subscription(
+                    topics,
+                    new SubscriptionInfo(uuid2, prevTasks20, standbyTasks20, userEndPoint).encode(),
+                    Collections.singletonList(t1p2)));
 
         final Map<String, ConsumerPartitionAssignor.Assignment> assignments = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment();
 
@@ -277,7 +510,7 @@ public class StreamsPartitionAssignorTest {
         final AssignmentInfo info11 = checkAssignment(allTopics, assignments.get("consumer11"));
         allActiveTasks.addAll(info11.activeTasks());
 
-        assertEquals(Utils.mkSet(task0, task1), allActiveTasks);
+        assertEquals(Utils.mkSet(task0_0, task0_1), allActiveTasks);
 
         // the third consumer
         final AssignmentInfo info20 = checkAssignment(allTopics, assignments.get("consumer20"));
@@ -351,12 +584,12 @@ public class StreamsPartitionAssignorTest {
         // the first consumer
         final AssignmentInfo info10 = AssignmentInfo.decode(assignments.get("consumer10").userData());
 
-        final List<TaskId> expectedInfo10TaskIds = asList(taskIdA1, taskIdA3, taskIdB1, taskIdB3);
+        final List<TaskId> expectedInfo10TaskIds = asList(taskIdA0, taskIdA2, taskIdB0, taskIdB2);
         assertEquals(expectedInfo10TaskIds, info10.activeTasks());
 
         // the second consumer
         final AssignmentInfo info11 = AssignmentInfo.decode(assignments.get("consumer11").userData());
-        final List<TaskId> expectedInfo11TaskIds = asList(taskIdA0, taskIdA2, taskIdB0, taskIdB2);
+        final List<TaskId> expectedInfo11TaskIds = asList(taskIdA1, taskIdA3, taskIdB1, taskIdB3);
 
         assertEquals(expectedInfo11TaskIds, info11.activeTasks());
     }
@@ -371,7 +604,7 @@ public class StreamsPartitionAssignorTest {
         builder.addProcessor("processor2", new MockProcessorSupplier(), "source2");
         builder.addStateStore(new MockKeyValueStoreBuilder("store2", false), "processor2");
         final List<String> topics = asList("topic1", "topic2");
-        final Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2);
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2);
 
         final UUID uuid1 = UUID.randomUUID();
 
@@ -403,10 +636,10 @@ public class StreamsPartitionAssignorTest {
         builder.addSource(null, "source2", null, null, null, "topic2");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
         final List<String> topics = asList("topic1", "topic2");
-        final Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2);
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2);
 
-        final Set<TaskId> prevTasks10 = Utils.mkSet(task0);
-        final Set<TaskId> standbyTasks10 = Utils.mkSet(task1);
+        final Set<TaskId> prevTasks10 = Utils.mkSet(task0_0);
+        final Set<TaskId> standbyTasks10 = Utils.mkSet(task0_1);
         final  Cluster emptyMetadata = new Cluster("cluster", Collections.singletonList(Node.noNode()),
             Collections.emptySet(),
             Collections.emptySet(),
@@ -459,12 +692,12 @@ public class StreamsPartitionAssignorTest {
         builder.addSource(null, "source3", null, null, null, "topic3");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2", "source3");
         final List<String> topics = asList("topic1", "topic2", "topic3");
-        final Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2, task3);
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2, task0_3);
 
         // assuming that previous tasks do not have topic3
-        final Set<TaskId> prevTasks10 = Utils.mkSet(task0);
-        final Set<TaskId> prevTasks11 = Utils.mkSet(task1);
-        final Set<TaskId> prevTasks20 = Utils.mkSet(task2);
+        final Set<TaskId> prevTasks10 = Utils.mkSet(task0_0);
+        final Set<TaskId> prevTasks11 = Utils.mkSet(task0_1);
+        final Set<TaskId> prevTasks20 = Utils.mkSet(task0_2);
 
         final UUID uuid1 = UUID.randomUUID();
         final UUID uuid2 = UUID.randomUUID();
@@ -606,15 +839,15 @@ public class StreamsPartitionAssignorTest {
         builder.addSource(null, "source2", null, null, null, "topic2");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
         final List<String> topics = asList("topic1", "topic2");
-        final Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2);
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2);
 
 
-        final Set<TaskId> prevTasks00 = Utils.mkSet(task0);
-        final Set<TaskId> prevTasks01 = Utils.mkSet(task1);
-        final Set<TaskId> prevTasks02 = Utils.mkSet(task2);
-        final Set<TaskId> standbyTasks01 = Utils.mkSet(task1);
-        final Set<TaskId> standbyTasks02 = Utils.mkSet(task2);
-        final Set<TaskId> standbyTasks00 = Utils.mkSet(task0);
+        final Set<TaskId> prevTasks00 = Utils.mkSet(task0_0);
+        final Set<TaskId> prevTasks01 = Utils.mkSet(task0_1);
+        final Set<TaskId> prevTasks02 = Utils.mkSet(task0_2);
+        final Set<TaskId> standbyTasks01 = Utils.mkSet(task0_1);
+        final Set<TaskId> standbyTasks02 = Utils.mkSet(task0_2);
+        final Set<TaskId> standbyTasks00 = Utils.mkSet(task0_0);
 
         final UUID uuid1 = UUID.randomUUID();
         final UUID uuid2 = UUID.randomUUID();
@@ -651,8 +884,8 @@ public class StreamsPartitionAssignorTest {
         assertNotEquals("same processId has same set of standby tasks", info11.standbyTasks().keySet(), info10.standbyTasks().keySet());
 
         // check active tasks assigned to the first client
-        assertEquals(Utils.mkSet(task0, task1), new HashSet<>(allActiveTasks));
-        assertEquals(Utils.mkSet(task2), new HashSet<>(allStandbyTasks));
+        assertEquals(Utils.mkSet(task0_0, task0_1), new HashSet<>(allActiveTasks));
+        assertEquals(Utils.mkSet(task0_2), new HashSet<>(allStandbyTasks));
 
         // the third consumer
         final AssignmentInfo info20 = checkAssignment(allTopics, assignments.get("consumer20"));
@@ -678,11 +911,11 @@ public class StreamsPartitionAssignorTest {
         EasyMock.expectLastCall();
 
         final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        activeTasks.put(task0, Utils.mkSet(t3p0));
-        activeTasks.put(task3, Utils.mkSet(t3p3));
+        activeTasks.put(task0_0, Utils.mkSet(t3p0));
+        activeTasks.put(task0_3, Utils.mkSet(t3p3));
         final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
-        standbyTasks.put(task1, Utils.mkSet(t3p1));
-        standbyTasks.put(task2, Utils.mkSet(t3p2));
+        standbyTasks.put(task0_1, Utils.mkSet(t3p1));
+        standbyTasks.put(task0_2, Utils.mkSet(t3p2));
         taskManager.setAssignmentMetadata(activeTasks, standbyTasks);
         EasyMock.expectLastCall();
 
@@ -693,7 +926,7 @@ public class StreamsPartitionAssignorTest {
         EasyMock.replay(taskManager);
 
         configurePartitionAssignor(Collections.emptyMap());
-        final List<TaskId> activeTaskList = asList(task0, task3);
+        final List<TaskId> activeTaskList = asList(task0_0, task0_3);
         final AssignmentInfo info = new AssignmentInfo(activeTaskList, standbyTasks, hostState);
         final ConsumerPartitionAssignor.Assignment assignment = new ConsumerPartitionAssignor.Assignment(asList(t3p0, t3p3), info.encode());
 
@@ -715,7 +948,7 @@ public class StreamsPartitionAssignorTest {
         builder.addSource(null, "source2", null, null, null, "topicX");
         builder.addProcessor("processor2", new MockProcessorSupplier(), "source2");
         final List<String> topics = asList("topic1", applicationId + "-topicX");
-        final Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2);
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2);
 
         final UUID uuid1 = UUID.randomUUID();
         createMockTaskManager(emptyTasks, emptyTasks, uuid1, builder);
@@ -749,7 +982,7 @@ public class StreamsPartitionAssignorTest {
         builder.addSink("sink2", "topicZ", null, null, null, "processor2");
         builder.addSource(null, "source3", null, null, null, "topicZ");
         final List<String> topics = asList("topic1", "test-topicX", "test-topicZ");
-        final Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2);
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2);
 
         final UUID uuid1 = UUID.randomUUID();
         createMockTaskManager(emptyTasks, emptyTasks, uuid1, builder);
@@ -1197,36 +1430,95 @@ public class StreamsPartitionAssignorTest {
     }
 
     @Test
-    public void shouldReturnUnchangedAssignmentForOldInstancesAndEmptyAssignmentForFutureInstances() {
+    public void shouldReturnInterleavedAssignmentWithUnrevokedPartitionsRemovedWhenNewConsumerJoins() {
         builder.addSource(null, "source1", null, null, null, "topic1");
 
-        final Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2);
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2);
+
+        subscriptions.put(c1,
+            new ConsumerPartitionAssignor.Subscription(
+                Collections.singletonList("topic1"),
+                new SubscriptionInfo(UUID.randomUUID(), allTasks, Collections.emptySet(), null).encode(),
+                Arrays.asList(t1p0, t1p1, t1p2))
+        );
+        subscriptions.put(c2,
+            new ConsumerPartitionAssignor.Subscription(
+                Collections.singletonList("topic1"),
+                new SubscriptionInfo(UUID.randomUUID(), Collections.emptySet(), Collections.emptySet(), null).encode(),
+                Collections.emptyList())
+        );
 
-        final Set<TaskId> activeTasks = Utils.mkSet(task0, task1);
-        final Set<TaskId> standbyTasks = Utils.mkSet(task2);
+        createMockTaskManager(allTasks, allTasks, UUID.randomUUID(), builder);
+        EasyMock.replay(taskManager);
+        partitionAssignor.configure(configProps());
+
+        final Map<String, ConsumerPartitionAssignor.Assignment> assignment = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment();
+
+        assertThat(assignment.size(), equalTo(2));
+
+        assertThat(assignment.get(c1).partitions(), equalTo(asList(t1p0, t1p2)));
+        assertThat(
+            AssignmentInfo.decode(assignment.get(c1).userData()),
+            equalTo(new AssignmentInfo(
+                Arrays.asList(task0_0, task0_2),
+                Collections.emptyMap(),
+                Collections.emptyMap()
+            )));
+
+        // The new consumer's assignment should be empty until c1 has the chance to revoke its partitions/tasks
+        assertThat(assignment.get(c2).partitions(), equalTo(Collections.emptyList()));
+        assertThat(
+            AssignmentInfo.decode(assignment.get(c2).userData()),
+            equalTo(new AssignmentInfo(
+                Collections.emptyList(),
+                Collections.emptyMap(),
+                Collections.emptyMap()
+            )));
+    }
+
+    @Test
+    public void shouldReturnNormalAssignmentForOldAndFutureInstancesDuringVersionProbing() {
+        builder.addSource(null, "source1", null, null, null, "topic1");
+
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2);
+
+        final Set<TaskId> activeTasks = Utils.mkSet(task0_0, task0_1);
+        final Set<TaskId> standbyTasks = Utils.mkSet(task0_2);
         final Map<TaskId, Set<TopicPartition>> standbyTaskMap = new HashMap<TaskId, Set<TopicPartition>>() {
             {
-                put(task2, Collections.singleton(t1p2));
+                put(task0_2, Collections.singleton(t1p2));
+            }
+        };
+        final Map<TaskId, Set<TopicPartition>> futureStandbyTaskMap = new HashMap<TaskId, Set<TopicPartition>>() {
+            {
+                put(task0_0, Collections.singleton(t1p0));
+                put(task0_1, Collections.singleton(t1p1));
             }
         };
 
         subscriptions.put("consumer1",
                 new ConsumerPartitionAssignor.Subscription(
                         Collections.singletonList("topic1"),
-                        new SubscriptionInfo(UUID.randomUUID(), activeTasks, standbyTasks, null).encode())
+                        new SubscriptionInfo(UUID.randomUUID(), activeTasks, standbyTasks, null).encode(),
+                        Arrays.asList(t1p0, t1p1))
         );
         subscriptions.put("future-consumer",
                 new ConsumerPartitionAssignor.Subscription(
                         Collections.singletonList("topic1"),
-                        encodeFutureSubscription())
+                        encodeFutureSubscription(),
+                        Collections.singletonList(t1p2))
         );
 
         createMockTaskManager(allTasks, allTasks, UUID.randomUUID(), builder);
         EasyMock.replay(taskManager);
-        partitionAssignor.configure(configProps());
+        final Map<String, Object> props = configProps();
+        props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
+        partitionAssignor.configure(props);
         final Map<String, ConsumerPartitionAssignor.Assignment> assignment = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment();
 
         assertThat(assignment.size(), equalTo(2));
+
+        assertThat(assignment.get("consumer1").partitions(), equalTo(asList(t1p0, t1p1)));
         assertThat(
             AssignmentInfo.decode(assignment.get("consumer1").userData()),
             equalTo(new AssignmentInfo(
@@ -1234,10 +1526,64 @@ public class StreamsPartitionAssignorTest {
                 standbyTaskMap,
                 Collections.emptyMap()
             )));
-        assertThat(assignment.get("consumer1").partitions(), equalTo(asList(t1p0, t1p1)));
 
-        assertThat(AssignmentInfo.decode(assignment.get("future-consumer").userData()), equalTo(new AssignmentInfo(LATEST_SUPPORTED_VERSION, LATEST_SUPPORTED_VERSION)));
-        assertThat(assignment.get("future-consumer").partitions().size(), equalTo(0));
+
+        assertThat(assignment.get("future-consumer").partitions(), equalTo(Collections.singletonList(t1p2)));
+        assertThat(
+            AssignmentInfo.decode(assignment.get("future-consumer").userData()),
+            equalTo(new AssignmentInfo(
+                Collections.singletonList(task0_2),
+                futureStandbyTaskMap,
+                Collections.emptyMap()
+            )));
+    }
+
+    @Test
+    public void shouldReturnInterleavedAssignmentForOnlyFutureInstancesDuringVersionProbing() {
+        builder.addSource(null, "source1", null, null, null, "topic1");
+
+        final Set<TaskId> allTasks = Utils.mkSet(task0_0, task0_1, task0_2);
+
+        subscriptions.put(c1,
+            new ConsumerPartitionAssignor.Subscription(
+                Collections.singletonList("topic1"),
+                encodeFutureSubscription(),
+                Collections.emptyList())
+        );
+        subscriptions.put(c2,
+            new ConsumerPartitionAssignor.Subscription(
+                Collections.singletonList("topic1"),
+                encodeFutureSubscription(),
+                Collections.emptyList())
+        );
+
+        createMockTaskManager(allTasks, allTasks, UUID.randomUUID(), builder);
+        EasyMock.replay(taskManager);
+        final Map<String, Object> props = configProps();
+        props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
+        partitionAssignor.configure(props);
+        final Map<String, ConsumerPartitionAssignor.Assignment> assignment = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment();
+
+        assertThat(assignment.size(), equalTo(2));
+
+        assertThat(assignment.get(c1).partitions(), equalTo(asList(t1p0, t1p2)));
+        assertThat(
+            AssignmentInfo.decode(assignment.get(c1).userData()),
+            equalTo(new AssignmentInfo(
+                Arrays.asList(task0_0, task0_2),
+                Collections.emptyMap(),
+                Collections.emptyMap()
+            )));
+
+
+        assertThat(assignment.get(c2).partitions(), equalTo(Collections.singletonList(t1p1)));
+        assertThat(
+            AssignmentInfo.decode(assignment.get(c2).userData()),
+            equalTo(new AssignmentInfo(
+                Collections.singletonList(task0_1),
+                Collections.emptyMap(),
+                Collections.emptyMap()
+            )));
     }
 
     @Test
@@ -1334,4 +1680,21 @@ public class StreamsPartitionAssignorTest {
 
         return info;
     }
+
+    private void assertEquivalentAssignment(final Map<String, List<TaskId>> thisAssignment,
+                                            final Map<String, List<TaskId>> otherAssignment) {
+        assertEquals(thisAssignment.size(), otherAssignment.size());
+        for (final Map.Entry<String, List<TaskId>> entry : thisAssignment.entrySet()) {
+            final String consumer = entry.getKey();
+            assertTrue(otherAssignment.containsKey(consumer));
+
+            final List<TaskId> thisTaskList = entry.getValue();
+            Collections.sort(thisTaskList);
+            final List<TaskId> otherTaskList = otherAssignment.get(consumer);
+            Collections.sort(otherTaskList);
+
+            assertThat(thisTaskList, equalTo(otherTaskList));
+        }
+    }
+
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 7e1ca7f..e46a4cc 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -42,7 +42,6 @@ import org.junit.runner.RunWith;
 
 import java.io.File;
 import java.io.IOException;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -245,7 +244,8 @@ public class TaskManagerTest {
     @Test
     public void shouldCloseActiveUnAssignedSuspendedTasksWhenClosingRevokedTasks() {
         mockSingleActiveTask();
-        EasyMock.expect(active.closeNotAssignedSuspendedTasks(taskId0Assignment.keySet())).andReturn(null).once();
+
+        expect(active.closeNotAssignedSuspendedTasks(taskId0Assignment.keySet())).andReturn(null).once();
         expect(restoreConsumer.assignment()).andReturn(Collections.emptySet());
 
         replay();
@@ -281,6 +281,7 @@ public class TaskManagerTest {
         // Need to call this twice so task manager doesn't consider all partitions "new"
         taskManager.setAssignmentMetadata(taskId0Assignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
         taskManager.setAssignmentMetadata(taskId0Assignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
+
         taskManager.setPartitionsToTaskId(taskId0PartitionToTaskId);
         taskManager.createTasks(taskId0Partitions);
 
@@ -404,9 +405,7 @@ public class TaskManagerTest {
 
     @Test
     public void shouldInitializeNewActiveTasks() {
-        EasyMock.expect(restoreConsumer.assignment()).andReturn(Collections.emptySet()).once();
-        EasyMock.expect(changeLogReader.restore(active)).andReturn(taskId0Partitions).once();
-        active.updateRestored(EasyMock.<Collection<TopicPartition>>anyObject());
+        active.initializeNewTasks();
         expectLastCall();
         replay();
 
@@ -416,9 +415,7 @@ public class TaskManagerTest {
 
     @Test
     public void shouldInitializeNewStandbyTasks() {
-        EasyMock.expect(restoreConsumer.assignment()).andReturn(Collections.emptySet()).once();
-        EasyMock.expect(changeLogReader.restore(active)).andReturn(taskId0Partitions).once();
-        active.updateRestored(EasyMock.<Collection<TopicPartition>>anyObject());
+        standby.initializeNewTasks();
         expectLastCall();
         replay();
 
@@ -428,6 +425,7 @@ public class TaskManagerTest {
 
     @Test
     public void shouldRestoreStateFromChangeLogReader() {
+        EasyMock.expect(active.hasRestoringTasks()).andReturn(true).once();
         EasyMock.expect(restoreConsumer.assignment()).andReturn(taskId0Partitions).once();
         expect(changeLogReader.restore(active)).andReturn(taskId0Partitions);
         active.updateRestored(taskId0Partitions);
@@ -440,11 +438,9 @@ public class TaskManagerTest {
 
     @Test
     public void shouldResumeRestoredPartitions() {
-        EasyMock.expect(restoreConsumer.assignment()).andReturn(taskId0Partitions).once();
-        expect(changeLogReader.restore(active)).andReturn(taskId0Partitions);
-        expect(active.allTasksRunning()).andReturn(true);
+        expect(active.allTasksRunning()).andReturn(true).once();
         expect(consumer.assignment()).andReturn(taskId0Partitions);
-        expect(standby.running()).andReturn(Collections.<StandbyTask>emptySet());
+        expect(standby.running()).andReturn(Collections.emptySet());
 
         consumer.resume(taskId0Partitions);
         expectLastCall();
@@ -666,6 +662,7 @@ public class TaskManagerTest {
     }
 
     private void mockAssignStandbyPartitions(final long offset) {
+        expect(active.hasRestoringTasks()).andReturn(true).once();
         final StandbyTask task = EasyMock.createNiceMock(StandbyTask.class);
         expect(active.allTasksRunning()).andReturn(true);
         expect(standby.running()).andReturn(Collections.singletonList(task));
@@ -679,13 +676,6 @@ public class TaskManagerTest {
         EasyMock.expect(changeLogReader.restore(active)).andReturn(taskId0Partitions).once();
     }
 
-    private void mockStandbyTaskExpectations() {
-        expect(standbyTaskCreator.createTasks(EasyMock.<Consumer<byte[], byte[]>>anyObject(),
-                                                   EasyMock.eq(taskId0Assignment)))
-                .andReturn(Collections.singletonList(standbyTask));
-
-    }
-
     private void mockSingleActiveTask() {
         expect(activeTaskCreator.createTasks(EasyMock.<Consumer<byte[], byte[]>>anyObject(),
                                                   EasyMock.eq(taskId0Assignment)))
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
index dc54c86..1443edf 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
@@ -70,8 +70,8 @@ public class ClientStateTest {
         final TaskId tid1 = new TaskId(0, 1);
         final TaskId tid2 = new TaskId(0, 2);
 
-        client.addPreviousActiveTasks("consumer", Utils.mkSet(tid1, tid2));
-        assertThat(client.previousActiveTasks(), equalTo(Utils.mkSet(tid1, tid2)));
+        client.addPreviousActiveTasks(Utils.mkSet(tid1, tid2));
+        assertThat(client.prevActiveTasks(), equalTo(Utils.mkSet(tid1, tid2)));
         assertThat(client.previousAssignedTasks(), equalTo(Utils.mkSet(tid1, tid2)));
     }
 
@@ -80,8 +80,8 @@ public class ClientStateTest {
         final TaskId tid1 = new TaskId(0, 1);
         final TaskId tid2 = new TaskId(0, 2);
 
-        client.addPreviousStandbyTasks("consumer", Utils.mkSet(tid1, tid2));
-        assertThat(client.previousActiveTasks().size(), equalTo(0));
+        client.addPreviousStandbyTasks(Utils.mkSet(tid1, tid2));
+        assertThat(client.prevActiveTasks().size(), equalTo(0));
         assertThat(client.previousAssignedTasks(), equalTo(Utils.mkSet(tid1, tid2)));
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
index 19d7730..17d403f 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
@@ -207,11 +207,11 @@ public class StickyTaskAssignorTest {
     @Test
     public void shouldAssignTasksToClientWithPreviousStandbyTasks() {
         final ClientState client1 = createClient(p1, 1);
-        client1.addPreviousStandbyTasks("consumer", Utils.mkSet(task02));
+        client1.addPreviousStandbyTasks(Utils.mkSet(task02));
         final ClientState client2 = createClient(p2, 1);
-        client2.addPreviousStandbyTasks("consumer", Utils.mkSet(task01));
+        client2.addPreviousStandbyTasks(Utils.mkSet(task01));
         final ClientState client3 = createClient(p3, 1);
-        client3.addPreviousStandbyTasks("consumer", Utils.mkSet(task00));
+        client3.addPreviousStandbyTasks(Utils.mkSet(task00));
 
         final StickyTaskAssignor taskAssignor = createTaskAssignor(task00, task01, task02);
 
@@ -225,9 +225,9 @@ public class StickyTaskAssignorTest {
     @Test
     public void shouldAssignBasedOnCapacityWhenMultipleClientHaveStandbyTasks() {
         final ClientState c1 = createClientWithPreviousActiveTasks(p1, 1, task00);
-        c1.addPreviousStandbyTasks("consumer", Utils.mkSet(task01));
+        c1.addPreviousStandbyTasks(Utils.mkSet(task01));
         final ClientState c2 = createClientWithPreviousActiveTasks(p2, 2, task02);
-        c2.addPreviousStandbyTasks("consumer", Utils.mkSet(task01));
+        c2.addPreviousStandbyTasks(Utils.mkSet(task01));
 
         final StickyTaskAssignor taskAssignor = createTaskAssignor(task00, task01, task02);
 
@@ -455,9 +455,9 @@ public class StickyTaskAssignorTest {
     @Test
     public void shouldNotHaveSameAssignmentOnAnyTwoHostsWhenThereArePreviousStandbyTasks() {
         final ClientState c1 = createClientWithPreviousActiveTasks(p1, 1, task01, task02);
-        c1.addPreviousStandbyTasks("consumer", Utils.mkSet(task03, task00));
+        c1.addPreviousStandbyTasks(Utils.mkSet(task03, task00));
         final ClientState c2 = createClientWithPreviousActiveTasks(p2, 1, task03, task00);
-        c2.addPreviousStandbyTasks("consumer", Utils.mkSet(task01, task02));
+        c2.addPreviousStandbyTasks(Utils.mkSet(task01, task02));
 
         createClient(p3, 1);
         createClient(p4, 1);
@@ -577,14 +577,14 @@ public class StickyTaskAssignorTest {
         final TaskId task23 = new TaskId(2, 3);
 
         final ClientState c1 = createClientWithPreviousActiveTasks(p1, 1, task01, task12, task13);
-        c1.addPreviousStandbyTasks("consumer", Utils.mkSet(task00, task11, task20, task21, task23));
+        c1.addPreviousStandbyTasks(Utils.mkSet(task00, task11, task20, task21, task23));
         final ClientState c2 = createClientWithPreviousActiveTasks(p2, 1, task00, task11, task22);
-        c2.addPreviousStandbyTasks("consumer", Utils.mkSet(task01, task10, task02, task20, task03, task12, task21, task13, task23));
+        c2.addPreviousStandbyTasks(Utils.mkSet(task01, task10, task02, task20, task03, task12, task21, task13, task23));
         final ClientState c3 = createClientWithPreviousActiveTasks(p3, 1, task20, task21, task23);
-        c3.addPreviousStandbyTasks("consumer", Utils.mkSet(task02, task12));
+        c3.addPreviousStandbyTasks(Utils.mkSet(task02, task12));
 
         final ClientState newClient = createClient(p4, 1);
-        newClient.addPreviousStandbyTasks("consumer", Utils.mkSet(task00, task10, task01, task02, task11, task20, task03, task12, task21, task13, task22, task23));
+        newClient.addPreviousStandbyTasks(Utils.mkSet(task00, task10, task01, task02, task11, task20, task03, task12, task21, task13, task22, task23));
 
         final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task10, task01, task02, task11, task20, task03, task12, task21, task13, task22, task23);
         taskAssignor.assign(0);
@@ -607,15 +607,15 @@ public class StickyTaskAssignorTest {
         final TaskId task23 = new TaskId(2, 3);
 
         final ClientState c1 = createClientWithPreviousActiveTasks(p1, 1, task01, task12, task13);
-        c1.addPreviousStandbyTasks("c1onsumer", Utils.mkSet(task00, task11, task20, task21, task23));
+        c1.addPreviousStandbyTasks(Utils.mkSet(task00, task11, task20, task21, task23));
         final ClientState c2 = createClientWithPreviousActiveTasks(p2, 1, task00, task11, task22);
-        c2.addPreviousStandbyTasks("consumer", Utils.mkSet(task01, task10, task02, task20, task03, task12, task21, task13, task23));
+        c2.addPreviousStandbyTasks(Utils.mkSet(task01, task10, task02, task20, task03, task12, task21, task13, task23));
 
         final ClientState bounce1 = createClient(p3, 1);
-        bounce1.addPreviousStandbyTasks("consumer", Utils.mkSet(task20, task21, task23));
+        bounce1.addPreviousStandbyTasks(Utils.mkSet(task20, task21, task23));
 
         final ClientState bounce2 = createClient(p4, 1);
-        bounce2.addPreviousStandbyTasks("consumer", Utils.mkSet(task02, task03, task10));
+        bounce2.addPreviousStandbyTasks(Utils.mkSet(task02, task03, task10));
 
         final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task10, task01, task02, task11, task20, task03, task12, task21, task13, task22, task23);
         taskAssignor.assign(0);
@@ -658,7 +658,7 @@ public class StickyTaskAssignorTest {
         final TaskId task06 = new TaskId(0, 6);
         final ClientState c1 = createClientWithPreviousActiveTasks(p1, 1, task00, task01, task02, task06);
         final ClientState c2 = createClient(p2, 1);
-        c2.addPreviousStandbyTasks("consumer", Utils.mkSet(task03, task04, task05));
+        c2.addPreviousStandbyTasks(Utils.mkSet(task03, task04, task05));
         final ClientState newClient = createClient(p3, 1);
 
         final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task01, task02, task03, task04, task05, task06);
@@ -705,7 +705,7 @@ public class StickyTaskAssignorTest {
 
     private ClientState createClientWithPreviousActiveTasks(final Integer processId, final int capacity, final TaskId... taskIds) {
         final ClientState clientState = new ClientState(capacity);
-        clientState.addPreviousActiveTasks("consumer", Utils.mkSet(taskIds));
+        clientState.addPreviousActiveTasks(Utils.mkSet(taskIds));
         clients.put(processId, clientState);
         return clientState;
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java b/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java
index 98e6e8f..496f89a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java
+++ b/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java
@@ -569,7 +569,7 @@ public class SmokeTestDriver extends SmokeTestUtil {
                 }
 
                 if (entry.getValue().getLast().value().longValue() != expectedCount) {
-                    resultStream.println("fail: key=" + key + " tagg=" + entry.getValue() + " expected=" + expected.get(key));
+                    resultStream.println("fail: key=" + key + " tagg=" + entry.getValue() + " expected=" + expectedCount);
                     resultStream.println("\t outputEvents: " + entry.getValue());
                     return false;
                 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java
index 0e07cac..185fa7c 100644
--- a/streams/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java
@@ -20,6 +20,7 @@ import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerGroupMetadata;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor;
+import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.PartitionInfo;
@@ -61,6 +62,8 @@ import static org.apache.kafka.streams.processor.internals.assignment.StreamsAss
 
 public class StreamsUpgradeTest {
 
+    private static final RebalanceProtocol REBALANCE_PROTOCOL = RebalanceProtocol.COOPERATIVE;
+
     @SuppressWarnings("unchecked")
     public static void main(final String[] args) throws Exception {
         if (args.length < 1) {
@@ -123,26 +126,26 @@ public class StreamsUpgradeTest {
             // 1. Client UUID (a unique id assigned to an instance of KafkaStreams)
             // 2. Task ids of previously running tasks
             // 3. Task ids of valid local states on the client's state directory.
-
             final TaskManager taskManager = taskManger();
-            final Set<TaskId> previousActiveTasks = taskManager.previousRunningTaskIds();
+
             final Set<TaskId> standbyTasks = taskManager.cachedTasksIds();
-            standbyTasks.removeAll(previousActiveTasks);
-            final FutureSubscriptionInfo data = new FutureSubscriptionInfo(
+            final Set<TaskId> activeTasks = prepareForSubscription(taskManager,
+                                                                   topics,
+                                                                   standbyTasks,
+                                                                   REBALANCE_PROTOCOL);
+            return new FutureSubscriptionInfo(
                 usedSubscriptionMetadataVersion,
                 LATEST_SUPPORTED_VERSION + 1,
                 taskManager.processId(),
-                previousActiveTasks,
+                activeTasks,
                 standbyTasks,
-                userEndPoint());
-
-            taskManager.updateSubscriptionsFromMetadata(topics);
-
-            return data.encode();
+                userEndPoint())
+                .encode();
         }
 
         @Override
-        public void onAssignment(final ConsumerPartitionAssignor.Assignment assignment, final ConsumerGroupMetadata metadata) {
+        public void onAssignment(final ConsumerPartitionAssignor.Assignment assignment,
+                                 final ConsumerGroupMetadata metadata) {
             try {
                 super.onAssignment(assignment, metadata);
                 return;
@@ -193,6 +196,7 @@ public class StreamsUpgradeTest {
             taskManager.setPartitionsToTaskId(partitionsToTaskId);
             taskManager.setAssignmentMetadata(activeTasks, info.standbyTasks());
             taskManager.updateSubscriptionsFromAssignment(partitions);
+            taskManager.setRebalanceInProgress(false);
         }
 
         @Override
diff --git a/tests/kafkatest/tests/streams/streams_eos_test.py b/tests/kafkatest/tests/streams/streams_eos_test.py
index 7e5cc26..428db9b 100644
--- a/tests/kafkatest/tests/streams/streams_eos_test.py
+++ b/tests/kafkatest/tests/streams/streams_eos_test.py
@@ -159,7 +159,7 @@ class StreamsEosTest(KafkaTest):
 
     def wait_for_startup(self, monitor, processor):
         self.wait_for(monitor, processor, "StateChange: REBALANCING -> RUNNING")
-        self.wait_for(monitor, processor, "processed 500 records from topic")
+        self.wait_for(monitor, processor, "processed [0-9]* records from topic")
 
     def wait_for(self, monitor, processor, output):
         monitor.wait_until(output,


Mime
View raw message