kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From guozh...@apache.org
Subject kafka git commit: KAFKA-4677: Avoid unnecessary task movement across threads during rebalance
Date Wed, 01 Feb 2017 04:16:55 GMT
Repository: kafka
Updated Branches:
  refs/heads/trunk 82744414d -> 0b48ea1c8


KAFKA-4677: Avoid unnecessary task movement across threads during rebalance

Makes task assignment more sticky by preferring to assign tasks to clients that had previously had the task as active task. If there are no clients with the task previously active, then search for a standby. Finally falling back to the least loaded client.

Author: Damian Guy <damian.guy@gmail.com>

Reviewers: Matthias J. Sax, Guozhang Wang

Closes #2429 from dguy/kafka-4677


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/0b48ea1c
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/0b48ea1c
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/0b48ea1c

Branch: refs/heads/trunk
Commit: 0b48ea1c81f22465cf32a19c012e0fb3c849afcc
Parents: 8274441
Author: Damian Guy <damian.guy@gmail.com>
Authored: Tue Jan 31 20:16:47 2017 -0800
Committer: Guozhang Wang <wangguoz@gmail.com>
Committed: Tue Jan 31 20:16:47 2017 -0800

----------------------------------------------------------------------
 .../internals/StreamPartitionAssignor.java      |  34 +-
 .../processor/internals/StreamThread.java       |  12 +-
 .../internals/assignment/ClientState.java       | 116 ++++-
 .../assignment/StickyTaskAssignor.java          | 283 ++++++++++
 .../internals/assignment/TaskAssignor.java      | 208 +-------
 .../kstream/internals/KTableAggregateTest.java  |   5 +
 .../internals/StreamPartitionAssignorTest.java  |  20 +-
 .../processor/internals/StreamThreadTest.java   |   4 +-
 .../internals/assignment/ClientStateTest.java   | 151 ++++++
 .../assignment/StickyTaskAssignorTest.java      | 515 +++++++++++++++++++
 .../internals/assignment/TaskAssignorTest.java  | 312 -----------
 11 files changed, 1085 insertions(+), 575 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignor.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignor.java
index 1ad6dbc..e17d96b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignor.java
@@ -33,8 +33,8 @@ import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.TopologyBuilder;
 import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
 import org.apache.kafka.streams.processor.internals.assignment.ClientState;
+import org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor;
 import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
-import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor;
 import org.apache.kafka.streams.state.HostInfo;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -105,13 +105,10 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable
         }
 
         void addConsumer(final String consumerMemberId, final SubscriptionInfo info) {
-
             consumers.add(consumerMemberId);
-
-            state.prevActiveTasks.addAll(info.prevTasks);
-            state.prevAssignedTasks.addAll(info.prevTasks);
-            state.prevAssignedTasks.addAll(info.standbyTasks);
-            state.capacity = state.capacity + 1d;
+            state.addPreviousActiveTasks(info.prevTasks);
+            state.addPreviousStandbyTasks(info.standbyTasks);
+            state.incrementCapacity();
         }
 
         @Override
@@ -228,10 +225,10 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable
         // 2. Task ids of previously running tasks
         // 3. Task ids of valid local states on the client's state directory.
 
-        Set<TaskId> prevTasks = streamThread.prevTasks();
+        final Set<TaskId> previousActiveTasks = streamThread.prevActiveTasks();
         Set<TaskId> standbyTasks = streamThread.cachedTasks();
-        standbyTasks.removeAll(prevTasks);
-        SubscriptionInfo data = new SubscriptionInfo(streamThread.processId, prevTasks, standbyTasks, this.userEndPoint);
+        standbyTasks.removeAll(previousActiveTasks);
+        SubscriptionInfo data = new SubscriptionInfo(streamThread.processId, previousActiveTasks, standbyTasks, this.userEndPoint);
 
         if (streamThread.builder.sourceTopicPattern() != null) {
             SubscriptionUpdates subscriptionUpdates = new SubscriptionUpdates();
@@ -461,7 +458,8 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable
         log.debug("stream-thread [{}] Assigning tasks {} to clients {} with number of replicas {}",
                 streamThread.getName(), partitionsForTask.keySet(), states, numStandbyReplicas);
 
-        TaskAssignor.assign(states, partitionsForTask.keySet(), numStandbyReplicas);
+        final StickyTaskAssignor<UUID> taskAssignor = new StickyTaskAssignor<>(states, partitionsForTask.keySet());
+        taskAssignor.assign(numStandbyReplicas);
 
         log.info("stream-thread [{}] Assigned tasks to clients as {}.", streamThread.getName(), states);
 
@@ -476,7 +474,7 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable
                 final Set<TopicPartition> topicPartitions = new HashSet<>();
                 final ClientState<TaskId> state = entry.getValue().state;
 
-                for (TaskId id : state.activeTasks) {
+                for (final TaskId id : state.activeTasks()) {
                     topicPartitions.addAll(partitionsForTask.get(id));
                 }
 
@@ -487,14 +485,14 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable
         // within the client, distribute tasks to its owned consumers
         Map<String, Assignment> assignment = new HashMap<>();
         for (Map.Entry<UUID, ClientMetadata> entry : clientsMetadata.entrySet()) {
-            Set<String> consumers = entry.getValue().consumers;
-            ClientState<TaskId> state = entry.getValue().state;
+            final Set<String> consumers = entry.getValue().consumers;
+            final ClientState<TaskId> state = entry.getValue().state;
 
-            ArrayList<TaskId> taskIds = new ArrayList<>(state.assignedTasks.size());
-            final int numActiveTasks = state.activeTasks.size();
+            final ArrayList<TaskId> taskIds = new ArrayList<>(state.assignedTaskCount());
+            final int numActiveTasks = state.activeTaskCount();
 
-            taskIds.addAll(state.activeTasks);
-            taskIds.addAll(state.standbyTasks);
+            taskIds.addAll(state.activeTasks());
+            taskIds.addAll(state.standbyTasks());
 
             final int numConsumers = consumers.size();
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 0128142..9bc268f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -191,7 +191,7 @@ public class StreamThread extends Thread {
     private final Map<TaskId, StandbyTask> standbyTasks;
     private final Map<TopicPartition, StreamTask> activeTasksByPartition;
     private final Map<TopicPartition, StandbyTask> standbyTasksByPartition;
-    private final Set<TaskId> prevTasks;
+    private final Set<TaskId> prevActiveTasks;
     private final Map<TaskId, StreamTask> suspendedTasks;
     private final Map<TaskId, StandbyTask> suspendedStandbyTasks;
     private final Time time;
@@ -331,7 +331,7 @@ public class StreamThread extends Thread {
         this.standbyTasks = new HashMap<>();
         this.activeTasksByPartition = new HashMap<>();
         this.standbyTasksByPartition = new HashMap<>();
-        this.prevTasks = new HashSet<>();
+        this.prevActiveTasks = new HashSet<>();
         this.suspendedTasks = new HashMap<>();
         this.suspendedStandbyTasks = new HashMap<>();
 
@@ -790,8 +790,8 @@ public class StreamThread extends Thread {
     /**
      * Returns ids of tasks that were being executed before the rebalance.
      */
-    public Set<TaskId> prevTasks() {
-        return Collections.unmodifiableSet(prevTasks);
+    public Set<TaskId> prevActiveTasks() {
+        return Collections.unmodifiableSet(prevActiveTasks);
     }
 
     /**
@@ -1019,8 +1019,8 @@ public class StreamThread extends Thread {
         log.info("{} Removing all active tasks [{}]", logPrefix, activeTasks.keySet());
 
         try {
-            prevTasks.clear();
-            prevTasks.addAll(activeTasks.keySet());
+            prevActiveTasks.clear();
+            prevActiveTasks.addAll(activeTasks.keySet());
 
             activeTasks.clear();
             activeTasksByPartition.clear();

http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
index 0746cab..c5577e5 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
@@ -17,40 +17,34 @@
 
 package org.apache.kafka.streams.processor.internals.assignment;
 
+
 import java.util.HashSet;
 import java.util.Set;
 
 public class ClientState<T> {
+    private final Set<T> activeTasks;
+    private final Set<T> standbyTasks;
+    private final Set<T> assignedTasks;
+    private final Set<T> prevActiveTasks;
+    private final Set<T> prevAssignedTasks;
 
-    final static double COST_ACTIVE = 0.1;
-    final static double COST_STANDBY  = 0.2;
-    final static double COST_LOAD = 0.5;
-
-    public final Set<T> activeTasks;
-    public final Set<T> standbyTasks;
-    public final Set<T> assignedTasks;
-    public final Set<T> prevActiveTasks;
-    public final Set<T> prevAssignedTasks;
-
-    public double capacity;
-    public double cost;
+    private int capacity;
 
     public ClientState() {
-        this(0d);
+        this(0);
     }
 
-    public ClientState(double capacity) {
+    ClientState(final int capacity) {
         this(new HashSet<T>(), new HashSet<T>(), new HashSet<T>(), new HashSet<T>(), new HashSet<T>(), capacity);
     }
 
-    private ClientState(Set<T> activeTasks, Set<T> standbyTasks, Set<T> assignedTasks, Set<T> prevActiveTasks, Set<T> prevAssignedTasks, double capacity) {
+    private ClientState(Set<T> activeTasks, Set<T> standbyTasks, Set<T> assignedTasks, Set<T> prevActiveTasks, Set<T> prevAssignedTasks, int capacity) {
         this.activeTasks = activeTasks;
         this.standbyTasks = standbyTasks;
         this.assignedTasks = assignedTasks;
         this.prevActiveTasks = prevActiveTasks;
         this.prevAssignedTasks = prevAssignedTasks;
         this.capacity = capacity;
-        this.cost = 0d;
     }
 
     public ClientState<T> copy() {
@@ -58,29 +52,103 @@ public class ClientState<T> {
                 new HashSet<>(prevActiveTasks), new HashSet<>(prevAssignedTasks), capacity);
     }
 
-    public void assign(T taskId, boolean active) {
-        if (active)
+    public void assign(final T taskId, final boolean active) {
+        if (active) {
             activeTasks.add(taskId);
-        else
+        } else {
             standbyTasks.add(taskId);
+        }
 
         assignedTasks.add(taskId);
+    }
+
+    public Set<T> activeTasks() {
+        return activeTasks;
+    }
+
+    public Set<T> standbyTasks() {
+        return standbyTasks;
+    }
+
+    public int assignedTaskCount() {
+        return assignedTasks.size();
+    }
 
-        double cost = COST_LOAD;
-        cost = prevAssignedTasks.remove(taskId) ? COST_STANDBY : cost;
-        cost = prevActiveTasks.remove(taskId) ? COST_ACTIVE : cost;
+    public void incrementCapacity() {
+        capacity++;
+    }
 
-        this.cost += cost;
+    public int activeTaskCount() {
+        return activeTasks.size();
+    }
+
+    public void addPreviousActiveTasks(final Set<T> prevTasks) {
+        prevActiveTasks.addAll(prevTasks);
+        prevAssignedTasks.addAll(prevTasks);
+    }
+
+    public void addPreviousStandbyTasks(final Set<T> standbyTasks) {
+        prevAssignedTasks.addAll(standbyTasks);
     }
 
     @Override
     public String toString() {
         return "[activeTasks: (" + activeTasks +
+            ") standbyTasks: (" + standbyTasks +
             ") assignedTasks: (" + assignedTasks +
             ") prevActiveTasks: (" + prevActiveTasks +
             ") prevAssignedTasks: (" + prevAssignedTasks +
             ") capacity: " + capacity +
-            " cost: " + cost +
             "]";
     }
+
+    boolean reachedCapacity() {
+        return assignedTasks.size() >= capacity;
+    }
+
+    boolean hasMoreAvailableCapacityThan(final ClientState<T> other) {
+        if (this.capacity <= 0) {
+            throw new IllegalStateException("Capacity of this ClientState must be greater than 0.");
+        }
+
+        if (other.capacity <= 0) {
+            throw new IllegalStateException("Capacity of other ClientState must be greater than 0");
+        }
+
+        final double otherLoad = (double) other.assignedTaskCount() / other.capacity;
+        final double thisLoad = (double) assignedTaskCount() / capacity;
+
+        if (thisLoad == otherLoad) {
+            return capacity > other.capacity;
+        }
+
+        return thisLoad < otherLoad;
+    }
+
+    Set<T> previousStandbyTasks() {
+        final Set<T> standby = new HashSet<>(prevAssignedTasks);
+        standby.removeAll(prevActiveTasks);
+        return standby;
+    }
+
+    Set<T> previousActiveTasks() {
+        return prevActiveTasks;
+    }
+
+    boolean hasAssignedTask(final T taskId) {
+        return assignedTasks.contains(taskId);
+    }
+
+    // Visible for testing
+    Set<T> assignedTasks() {
+        return assignedTasks;
+    }
+
+    Set<T> previousAssignedTasks() {
+        return prevAssignedTasks;
+    }
+
+    int capacity() {
+        return capacity;
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
new file mode 100644
index 0000000..6d49b72
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
@@ -0,0 +1,283 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+
+public class StickyTaskAssignor<ID> implements TaskAssignor<ID, TaskId> {
+
+    private static final Logger log = LoggerFactory.getLogger(StickyTaskAssignor.class);
+    private final Map<ID, ClientState<TaskId>> clients;
+    private final Set<TaskId> taskIds;
+    private final Map<TaskId, ID> previousActiveTaskAssignment = new HashMap<>();
+    private final Map<TaskId, Set<ID>> previousStandbyTaskAssignment = new HashMap<>();
+    private final TaskPairs taskPairs;
+    private final int availableCapacity;
+    private final boolean hasNewTasks;
+
+    public StickyTaskAssignor(final Map<ID, ClientState<TaskId>> clients, final Set<TaskId> taskIds) {
+        this.clients = clients;
+        this.taskIds = taskIds;
+        this.availableCapacity = sumCapacity(clients.values());
+        taskPairs = new TaskPairs(taskIds.size() * (taskIds.size() - 1) / 2);
+        mapPreviousTaskAssignment(clients);
+        this.hasNewTasks = !previousActiveTaskAssignment.keySet().containsAll(taskIds);
+    }
+
+    @Override
+    public void assign(final int numStandbyReplicas) {
+        assignActive();
+        assignStandby(numStandbyReplicas);
+    }
+
+    private void assignStandby(final int numStandbyReplicas) {
+        for (final TaskId taskId : taskIds) {
+            for (int i = 0; i < numStandbyReplicas; i++) {
+                final Set<ID> ids = findClientsWithoutAssignedTask(taskId);
+                if (ids.isEmpty()) {
+                    log.warn("Unable to assign {} of {} standby tasks for task [{}]. " +
+                                     "There is not enough available capacity. You should " +
+                                     "increase the number of threads and/or application instances " +
+                                     "to maintain the requested number of standby replicas.",
+                             numStandbyReplicas - i,
+                             numStandbyReplicas, taskId);
+                    break;
+                }
+                assign(taskId, ids, false);
+            }
+        }
+    }
+
+    private void assignActive() {
+        final Set<TaskId> previouslyAssignedTaskIds = new HashSet<>(previousActiveTaskAssignment.keySet());
+        previouslyAssignedTaskIds.addAll(previousStandbyTaskAssignment.keySet());
+        previouslyAssignedTaskIds.retainAll(taskIds);
+
+        // assign previously assigned tasks first
+        for (final TaskId taskId : previouslyAssignedTaskIds) {
+            assign(taskId, clients.keySet(), true);
+        }
+
+        final Set<TaskId> newTasks  = new HashSet<>(taskIds);
+        newTasks.removeAll(previouslyAssignedTaskIds);
+
+        for (final TaskId taskId : newTasks) {
+            assign(taskId, clients.keySet(), true);
+        }
+    }
+
+    private void assign(final TaskId taskId, final Set<ID> clientsWithin, final boolean active) {
+        final ClientState<TaskId> client = findClient(taskId, clientsWithin);
+        taskPairs.addPairs(taskId, client.assignedTasks());
+        client.assign(taskId, active);
+    }
+
+    private Set<ID> findClientsWithoutAssignedTask(final TaskId taskId) {
+        final Set<ID> clientIds = new HashSet<>();
+        for (final Map.Entry<ID, ClientState<TaskId>> client : clients.entrySet()) {
+            if (!client.getValue().hasAssignedTask(taskId)) {
+                clientIds.add(client.getKey());
+            }
+        }
+        return clientIds;
+    }
+
+
+    private ClientState<TaskId> findClient(final TaskId taskId,
+                                           final Set<ID> clientsWithin) {
+        // optimize the case where there is only 1 id to search within.
+        if (clientsWithin.size() == 1) {
+            return clients.get(clientsWithin.iterator().next());
+        }
+
+        final ClientState<TaskId> previous = findClientsWithPreviousAssignedTask(taskId, clientsWithin);
+        if (previous == null) {
+            return leastLoaded(taskId, clientsWithin);
+        }
+
+        if (shouldBalanceLoad(previous)) {
+            final ClientState<TaskId> standby = findLeastLoadedClientWithPreviousStandByTask(taskId, clientsWithin);
+            if (standby == null
+                    || shouldBalanceLoad(standby)) {
+                return leastLoaded(taskId, clientsWithin);
+            }
+            return standby;
+        }
+
+        return previous;
+    }
+
+    private boolean shouldBalanceLoad(final ClientState<TaskId> client) {
+        return !hasNewTasks
+                && client.reachedCapacity()
+                && hasClientsWithMoreAvailableCapacity(client);
+    }
+
+    private boolean hasClientsWithMoreAvailableCapacity(final ClientState<TaskId> client) {
+        for (ClientState<TaskId> clientState : clients.values()) {
+            if (clientState.hasMoreAvailableCapacityThan(client)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private ClientState<TaskId> findClientsWithPreviousAssignedTask(final TaskId taskId,
+                                                                    final Set<ID> clientsWithin) {
+        final ID previous = previousActiveTaskAssignment.get(taskId);
+        if (previous != null && clientsWithin.contains(previous)) {
+            return clients.get(previous);
+        }
+        return findLeastLoadedClientWithPreviousStandByTask(taskId, clientsWithin);
+    }
+
+    private ClientState<TaskId> findLeastLoadedClientWithPreviousStandByTask(final TaskId taskId, final Set<ID> clientsWithin) {
+        final Set<ID> ids = previousStandbyTaskAssignment.get(taskId);
+        if (ids == null) {
+            return null;
+        }
+        final HashSet<ID> constrainTo = new HashSet<>(ids);
+        constrainTo.retainAll(clientsWithin);
+        return leastLoaded(taskId, constrainTo);
+    }
+
+    private ClientState<TaskId> leastLoaded(final TaskId taskId, final Set<ID> clientIds) {
+        final ClientState<TaskId> leastLoaded = findLeastLoaded(taskId, clientIds, true);
+        if (leastLoaded == null) {
+            return findLeastLoaded(taskId, clientIds, false);
+        }
+        return leastLoaded;
+    }
+
+    private ClientState<TaskId> findLeastLoaded(final TaskId taskId,
+                                                final Set<ID> clientIds,
+                                                boolean checkTaskPairs) {
+        ClientState<TaskId> leastLoaded = null;
+        for (final ID id : clientIds) {
+            final ClientState<TaskId> client = clients.get(id);
+            if (client.assignedTaskCount() == 0) {
+                return client;
+            }
+
+            if (leastLoaded == null || client.hasMoreAvailableCapacityThan(leastLoaded)) {
+                if (!checkTaskPairs) {
+                    leastLoaded = client;
+                } else if (taskPairs.hasNewPair(taskId, client.assignedTasks())) {
+                    leastLoaded = client;
+                }
+            }
+
+        }
+        return leastLoaded;
+
+    }
+
+    private void mapPreviousTaskAssignment(final Map<ID, ClientState<TaskId>> clients) {
+        for (final Map.Entry<ID, ClientState<TaskId>> clientState : clients.entrySet()) {
+            for (final TaskId activeTask : clientState.getValue().previousActiveTasks()) {
+                previousActiveTaskAssignment.put(activeTask, clientState.getKey());
+            }
+
+            for (final TaskId prevAssignedTask : clientState.getValue().previousStandbyTasks()) {
+                if (!previousStandbyTaskAssignment.containsKey(prevAssignedTask)) {
+                    previousStandbyTaskAssignment.put(prevAssignedTask, new HashSet<ID>());
+                }
+                previousStandbyTaskAssignment.get(prevAssignedTask).add(clientState.getKey());
+            }
+        }
+    }
+
+    private int sumCapacity(final Collection<ClientState<TaskId>> values) {
+        int capacity = 0;
+        for (ClientState<TaskId> client : values) {
+            capacity += client.capacity();
+        }
+        return capacity;
+    }
+
+
+    private static class TaskPairs {
+        private final Set<Pair> pairs;
+        private final int maxPairs;
+
+        TaskPairs(final int maxPairs) {
+            this.maxPairs = maxPairs;
+            this.pairs = new HashSet<>(maxPairs);
+        }
+
+        boolean hasNewPair(final TaskId task1, final Set<TaskId> taskIds) {
+            if (pairs.size() == maxPairs) {
+                return false;
+            }
+            for (final TaskId taskId : taskIds) {
+                if (!pairs.contains(pair(task1, taskId))) {
+                    return true;
+                }
+            }
+            return false;
+        }
+
+        void addPairs(final TaskId taskId, final Set<TaskId> assigned) {
+            for (final TaskId id : assigned) {
+                pairs.add(pair(id, taskId));
+            }
+        }
+
+        Pair pair(final TaskId task1, final TaskId task2) {
+            if (task1.compareTo(task2) < 0) {
+                return new Pair(task1, task2);
+            }
+            return new Pair(task2, task1);
+        }
+
+        class Pair {
+            private final TaskId task1;
+            private final TaskId task2;
+
+            Pair(final TaskId task1, final TaskId task2) {
+                this.task1 = task1;
+                this.task2 = task2;
+            }
+
+            @Override
+            public boolean equals(final Object o) {
+                if (this == o) return true;
+                if (o == null || getClass() != o.getClass()) return false;
+                final Pair pair = (Pair) o;
+                return Objects.equals(task1, pair.task1) &&
+                        Objects.equals(task2, pair.task2);
+            }
+
+            @Override
+            public int hashCode() {
+                return Objects.hash(task1, task2);
+            }
+        }
+
+
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
index e807c4e..b846ae0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
@@ -5,215 +5,17 @@
  * The ASF licenses this file to You under the Apache License, Version 2.0
  * (the "License"); you may not use this file except in compliance with
  * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import org.apache.kafka.streams.errors.TaskAssignmentException;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Random;
-import java.util.Set;
-
-public class TaskAssignor<C, T extends Comparable<T>> {
-
-    private static final Logger log = LoggerFactory.getLogger(TaskAssignor.class);
-
-    public static <C, T extends Comparable<T>> void assign(Map<C, ClientState<T>> states, Set<T> tasks, int numStandbyReplicas) {
-        long seed = 0L;
-        for (C client : states.keySet()) {
-            seed += client.hashCode();
-        }
-
-        TaskAssignor<C, T> assignor = new TaskAssignor<>(states, tasks, seed);
-
-        // assign active tasks
-        assignor.assignTasks();
-
-        // assign standby tasks
-        if (numStandbyReplicas > 0)
-            assignor.assignStandbyTasks(numStandbyReplicas);
-    }
-
-    private final Random rand;
-    private final Map<C, ClientState<T>> states;
-    private final Set<TaskPair<T>> taskPairs;
-    private final int maxNumTaskPairs;
-    private final ArrayList<T> tasks;
-    private boolean prevAssignmentBalanced = true;
-    private boolean prevClientsUnchanged = true;
-
-    private TaskAssignor(Map<C, ClientState<T>> states, Set<T> tasks, long randomSeed) {
-        this.rand = new Random(randomSeed);
-        this.tasks = new ArrayList<>(tasks);
-        this.states = states;
-
-        int avgNumTasks = tasks.size() / states.size();
-        Set<T> existingTasks = new HashSet<>();
-        for (Map.Entry<C, ClientState<T>> entry : states.entrySet()) {
-            Set<T> oldTasks = entry.getValue().prevAssignedTasks;
-
-            // make sure the previous assignment is balanced
-            prevAssignmentBalanced = prevAssignmentBalanced &&
-                oldTasks.size() < 2 * avgNumTasks && oldTasks.size() > avgNumTasks / 2;
-
-            // make sure there are no duplicates
-            for (T task : oldTasks) {
-                prevClientsUnchanged = prevClientsUnchanged && !existingTasks.contains(task);
-            }
-            existingTasks.addAll(oldTasks);
-        }
-
-        // make sure the existing assignment didn't miss out any task
-        prevClientsUnchanged = prevClientsUnchanged && existingTasks.equals(tasks);
-
-        int numTasks = tasks.size();
-        this.maxNumTaskPairs = numTasks * (numTasks - 1) / 2;
-        this.taskPairs = new HashSet<>(this.maxNumTaskPairs);
-    }
-
-    private void assignTasks() {
-        assignTasks(true);
-    }
-
-    private void assignStandbyTasks(int numStandbyReplicas) {
-        int numReplicas = Math.min(numStandbyReplicas, states.size() - 1);
-        for (int i = 0; i < numReplicas; i++) {
-            assignTasks(false);
-        }
-    }
-
-    private void assignTasks(boolean active) {
-        Collections.shuffle(this.tasks, rand);
-
-        for (T task : tasks) {
-            ClientState<T> state = findClientFor(task);
-
-            if (state != null) {
-                state.assign(task, active);
-            } else {
-                TaskAssignmentException ex = new TaskAssignmentException("failed to find an assignable client");
-                log.error(ex.getMessage(), ex);
-                throw ex;
-            }
-        }
-    }
-
-    private ClientState<T> findClientFor(T task) {
-        boolean checkTaskPairs = taskPairs.size() < maxNumTaskPairs;
-
-        ClientState<T> state = findClientByAdditionCost(task, checkTaskPairs);
-
-        if (state == null && checkTaskPairs)
-            state = findClientByAdditionCost(task, false);
-
-        if (state != null)
-            addTaskPairs(task, state);
-
-        return state;
-    }
-
-    private ClientState<T> findClientByAdditionCost(T task, boolean checkTaskPairs) {
-        ClientState<T> candidate = null;
-        double candidateAdditionCost = 0d;
-
-        for (ClientState<T> state : states.values()) {
-            if (prevAssignmentBalanced && prevClientsUnchanged &&
-                state.prevAssignedTasks.contains(task)) {
-                return state;
-            }
-            if (!state.assignedTasks.contains(task)) {
-                // if checkTaskPairs flag is on, skip this client if this task doesn't introduce a new task combination
-                if (checkTaskPairs && !state.assignedTasks.isEmpty() && !hasNewTaskPair(task, state))
-                    continue;
-
-                double additionCost = computeAdditionCost(task, state);
-                if (candidate == null ||
-                        (additionCost < candidateAdditionCost ||
-                            (additionCost == candidateAdditionCost && state.cost < candidate.cost))) {
-                    candidate = state;
-                    candidateAdditionCost = additionCost;
-                }
-            }
-        }
-
-        return candidate;
-    }
-
-    private void addTaskPairs(T task, ClientState<T> state) {
-        for (T other : state.assignedTasks) {
-            taskPairs.add(pair(task, other));
-        }
-    }
-
-    private boolean hasNewTaskPair(T task, ClientState<T> state) {
-        for (T other : state.assignedTasks) {
-            if (!taskPairs.contains(pair(task, other)))
-                return true;
-        }
-        return false;
-    }
-
-    private double computeAdditionCost(T task, ClientState<T> state) {
-        double cost = Math.floor((double) state.assignedTasks.size() / state.capacity);
-
-        if (state.prevAssignedTasks.contains(task)) {
-            if (state.prevActiveTasks.contains(task)) {
-                cost += ClientState.COST_ACTIVE;
-            } else {
-                cost += ClientState.COST_STANDBY;
-            }
-        } else {
-            cost += ClientState.COST_LOAD;
-        }
-
-        return cost;
-    }
-
-    private TaskPair<T> pair(T task1, T task2) {
-        if (task1.compareTo(task2) < 0) {
-            return new TaskPair<>(task1, task2);
-        } else {
-            return new TaskPair<>(task2, task1);
-        }
-    }
-
-    private static class TaskPair<T> {
-        final T task1;
-        final T task2;
-
-        TaskPair(T task1, T task2) {
-            this.task1 = task1;
-            this.task2 = task2;
-        }
-
-        @Override
-        public int hashCode() {
-            return task1.hashCode() ^ task2.hashCode();
-        }
-
-        @SuppressWarnings("unchecked")
-        @Override
-        public boolean equals(Object o) {
-            if (o instanceof TaskPair) {
-                TaskPair<T> other = (TaskPair<T>) o;
-                return this.task1.equals(other.task1) && this.task2.equals(other.task2);
-            }
-            return false;
-        }
-    }
-
+public interface TaskAssignor<C, T extends Comparable<T>> {
+    void assign(int numStandbyReplicas);
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
index 39baa4e..68700cb 100644
--- a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
@@ -21,6 +21,7 @@ import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.kstream.Aggregator;
 import org.apache.kafka.streams.kstream.ForeachAction;
 import org.apache.kafka.streams.kstream.Initializer;
@@ -38,6 +39,7 @@ import org.apache.kafka.test.MockProcessorSupplier;
 import org.apache.kafka.test.TestUtils;
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 
 
@@ -55,6 +57,9 @@ public class KTableAggregateTest {
     private KStreamTestDriver driver = null;
     private File stateDir = null;
 
+    @Rule
+    public EmbeddedKafkaCluster cluster = null;
+
     @After
     public void tearDown() {
         if (driver != null) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignorTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignorTest.java
index 6503038..36d652a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignorTest.java
@@ -135,7 +135,7 @@ public class StreamPartitionAssignorTest {
         StreamThread thread = new StreamThread(builder, config, new MockClientSupplier(), "test", clientId, processId, new Metrics(), Time.SYSTEM, new StreamsMetadataState(builder, StreamsMetadataState.UNKNOWN_HOST),
                                                0) {
             @Override
-            public Set<TaskId> prevTasks() {
+            public Set<TaskId> prevActiveTasks() {
                 return prevTasks;
             }
             @Override
@@ -482,12 +482,12 @@ public class StreamPartitionAssignorTest {
         Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2);
 
 
-        final Set<TaskId> prevTasks10 = Utils.mkSet(task0);
-        final Set<TaskId> prevTasks11 = Utils.mkSet(task1);
-        final Set<TaskId> prevTasks20 = Utils.mkSet(task2);
-        final Set<TaskId> standbyTasks10 = Utils.mkSet(task1);
-        final Set<TaskId> standbyTasks11 = Utils.mkSet(task2);
-        final Set<TaskId> standbyTasks20 = Utils.mkSet(task0);
+        final Set<TaskId> prevTasks00 = Utils.mkSet(task0);
+        final Set<TaskId> prevTasks01 = Utils.mkSet(task1);
+        final Set<TaskId> prevTasks02 = Utils.mkSet(task2);
+        final Set<TaskId> standbyTasks01 = Utils.mkSet(task1);
+        final Set<TaskId> standbyTasks02 = Utils.mkSet(task2);
+        final Set<TaskId> standbyTasks00 = Utils.mkSet(task0);
 
         UUID uuid1 = UUID.randomUUID();
         UUID uuid2 = UUID.randomUUID();
@@ -501,11 +501,11 @@ public class StreamPartitionAssignorTest {
 
         Map<String, PartitionAssignor.Subscription> subscriptions = new HashMap<>();
         subscriptions.put("consumer10",
-                new PartitionAssignor.Subscription(topics, new SubscriptionInfo(uuid1, prevTasks10, standbyTasks10, userEndPoint).encode()));
+                new PartitionAssignor.Subscription(topics, new SubscriptionInfo(uuid1, prevTasks00, standbyTasks01, userEndPoint).encode()));
         subscriptions.put("consumer11",
-                new PartitionAssignor.Subscription(topics, new SubscriptionInfo(uuid1, prevTasks11, standbyTasks11, userEndPoint).encode()));
+                new PartitionAssignor.Subscription(topics, new SubscriptionInfo(uuid1, prevTasks01, standbyTasks02, userEndPoint).encode()));
         subscriptions.put("consumer20",
-                new PartitionAssignor.Subscription(topics, new SubscriptionInfo(uuid2, prevTasks20, standbyTasks20, userEndPoint).encode()));
+                new PartitionAssignor.Subscription(topics, new SubscriptionInfo(uuid2, prevTasks02, standbyTasks00, "any:9097").encode()));
 
         Map<String, PartitionAssignor.Assignment> assignments = partitionAssignor.assign(metadata, subscriptions);
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 0e98f56..250abc1 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -407,8 +407,8 @@ public class StreamThreadTest {
 
         assertThat(thread1.tasks().keySet(), equalTo(originalTaskAssignmentThread2));
         assertThat(thread2.tasks().keySet(), equalTo(originalTaskAssignmentThread1));
-        assertThat(thread1.prevTasks(), equalTo(originalTaskAssignmentThread1));
-        assertThat(thread2.prevTasks(), equalTo(originalTaskAssignmentThread2));
+        assertThat(thread1.prevActiveTasks(), equalTo(originalTaskAssignmentThread1));
+        assertThat(thread2.prevActiveTasks(), equalTo(originalTaskAssignmentThread2));
     }
 
     private class MockStreamsPartitionAssignor extends StreamPartitionAssignor {

http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
new file mode 100644
index 0000000..6a12191
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
@@ -0,0 +1,151 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.common.utils.Utils;
+import org.junit.Test;
+
+import java.util.Collections;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class ClientStateTest {
+
+    private final ClientState<Integer> client = new ClientState<>(1);
+
+    @Test
+    public void shouldHaveNotReachedCapacityWhenAssignedTasksLessThanCapacity() throws Exception {
+        assertFalse(client.reachedCapacity());
+    }
+
+    @Test
+    public void shouldHaveReachedCapacityWhenAssignedTasksGreaterThanOrEqualToCapacity() throws Exception {
+        client.assign(1, true);
+        assertTrue(client.reachedCapacity());
+    }
+
+
+    @Test
+    public void shouldAddActiveTasksToBothAssignedAndActive() throws Exception {
+        client.assign(1, true);
+        assertThat(client.activeTasks(), equalTo(Collections.singleton(1)));
+        assertThat(client.assignedTasks(), equalTo(Collections.singleton(1)));
+        assertThat(client.assignedTaskCount(), equalTo(1));
+        assertThat(client.standbyTasks().size(), equalTo(0));
+    }
+
+    @Test
+    public void shouldAddStandbyTasksToBothStandbyAndActive() throws Exception {
+        client.assign(1, false);
+        assertThat(client.assignedTasks(), equalTo(Collections.singleton(1)));
+        assertThat(client.standbyTasks(), equalTo(Collections.singleton(1)));
+        assertThat(client.assignedTaskCount(), equalTo(1));
+        assertThat(client.activeTasks().size(), equalTo(0));
+    }
+
+    @Test
+    public void shouldAddPreviousActiveTasksToPreviousAssignedAndPreviousActive() throws Exception {
+        client.addPreviousActiveTasks(Utils.mkSet(1, 2));
+        assertThat(client.previousActiveTasks(), equalTo(Utils.mkSet(1, 2)));
+        assertThat(client.previousAssignedTasks(), equalTo(Utils.mkSet(1, 2)));
+    }
+
+    @Test
+    public void shouldAddPreviousStandbyTasksToPreviousAssigned() throws Exception {
+        client.addPreviousStandbyTasks(Utils.mkSet(1, 2));
+        assertThat(client.previousActiveTasks().size(), equalTo(0));
+        assertThat(client.previousAssignedTasks(), equalTo(Utils.mkSet(1, 2)));
+    }
+
+    @Test
+    public void shouldHaveAssignedTaskIfActiveTaskAssigned() throws Exception {
+        client.assign(2, true);
+        assertTrue(client.hasAssignedTask(2));
+    }
+
+    @Test
+    public void shouldHaveAssignedTaskIfStandbyTaskAssigned() throws Exception {
+        client.assign(2, false);
+        assertTrue(client.hasAssignedTask(2));
+    }
+
+    @Test
+    public void shouldNotHaveAssignedTaskIfTaskNotAssigned() throws Exception {
+        client.assign(2, true);
+        assertFalse(client.hasAssignedTask(3));
+    }
+
+    @Test
+    public void shouldHaveMoreAvailableCapacityWhenCapacityTheSameButFewerAssignedTasks() throws Exception {
+        final ClientState<Integer> c2 = new ClientState<>(1);
+        client.assign(1, true);
+        assertTrue(c2.hasMoreAvailableCapacityThan(client));
+        assertFalse(client.hasMoreAvailableCapacityThan(c2));
+    }
+
+    @Test
+    public void shouldHaveMoreAvailableCapacityWhenCapacityHigherAndSameAssignedTaskCount() throws Exception {
+        final ClientState<Integer> c2 = new ClientState<>(2);
+        assertTrue(c2.hasMoreAvailableCapacityThan(client));
+        assertFalse(client.hasMoreAvailableCapacityThan(c2));
+    }
+
+    @Test
+    public void shouldUseMultiplesOfCapacityToDetermineClientWithMoreAvailableCapacity() throws Exception {
+        final ClientState<Integer> c2 = new ClientState<>(2);
+
+        for (int i = 0; i < 7; i++) {
+            c2.assign(i, true);
+        }
+
+        for (int i = 7; i < 11; i++) {
+            client.assign(i, true);
+        }
+
+        assertTrue(c2.hasMoreAvailableCapacityThan(client));
+    }
+
+    @Test
+    public void shouldHaveMoreAvailableCapacityWhenCapacityIsTheSameButAssignedTasksIsLess() throws Exception {
+        final ClientState<Integer> c1 = new ClientState<>(3);
+        final ClientState<Integer> c2 = new ClientState<>(3);
+        for (int i = 0; i < 4; i++) {
+            c1.assign(i, true);
+            c2.assign(i, true);
+        }
+        c2.assign(5, true);
+        assertTrue(c1.hasMoreAvailableCapacityThan(c2));
+    }
+
+    @Test(expected = IllegalStateException.class)
+    public void shouldThrowIllegalStateExceptionIfCapacityOfThisClientStateIsZero() throws Exception {
+        final ClientState<Integer> c1 = new ClientState<>(0);
+        c1.hasMoreAvailableCapacityThan(new ClientState<Integer>(1));
+    }
+
+    @Test(expected = IllegalStateException.class)
+    public void shouldThrowIllegalStateExceptionIfCapacityOfOtherClientStateIsZero() throws Exception {
+        final ClientState<Integer> c1 = new ClientState<>(1);
+        c1.hasMoreAvailableCapacityThan(new ClientState<Integer>(0));
+    }
+
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
new file mode 100644
index 0000000..a119d18
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
@@ -0,0 +1,515 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.processor.TaskId;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsCollectionContaining.hasItems;
+import static org.hamcrest.core.IsNot.not;
+import static org.junit.Assert.assertTrue;
+
+public class StickyTaskAssignorTest {
+
+    private final TaskId task00 = new TaskId(0, 0);
+    private final TaskId task01 = new TaskId(0, 1);
+    private final TaskId task02 = new TaskId(0, 2);
+    private final TaskId task03 = new TaskId(0, 3);
+    private final Map<Integer, ClientState<TaskId>> clients = new TreeMap<>();
+    private final Integer p1 = 1;
+    private final Integer p2 = 2;
+    private final Integer p3 = 3;
+    private final Integer p4 = 4;
+
+    @Test
+    public void shouldAssignOneActiveTaskToEachProcessWhenTaskCountSameAsProcessCount() throws Exception {
+        createClient(p1, 1);
+        createClient(p2, 1);
+        createClient(p3, 1);
+
+        final StickyTaskAssignor taskAssignor = createTaskAssignor(task00, task01, task02);
+        taskAssignor.assign(0);
+
+        for (final Integer processId : clients.keySet()) {
+            assertThat(clients.get(processId).activeTaskCount(), equalTo(1));
+        }
+    }
+
+    @Test
+    public void shouldNotMigrateActiveTaskToOtherProcess() throws Exception {
+        createClientWithPreviousActiveTasks(p1, 1, task00);
+        createClientWithPreviousActiveTasks(p2, 1, task01);
+
+        final StickyTaskAssignor firstAssignor = createTaskAssignor(task00, task01, task02);
+        firstAssignor.assign(0);
+
+        assertThat(clients.get(p1).activeTasks(), hasItems(task00));
+        assertThat(clients.get(p2).activeTasks(), hasItems(task01));
+        assertThat(allActiveTasks(), equalTo(Arrays.asList(task00, task01, task02)));
+
+        clients.clear();
+
+        // flip the previous active tasks assignment around.
+        createClientWithPreviousActiveTasks(p1, 1, task01);
+        createClientWithPreviousActiveTasks(p2, 1, task02);
+
+        final StickyTaskAssignor secondAssignor = createTaskAssignor(task00, task01, task02);
+        secondAssignor.assign(0);
+
+        assertThat(clients.get(p1).activeTasks(), hasItems(task01));
+        assertThat(clients.get(p2).activeTasks(), hasItems(task02));
+        assertThat(allActiveTasks(), equalTo(Arrays.asList(task00, task01, task02)));
+    }
+
+    @Test
+    public void shouldMigrateActiveTasksToNewProcessWithoutChangingAllAssignments() throws Exception {
+        createClientWithPreviousActiveTasks(p1, 1, task00, task02);
+        createClientWithPreviousActiveTasks(p2, 1, task01);
+        createClient(p3, 1);
+
+        final StickyTaskAssignor taskAssignor = createTaskAssignor(task00, task01, task02);
+
+        taskAssignor.assign(0);
+
+        assertThat(clients.get(p2).activeTasks(), equalTo(Collections.singleton(task01)));
+        assertThat(clients.get(p1).activeTasks().size(), equalTo(1));
+        assertThat(clients.get(p3).activeTasks().size(), equalTo(1));
+        assertThat(allActiveTasks(), equalTo(Arrays.asList(task00, task01, task02)));
+    }
+
+    @Test
+    public void shouldAssignBasedOnCapacity() throws Exception {
+        createClient(p1, 1);
+        createClient(p2, 2);
+        final StickyTaskAssignor taskAssignor = createTaskAssignor(task00, task01, task02);
+
+        taskAssignor.assign(0);
+        assertThat(clients.get(p1).activeTasks().size(), equalTo(1));
+        assertThat(clients.get(p2).activeTasks().size(), equalTo(2));
+    }
+
+    @Test
+    public void shouldKeepActiveTaskStickynessWhenMoreClientThanActiveTasks() {
+        final int p5 = 5;
+        createClientWithPreviousActiveTasks(p1, 1, task00);
+        createClientWithPreviousActiveTasks(p2, 1, task02);
+        createClientWithPreviousActiveTasks(p3, 1, task01);
+        createClient(p4, 1);
+        createClient(p5, 1);
+
+        final StickyTaskAssignor taskAssignor = createTaskAssignor(task00, task01, task02);
+        taskAssignor.assign(0);
+
+        assertThat(clients.get(p1).activeTasks(), equalTo(Collections.singleton(task00)));
+        assertThat(clients.get(p2).activeTasks(), equalTo(Collections.singleton(task02)));
+        assertThat(clients.get(p3).activeTasks(), equalTo(Collections.singleton(task01)));
+
+        // change up the assignment and make sure it is still sticky
+        clients.clear();
+        createClient(p1, 1);
+        createClientWithPreviousActiveTasks(p2, 1, task00);
+        createClient(p3, 1);
+        createClientWithPreviousActiveTasks(p4, 1, task02);
+        createClientWithPreviousActiveTasks(p5, 1, task01);
+
+        final StickyTaskAssignor secondAssignor = createTaskAssignor(task00, task01, task02);
+        secondAssignor.assign(0);
+
+        assertThat(clients.get(p2).activeTasks(), equalTo(Collections.singleton(task00)));
+        assertThat(clients.get(p4).activeTasks(), equalTo(Collections.singleton(task02)));
+        assertThat(clients.get(p5).activeTasks(), equalTo(Collections.singleton(task01)));
+
+
+    }
+
+    @Test
+    public void shouldAssignTasksToClientWithPreviousStandbyTasks() throws Exception {
+        final ClientState<TaskId> client1 = createClient(p1, 1);
+        client1.addPreviousStandbyTasks(Utils.mkSet(task02));
+        final ClientState<TaskId> client2 = createClient(p2, 1);
+        client2.addPreviousStandbyTasks(Utils.mkSet(task01));
+        final ClientState<TaskId> client3 = createClient(p3, 1);
+        client3.addPreviousStandbyTasks(Utils.mkSet(task00));
+
+        final StickyTaskAssignor taskAssignor = createTaskAssignor(task00, task01, task02);
+
+        taskAssignor.assign(0);
+
+        assertThat(clients.get(p1).activeTasks(), equalTo(Collections.singleton(task02)));
+        assertThat(clients.get(p2).activeTasks(), equalTo(Collections.singleton(task01)));
+        assertThat(clients.get(p3).activeTasks(), equalTo(Collections.singleton(task00)));
+    }
+
+    @Test
+    public void shouldAssignBasedOnCapacityWhenMultipleClientHaveStandbyTasks() throws Exception {
+        final ClientState<TaskId> c1 = createClientWithPreviousActiveTasks(p1, 1, task00);
+        c1.addPreviousStandbyTasks(Utils.mkSet(task01));
+        final ClientState<TaskId> c2 = createClientWithPreviousActiveTasks(p2, 2, task02);
+        c2.addPreviousStandbyTasks(Utils.mkSet(task01));
+
+        final StickyTaskAssignor taskAssignor = createTaskAssignor(task00, task01, task02);
+
+        taskAssignor.assign(0);
+
+        assertThat(clients.get(p1).activeTasks(), equalTo(Collections.singleton(task00)));
+        assertThat(clients.get(p2).activeTasks(), equalTo(Utils.mkSet(task02, task01)));
+    }
+
+    @Test
+    public void shouldAssignStandbyTasksToDifferentClientThanCorrespondingActiveTaskIsAssingedTo() throws Exception {
+        createClientWithPreviousActiveTasks(p1, 1, task00);
+        createClientWithPreviousActiveTasks(p2, 1, task01);
+        createClientWithPreviousActiveTasks(p3, 1, task02);
+        createClientWithPreviousActiveTasks(p4, 1, task03);
+
+        final StickyTaskAssignor taskAssignor = createTaskAssignor(task00, task01, task02, task03);
+        taskAssignor.assign(1);
+
+        assertThat(clients.get(p1).standbyTasks(), not(hasItems(task00)));
+        assertTrue(clients.get(p1).standbyTasks().size() <= 2);
+        assertThat(clients.get(p2).standbyTasks(), not(hasItems(task01)));
+        assertTrue(clients.get(p2).standbyTasks().size() <= 2);
+        assertThat(clients.get(p3).standbyTasks(), not(hasItems(task02)));
+        assertTrue(clients.get(p3).standbyTasks().size() <= 2);
+        assertThat(clients.get(p4).standbyTasks(), not(hasItems(task03)));
+        assertTrue(clients.get(p4).standbyTasks().size() <= 2);
+
+        int nonEmptyStandbyTaskCount = 0;
+        for (final Integer client : clients.keySet()) {
+            nonEmptyStandbyTaskCount += clients.get(client).standbyTasks().isEmpty() ? 0 : 1;
+        }
+
+        assertTrue(nonEmptyStandbyTaskCount >= 3);
+        assertThat(allStandbyTasks(), equalTo(Arrays.asList(task00, task01, task02, task03)));
+    }
+
+
+
+    @Test
+    public void shouldAssignMultipleReplicasOfStandbyTask() throws Exception {
+        createClientWithPreviousActiveTasks(p1, 1, task00);
+        createClientWithPreviousActiveTasks(p2, 1, task01);
+        createClientWithPreviousActiveTasks(p3, 1, task02);
+
+        final StickyTaskAssignor taskAssignor = createTaskAssignor(task00, task01, task02);
+        taskAssignor.assign(2);
+
+        assertThat(clients.get(p1).standbyTasks(), equalTo(Utils.mkSet(task01, task02)));
+        assertThat(clients.get(p2).standbyTasks(), equalTo(Utils.mkSet(task02, task00)));
+        assertThat(clients.get(p3).standbyTasks(), equalTo(Utils.mkSet(task00, task01)));
+    }
+
+    @Test
+    public void shouldNotAssignStandbyTaskReplicasWhenNoClientAvailableWithoutHavingTheTaskAssigned() throws Exception {
+        createClient(p1, 1);
+        final StickyTaskAssignor taskAssignor = createTaskAssignor(task00);
+        taskAssignor.assign(1);
+        assertThat(clients.get(p1).standbyTasks().size(), equalTo(0));
+    }
+
+    @Test
+    public void shouldAssignActiveAndStandbyTasks() throws Exception {
+        createClient(p1, 1);
+        createClient(p2, 1);
+        createClient(p3, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task01, task02);
+        taskAssignor.assign(1);
+
+        assertThat(allActiveTasks(), equalTo(Arrays.asList(task00, task01, task02)));
+        assertThat(allStandbyTasks(), equalTo(Arrays.asList(task00, task01, task02)));
+    }
+
+
+    @Test
+    public void shouldAssignAtLeastOneTaskToEachClientIfPossible() throws Exception {
+        createClient(p1, 3);
+        createClient(p2, 1);
+        createClient(p3, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task01, task02);
+        taskAssignor.assign(0);
+        assertThat(clients.get(p1).assignedTaskCount(), equalTo(1));
+        assertThat(clients.get(p2).assignedTaskCount(), equalTo(1));
+        assertThat(clients.get(p3).assignedTaskCount(), equalTo(1));
+    }
+
+    @Test
+    public void shouldAssignEachActiveTaskToOneClientWhenMoreClientsThanTasks() throws Exception {
+        createClient(p1, 1);
+        createClient(p2, 1);
+        createClient(p3, 1);
+        createClient(p4, 1);
+        createClient(5, 1);
+        createClient(6, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task01, task02);
+        taskAssignor.assign(0);
+
+        assertThat(allActiveTasks(), equalTo(Arrays.asList(task00, task01, task02)));
+    }
+
+    @Test
+    public void shouldBalanceActiveAndStandbyTasksAcrossAvailableClients() throws Exception {
+        createClient(p1, 1);
+        createClient(p2, 1);
+        createClient(p3, 1);
+        createClient(p4, 1);
+        createClient(5, 1);
+        createClient(6, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task01, task02);
+        taskAssignor.assign(1);
+
+        for (final ClientState<TaskId> clientState : clients.values()) {
+            assertThat(clientState.assignedTaskCount(), equalTo(1));
+        }
+    }
+
+    @Test
+    public void shouldAssignMoreTasksToClientWithMoreCapacity() throws Exception {
+        createClient(p2, 2);
+        createClient(p1, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00,
+                                                                            task01,
+                                                                            task02,
+                                                                            new TaskId(1, 0),
+                                                                            new TaskId(1, 1),
+                                                                            new TaskId(1, 2),
+                                                                            new TaskId(2, 0),
+                                                                            new TaskId(2, 1),
+                                                                            new TaskId(2, 2),
+                                                                            new TaskId(3, 0),
+                                                                            new TaskId(3, 1),
+                                                                            new TaskId(3, 2));
+
+        taskAssignor.assign(0);
+        assertThat(clients.get(p2).assignedTaskCount(), equalTo(8));
+        assertThat(clients.get(p1).assignedTaskCount(), equalTo(4));
+    }
+
+
+    @Test
+    public void shouldNotHaveSameAssignmentOnAnyTwoHosts() throws Exception {
+        createClient(p1, 1);
+        createClient(p2, 1);
+        createClient(p3, 1);
+        createClient(p4, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task02, task01, task03);
+        taskAssignor.assign(1);
+
+        for (int i = p1; i <= p4; i++) {
+            final Set<TaskId> taskIds = clients.get(i).assignedTasks();
+            for (int j = p1; j <= p4; j++) {
+                if (j != i) {
+                    assertThat("clients shouldn't have same task assignment", clients.get(j).assignedTasks(),
+                               not(equalTo(taskIds)));
+                }
+            }
+
+        }
+    }
+
+    @Test
+    public void shouldNotHaveSameAssignmentOnAnyTwoHostsWhenThereArePreviousActiveTasks() throws Exception {
+        createClientWithPreviousActiveTasks(p1, 1, task01, task02);
+        createClientWithPreviousActiveTasks(p2, 1, task03);
+        createClientWithPreviousActiveTasks(p3, 1, task00);
+        createClient(p4, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task02, task01, task03);
+        taskAssignor.assign(1);
+
+        for (int i = p1; i <= p4; i++) {
+            final Set<TaskId> taskIds = clients.get(i).assignedTasks();
+            for (int j = p1; j <= p4; j++) {
+                if (j != i) {
+                    assertThat("clients shouldn't have same task assignment", clients.get(j).assignedTasks(),
+                               not(equalTo(taskIds)));
+                }
+            }
+
+        }
+    }
+
+    @Test
+    public void shouldNotHaveSameAssignmentOnAnyTwoHostsWhenThereArePreviousStandbyTasks() throws Exception {
+        final ClientState<TaskId> c1 = createClientWithPreviousActiveTasks(p1, 1, task01, task02);
+        c1.addPreviousStandbyTasks(Utils.mkSet(task03, task00));
+        final ClientState<TaskId> c2 = createClientWithPreviousActiveTasks(p2, 1, task03, task00);
+        c2.addPreviousStandbyTasks(Utils.mkSet(task01, task02));
+
+        createClient(p3, 1);
+        createClient(p4, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task02, task01, task03);
+        taskAssignor.assign(1);
+
+        for (int i = p1; i <= p4; i++) {
+            final Set<TaskId> taskIds = clients.get(i).assignedTasks();
+            for (int j = p1; j <= p4; j++) {
+                if (j != i) {
+                    assertThat("clients shouldn't have same task assignment", clients.get(j).assignedTasks(),
+                               not(equalTo(taskIds)));
+                }
+            }
+
+        }
+    }
+
+    @Test
+    public void shouldReBalanceTasksAcrossAllClientsWhenCapacityAndTaskCountTheSame() throws Exception {
+        createClientWithPreviousActiveTasks(p3, 1, task00, task01, task02, task03);
+        createClient(p1, 1);
+        createClient(p2, 1);
+        createClient(p4, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task02, task01, task03);
+        taskAssignor.assign(0);
+
+        assertThat(clients.get(p1).assignedTaskCount(), equalTo(1));
+        assertThat(clients.get(p2).assignedTaskCount(), equalTo(1));
+        assertThat(clients.get(p3).assignedTaskCount(), equalTo(1));
+        assertThat(clients.get(p4).assignedTaskCount(), equalTo(1));
+    }
+
+    @Test
+    public void shouldReBalanceTasksAcrossClientsWhenCapacityLessThanTaskCount() throws Exception {
+        createClientWithPreviousActiveTasks(p3, 1, task00, task01, task02, task03);
+        createClient(p1, 1);
+        createClient(p2, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task02, task01, task03);
+        taskAssignor.assign(0);
+
+        assertThat(clients.get(p3).assignedTaskCount(), equalTo(2));
+        assertThat(clients.get(p1).assignedTaskCount(), equalTo(1));
+        assertThat(clients.get(p2).assignedTaskCount(), equalTo(1));
+    }
+
+    @Test
+    public void shouldRebalanceTasksToClientsBasedOnCapacity() throws Exception {
+        createClientWithPreviousActiveTasks(p2, 1, task00, task03, task02);
+        createClient(p3, 2);
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task02, task03);
+        taskAssignor.assign(0);
+        assertThat(clients.get(p2).assignedTaskCount(), equalTo(1));
+        assertThat(clients.get(p3).assignedTaskCount(), equalTo(2));
+    }
+
+    @Test
+    public void shouldMoveMinimalNumberOfTasksWhenPreviouslyAboveCapacityAndNewClientAdded() throws Exception {
+        final Set<TaskId> p1PrevTasks = Utils.mkSet(task00, task02);
+        final Set<TaskId> p2PrevTasks = Utils.mkSet(task01, task03);
+
+        createClientWithPreviousActiveTasks(p1, 1, task00, task02);
+        createClientWithPreviousActiveTasks(p2, 1, task01, task03);
+        createClientWithPreviousActiveTasks(p3, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task02, task01, task03);
+        taskAssignor.assign(0);
+
+        final Set<TaskId> p3ActiveTasks = clients.get(p3).activeTasks();
+        assertThat(p3ActiveTasks.size(), equalTo(1));
+        if (p1PrevTasks.removeAll(p3ActiveTasks)) {
+            assertThat(clients.get(p2).activeTasks(), equalTo(p2PrevTasks));
+        } else {
+            assertThat(clients.get(p1).activeTasks(), equalTo(p1PrevTasks));
+        }
+    }
+
+    @Test
+    public void shouldNotMoveAnyTasksWhenNewTasksAdded() throws Exception {
+        final TaskId task04 = new TaskId(0, 4);
+        final TaskId task05 = new TaskId(0, 5);
+
+        createClientWithPreviousActiveTasks(p1, 1, task00, task01);
+        createClientWithPreviousActiveTasks(p2, 1, task02, task03);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task03, task01, task04, task02, task00, task05);
+        taskAssignor.assign(0);
+
+        assertThat(clients.get(p1).activeTasks(), hasItems(task00, task01));
+        assertThat(clients.get(p2).activeTasks(), hasItems(task02, task03));
+    }
+
+    @Test
+    public void shouldAssignNewTasksToNewClientWhenPreviousTasksAssignedToOldClients() throws Exception {
+        final TaskId task04 = new TaskId(0, 4);
+        final TaskId task05 = new TaskId(0, 5);
+
+        createClientWithPreviousActiveTasks(p1, 1, task02, task01);
+        createClientWithPreviousActiveTasks(p2, 1, task00, task03);
+        createClient(p3, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task03, task01, task04, task02, task00, task05);
+        taskAssignor.assign(0);
+
+        assertThat(clients.get(p1).activeTasks(), hasItems(task02, task01));
+        assertThat(clients.get(p2).activeTasks(), hasItems(task00, task03));
+        assertThat(clients.get(p3).activeTasks(), hasItems(task04, task05));
+    }
+
+    private StickyTaskAssignor<Integer> createTaskAssignor(final TaskId... tasks) {
+        return new StickyTaskAssignor<>(clients,
+                                        new HashSet<>(Arrays.asList(tasks)));
+    }
+
+    private List<TaskId> allActiveTasks() {
+        final List<TaskId> allActive = new ArrayList<>();
+        for (final ClientState<TaskId> client : clients.values()) {
+            allActive.addAll(client.activeTasks());
+        }
+        Collections.sort(allActive);
+        return allActive;
+    }
+
+    private List<TaskId> allStandbyTasks() {
+        final List<TaskId> tasks = new ArrayList<>();
+        for (final ClientState<TaskId> client : clients.values()) {
+            tasks.addAll(client.standbyTasks());
+        }
+        Collections.sort(tasks);
+        return tasks;
+    }
+
+    private ClientState<TaskId> createClient(final Integer processId, final int capacity) {
+        return createClientWithPreviousActiveTasks(processId, capacity);
+    }
+
+    private ClientState<TaskId> createClientWithPreviousActiveTasks(final Integer processId, final int capacity, final TaskId... taskIds) {
+        final ClientState<TaskId> clientState = new ClientState<>(capacity);
+        clientState.addPreviousActiveTasks(Utils.mkSet(taskIds));
+        clients.put(processId, clientState);
+        return clientState;
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/0b48ea1c/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java
deleted file mode 100644
index 52ca0a4..0000000
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java
+++ /dev/null
@@ -1,312 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.kafka.streams.processor.internals.assignment;
-
-import static org.apache.kafka.common.utils.Utils.mkList;
-import static org.apache.kafka.common.utils.Utils.mkSet;
-import org.junit.Test;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-
-public class TaskAssignorTest {
-
-    private static Map<Integer, ClientState<Integer>> copyStates(Map<Integer, ClientState<Integer>> states) {
-        Map<Integer, ClientState<Integer>> copy = new HashMap<>();
-        for (Map.Entry<Integer, ClientState<Integer>> entry : states.entrySet()) {
-            copy.put(entry.getKey(), entry.getValue().copy());
-        }
-
-        return copy;
-    }
-
-    @Test
-    public void testAssignWithoutStandby() {
-        HashMap<Integer, ClientState<Integer>> statesWithNoPrevTasks = new HashMap<>();
-        for (int i = 0; i < 6; i++) {
-            statesWithNoPrevTasks.put(i, new ClientState<Integer>(1d));
-        }
-        Set<Integer> tasks;
-        int numActiveTasks;
-        int numAssignedTasks;
-
-        Map<Integer, ClientState<Integer>> states;
-
-        // # of clients and # of tasks are equal.
-        states = copyStates(statesWithNoPrevTasks);
-        tasks = mkSet(0, 1, 2, 3, 4, 5);
-        TaskAssignor.assign(states, tasks, 0);
-        numActiveTasks = 0;
-        numAssignedTasks = 0;
-        for (ClientState<Integer> assignment : states.values()) {
-            numActiveTasks += assignment.activeTasks.size();
-            numAssignedTasks += assignment.assignedTasks.size();
-            assertEquals(1, assignment.activeTasks.size());
-            assertEquals(1, assignment.assignedTasks.size());
-        }
-        assertEquals(tasks.size(), numActiveTasks);
-        assertEquals(tasks.size(), numAssignedTasks);
-
-        // # of clients < # of tasks
-        tasks = mkSet(0, 1, 2, 3, 4, 5, 6, 7);
-        states = copyStates(statesWithNoPrevTasks);
-        TaskAssignor.assign(states, tasks, 0);
-        numActiveTasks = 0;
-        numAssignedTasks = 0;
-        for (ClientState<Integer> assignment : states.values()) {
-            numActiveTasks += assignment.activeTasks.size();
-            numAssignedTasks += assignment.assignedTasks.size();
-            assertTrue(1 <= assignment.activeTasks.size());
-            assertTrue(2 >= assignment.activeTasks.size());
-            assertTrue(1 <= assignment.assignedTasks.size());
-            assertTrue(2 >= assignment.assignedTasks.size());
-        }
-        assertEquals(tasks.size(), numActiveTasks);
-        assertEquals(tasks.size(), numAssignedTasks);
-
-        // # of clients > # of tasks
-        tasks = mkSet(0, 1, 2, 3);
-        states = copyStates(statesWithNoPrevTasks);
-        TaskAssignor.assign(states, tasks, 0);
-        numActiveTasks = 0;
-        numAssignedTasks = 0;
-        for (ClientState<Integer> assignment : states.values()) {
-            numActiveTasks += assignment.activeTasks.size();
-            numAssignedTasks += assignment.assignedTasks.size();
-            assertTrue(0 <= assignment.activeTasks.size());
-            assertTrue(1 >= assignment.activeTasks.size());
-            assertTrue(0 <= assignment.assignedTasks.size());
-            assertTrue(1 >= assignment.assignedTasks.size());
-        }
-        assertEquals(tasks.size(), numActiveTasks);
-        assertEquals(tasks.size(), numAssignedTasks);
-    }
-
-    @Test
-    public void testAssignWithStandby() {
-        HashMap<Integer, ClientState<Integer>> statesWithNoPrevTasks = new HashMap<>();
-        for (int i = 0; i < 6; i++) {
-            statesWithNoPrevTasks.put(i, new ClientState<Integer>(1d));
-        }
-        Set<Integer> tasks;
-        Map<Integer, ClientState<Integer>> states;
-        int numActiveTasks;
-        int numAssignedTasks;
-
-        // # of clients and # of tasks are equal.
-        tasks = mkSet(0, 1, 2, 3, 4, 5);
-
-        // 1 standby replicas.
-        numActiveTasks = 0;
-        numAssignedTasks = 0;
-        states = copyStates(statesWithNoPrevTasks);
-        TaskAssignor.assign(states, tasks, 1);
-        for (ClientState<Integer> assignment : states.values()) {
-            numActiveTasks += assignment.activeTasks.size();
-            numAssignedTasks += assignment.assignedTasks.size();
-            assertEquals(1, assignment.activeTasks.size());
-            assertEquals(2, assignment.assignedTasks.size());
-        }
-        assertEquals(tasks.size(), numActiveTasks);
-        assertEquals(tasks.size() * 2, numAssignedTasks);
-
-        // # of clients < # of tasks
-        tasks = mkSet(0, 1, 2, 3, 4, 5, 6, 7);
-
-        // 1 standby replicas.
-        states = copyStates(statesWithNoPrevTasks);
-        TaskAssignor.assign(states, tasks, 1);
-        numActiveTasks = 0;
-        numAssignedTasks = 0;
-        for (ClientState<Integer> assignment : states.values()) {
-            numActiveTasks += assignment.activeTasks.size();
-            numAssignedTasks += assignment.assignedTasks.size();
-            assertTrue(1 <= assignment.activeTasks.size());
-            assertTrue(2 >= assignment.activeTasks.size());
-            assertTrue(2 <= assignment.assignedTasks.size());
-            assertTrue(3 >= assignment.assignedTasks.size());
-        }
-        assertEquals(tasks.size(), numActiveTasks);
-        assertEquals(tasks.size() * 2, numAssignedTasks);
-
-        // # of clients > # of tasks
-        tasks = mkSet(0, 1, 2, 3);
-
-        // 1 standby replicas.
-        states = copyStates(statesWithNoPrevTasks);
-        TaskAssignor.assign(states, tasks, 1);
-        numActiveTasks = 0;
-        numAssignedTasks = 0;
-        for (ClientState<Integer> assignment : states.values()) {
-            numActiveTasks += assignment.activeTasks.size();
-            numAssignedTasks += assignment.assignedTasks.size();
-            assertTrue(0 <= assignment.activeTasks.size());
-            assertTrue(1 >= assignment.activeTasks.size());
-            assertTrue(1 <= assignment.assignedTasks.size());
-            assertTrue(2 >= assignment.assignedTasks.size());
-        }
-        assertEquals(tasks.size(), numActiveTasks);
-        assertEquals(tasks.size() * 2, numAssignedTasks);
-
-        // # of clients >> # of tasks
-        tasks = mkSet(0, 1);
-
-        // 1 standby replicas.
-        states = copyStates(statesWithNoPrevTasks);
-        TaskAssignor.assign(states, tasks, 1);
-        numActiveTasks = 0;
-        numAssignedTasks = 0;
-        for (ClientState<Integer> assignment : states.values()) {
-            numActiveTasks += assignment.activeTasks.size();
-            numAssignedTasks += assignment.assignedTasks.size();
-            assertTrue(0 <= assignment.activeTasks.size());
-            assertTrue(1 >= assignment.activeTasks.size());
-            assertTrue(0 <= assignment.assignedTasks.size());
-            assertTrue(1 >= assignment.assignedTasks.size());
-        }
-        assertEquals(tasks.size(), numActiveTasks);
-        assertEquals(tasks.size() * 2, numAssignedTasks);
-
-        // 2 standby replicas.
-        states = copyStates(statesWithNoPrevTasks);
-        TaskAssignor.assign(states, tasks, 2);
-        numActiveTasks = 0;
-        numAssignedTasks = 0;
-        for (ClientState<Integer> assignment : states.values()) {
-            numActiveTasks += assignment.activeTasks.size();
-            numAssignedTasks += assignment.assignedTasks.size();
-            assertTrue(0 <= assignment.activeTasks.size());
-            assertTrue(1 >= assignment.activeTasks.size());
-            assertTrue(1 == assignment.assignedTasks.size());
-        }
-        assertEquals(tasks.size(), numActiveTasks);
-        assertEquals(tasks.size() * 3, numAssignedTasks);
-
-        // 3 standby replicas.
-        states = copyStates(statesWithNoPrevTasks);
-        TaskAssignor.assign(states, tasks, 3);
-        numActiveTasks = 0;
-        numAssignedTasks = 0;
-        for (ClientState<Integer> assignment : states.values()) {
-            numActiveTasks += assignment.activeTasks.size();
-            numAssignedTasks += assignment.assignedTasks.size();
-            assertTrue(0 <= assignment.activeTasks.size());
-            assertTrue(1 >= assignment.activeTasks.size());
-            assertTrue(1 <= assignment.assignedTasks.size());
-            assertTrue(2 >= assignment.assignedTasks.size());
-        }
-        assertEquals(tasks.size(), numActiveTasks);
-        assertEquals(tasks.size() * 4, numAssignedTasks);
-    }
-
-    @Test
-    public void testStickiness() {
-        List<Integer> tasks;
-        Map<Integer, ClientState<Integer>> statesWithPrevTasks;
-        Map<Integer, ClientState<Integer>> assignments;
-        int i;
-
-        // # of clients and # of tasks are equal.
-        Map<Integer, ClientState<Integer>> states;
-        tasks = mkList(0, 1, 2, 3, 4, 5);
-        Collections.shuffle(tasks);
-        statesWithPrevTasks = new HashMap<>();
-        i = 0;
-        for (int task : tasks) {
-            ClientState<Integer> state = new ClientState<>(1d);
-            state.prevActiveTasks.add(task);
-            state.prevAssignedTasks.add(task);
-            statesWithPrevTasks.put(i++, state);
-        }
-        states = copyStates(statesWithPrevTasks);
-        TaskAssignor.assign(states, mkSet(0, 1, 2, 3, 4, 5), 0);
-        for (int client : states.keySet()) {
-            Set<Integer> oldActive = statesWithPrevTasks.get(client).prevActiveTasks;
-            Set<Integer> oldAssigned = statesWithPrevTasks.get(client).prevAssignedTasks;
-            Set<Integer> newActive = states.get(client).activeTasks;
-            Set<Integer> newAssigned = states.get(client).assignedTasks;
-
-            assertEquals(oldActive, newActive);
-            assertEquals(oldAssigned, newAssigned);
-        }
-
-        // # of clients > # of tasks
-        tasks = mkList(0, 1, 2, 3, -1, -1);
-        Collections.shuffle(tasks);
-        statesWithPrevTasks = new HashMap<>();
-        i = 0;
-        for (int task : tasks) {
-            ClientState<Integer> state = new ClientState<>(1d);
-            if (task >= 0) {
-                state.prevActiveTasks.add(task);
-                state.prevAssignedTasks.add(task);
-            }
-            statesWithPrevTasks.put(i++, state);
-        }
-        states = copyStates(statesWithPrevTasks);
-        TaskAssignor.assign(states, mkSet(0, 1, 2, 3), 0);
-        for (int client : states.keySet()) {
-            Set<Integer> oldActive = statesWithPrevTasks.get(client).prevActiveTasks;
-            Set<Integer> oldAssigned = statesWithPrevTasks.get(client).prevAssignedTasks;
-            Set<Integer> newActive = states.get(client).activeTasks;
-            Set<Integer> newAssigned = states.get(client).assignedTasks;
-
-            assertEquals(oldActive, newActive);
-            assertEquals(oldAssigned, newAssigned);
-        }
-
-        // # of clients < # of tasks
-        List<Set<Integer>> taskSets = mkList(mkSet(0, 1), mkSet(2, 3), mkSet(4, 5), mkSet(6, 7), mkSet(8, 9), mkSet(10, 11));
-        Collections.shuffle(taskSets);
-        statesWithPrevTasks = new HashMap<>();
-        i = 0;
-        for (Set<Integer> taskSet : taskSets) {
-            ClientState<Integer> state = new ClientState<>(1d);
-            state.prevActiveTasks.addAll(taskSet);
-            state.prevAssignedTasks.addAll(taskSet);
-            statesWithPrevTasks.put(i++, state);
-        }
-        states = copyStates(statesWithPrevTasks);
-        TaskAssignor.assign(states, mkSet(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), 0);
-        for (int client : states.keySet()) {
-            Set<Integer> oldActive = statesWithPrevTasks.get(client).prevActiveTasks;
-            Set<Integer> oldAssigned = statesWithPrevTasks.get(client).prevAssignedTasks;
-            Set<Integer> newActive = states.get(client).activeTasks;
-            Set<Integer> newAssigned = states.get(client).assignedTasks;
-
-            Set<Integer> intersection = new HashSet<>();
-
-            intersection.addAll(oldActive);
-            intersection.retainAll(newActive);
-            assertTrue(intersection.size() > 0);
-
-            intersection.clear();
-            intersection.addAll(oldAssigned);
-            intersection.retainAll(newAssigned);
-            assertTrue(intersection.size() > 0);
-        }
-    }
-
-}


Mime
View raw message