kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From damian...@apache.org
Subject [2/3] kafka git commit: KAFKA-5152; move state restoration out of rebalance and into poll loop
Date Wed, 16 Aug 2017 10:14:13 GMT
http://git-wip-us.apache.org/repos/asf/kafka/blob/b268322e/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index d25af64..2d91e1b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
-import org.apache.kafka.clients.consumer.CommitFailedException;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
@@ -27,8 +26,6 @@ import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.config.ConfigDef;
-import org.apache.kafka.common.config.ConfigDef.Type;
 import org.apache.kafka.common.errors.ProducerFencedException;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
@@ -40,7 +37,6 @@ import org.apache.kafka.common.metrics.stats.Rate;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.KafkaClientSupplier;
 import org.apache.kafka.streams.StreamsConfig;
-import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskIdFormatException;
 import org.apache.kafka.streams.processor.PartitionGrouper;
@@ -63,7 +59,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.regex.Pattern;
@@ -79,7 +74,7 @@ public class StreamThread extends Thread {
      * Stream thread states are the possible states that a stream thread can be in.
      * A thread must only be in one state at a time
      * The expected state transitions with the following defined states is:
-     *
+     * <p>
      * <pre>
      *                +-------------+
      *          +<--- | Created     |
@@ -98,30 +93,31 @@ public class StreamThread extends Thread {
      *          |           |              |
      *          |           v              |
      *          |     +-----+-------+      |
-     *          +<--- | Assigning   |      |
-     *          |     | Partitions  | ---->+
-     *          |     +-----+-------+
+     *          +<--- | Partitions  |+---> |
+     *          |     | Assigned    |
+     *          |     ------+-------+
      *          |           |
      *          |           v
      *          |     +-----+-------+
      *          +---> | Pending     |
-     *                | Shutdown    |
-     *                +-----+-------+
-     *                      |
-     *                      v
-     *                +-----+-------+
-     *                | Dead        |
+     *          |     | Shutdown    |
+     *          |     +-----+-------+
+     *          |           |
+     *          |           v
+     *          |     +-----+-------+
+     *          +---> | Dead        |
      *                +-------------+
      * </pre>
-     *
+     * <p>
      * Note the following:
-     * - Any state can go to PENDING_SHUTDOWN followed by a subsequent transition to DEAD.
+     * - Any state can go to PENDING_SHUTDOWN. That is because streams can be closed at any time.
+     * - Any state can go to DEAD. That is because exceptions can happen at any other state,
+     * leading to the stream thread terminating.
      * - A streams thread can stay in PARTITIONS_REVOKED indefinitely, in the corner case when
-     *   the coordinator repeatedly fails in-between revoking partitions and assigning new partitions.
-     *
+     * the coordinator repeatedly fails in-between revoking partitions and assigning new partitions.
      */
     public enum State implements ThreadStateTransitionValidator {
-        CREATED(1, 4), RUNNING(2, 4), PARTITIONS_REVOKED(2, 3, 4), ASSIGNING_PARTITIONS(1, 4), PENDING_SHUTDOWN(5), DEAD;
+        CREATED(1, 4), RUNNING(2, 4), PARTITIONS_REVOKED(2, 3, 4), PARTITIONS_ASSIGNED(1, 2, 4),  PENDING_SHUTDOWN(5), DEAD;
 
         private final Set<Integer> validTransitions = new HashSet<>();
 
@@ -156,74 +152,71 @@ public class StreamThread extends Thread {
 
     private class RebalanceListener implements ConsumerRebalanceListener {
         private final Time time;
-        private final int requestTimeOut;
 
-        RebalanceListener(final Time time, final int requestTimeOut) {
+        RebalanceListener(final Time time) {
             this.time = time;
-            this.requestTimeOut = requestTimeOut;
         }
 
         @Override
         public void onPartitionsAssigned(final Collection<TopicPartition> assignment) {
             log.debug("{} at state {}: new partitions {} assigned at the end of consumer rebalance.\n" +
-                    "\tassigned active tasks: {}\n" +
-                    "\tassigned standby tasks: {}\n" +
-                    "\tcurrent suspended active tasks: {}\n" +
-                    "\tcurrent suspended standby tasks: {}\n" +
-                    "\tprevious active tasks: {}",
-                logPrefix,
-                state,
-                assignment,
-                partitionAssignor.activeTasks().keySet(),
-                partitionAssignor.standbyTasks().keySet(),
-                suspendedTasks.keySet(),
-                suspendedStandbyTasks.keySet(),
-                prevActiveTasks);
+                              "\tassigned active tasks: {}\n" +
+                              "\tassigned standby tasks: {}\n" +
+                              "\tcurrent suspended active tasks: {}\n" +
+                              "\tcurrent suspended standby tasks: {}\n",
+                      logPrefix,
+                      state,
+                      assignment,
+                      partitionAssignor.activeTasks().keySet(),
+                      partitionAssignor.standbyTasks().keySet(),
+                      active.previousTaskIds(),
+                      standby.previousTaskIds());
 
             final long start = time.milliseconds();
             try {
-                storeChangelogReader = new StoreChangelogReader(getName(), restoreConsumer, time, requestTimeOut);
-                setState(State.ASSIGNING_PARTITIONS);
+                if (!setState(State.PARTITIONS_ASSIGNED)) {
+                    return;
+                }
                 // do this first as we may have suspended standby tasks that
                 // will become active or vice versa
                 closeNonAssignedSuspendedStandbyTasks();
                 closeNonAssignedSuspendedTasks();
-                addStreamTasks(assignment, start);
-                storeChangelogReader.restore();
-                addStandbyTasks(start);
+                addStreamTasks(assignment);
+                addStandbyTasks();
                 streamsMetadataState.onChange(partitionAssignor.getPartitionsByHostState(), partitionAssignor.clusterMetadata());
-                lastCleanMs = time.milliseconds(); // start the cleaning cycle
-                setState(State.RUNNING);
+                storeChangelogReader.reset();
+                Set<TopicPartition> partitions = active.uninitializedPartitions();
+                log.trace("{} pausing partitions: {}", logPrefix, partitions);
+                consumer.pause(partitions);
             } catch (final Throwable t) {
                 rebalanceException = t;
                 throw t;
             } finally {
                 log.info("{} partition assignment took {} ms.\n" +
-                        "\tcurrent active tasks: {}\n" +
-                        "\tcurrent standby tasks: {}\n" +
-                        "\tprevious active tasks: {}\n",
-                    logPrefix,
-                    time.milliseconds() - start,
-                    activeTasks.keySet(),
-                    standbyTasks.keySet(),
-                    prevActiveTasks);
+                                 "\tcurrent active tasks: {}\n" +
+                                 "\tcurrent standby tasks: {}\n" +
+                                 "\tprevious active tasks: {}\n",
+                         logPrefix,
+                         time.milliseconds() - start,
+                         active.allAssignedTaskIds(),
+                         standby.allAssignedTaskIds(),
+                         active.previousTaskIds());
             }
         }
 
         @Override
         public void onPartitionsRevoked(final Collection<TopicPartition> assignment) {
             log.debug("{} at state {}: partitions {} revoked at the beginning of consumer rebalance.\n" +
-                    "\tcurrent assigned active tasks: {}\n" +
-                    "\tcurrent assigned standby tasks: {}\n",
-                logPrefix,
-                state,
-                assignment,
-                activeTasks.keySet(), standbyTasks.keySet());
+                              "\tcurrent assigned active tasks: {}\n" +
+                              "\tcurrent assigned standby tasks: {}\n",
+                      logPrefix,
+                      state,
+                      assignment,
+                      active.runningTaskIds(), standby.runningTaskIds());
 
             final long start = time.milliseconds();
             try {
                 setState(State.PARTITIONS_REVOKED);
-                lastCleanMs = Long.MAX_VALUE; // stop the cleaning cycle until partitions are assigned
                 // suspend active tasks
                 suspendTasksAndState();
             } catch (final Throwable t) {
@@ -231,57 +224,25 @@ public class StreamThread extends Thread {
                 throw t;
             } finally {
                 streamsMetadataState.onChange(Collections.<HostInfo, Set<TopicPartition>>emptyMap(), partitionAssignor.clusterMetadata());
-                removeStreamTasks();
-                removeStandbyTasks();
+                standbyRecords.clear();
 
                 log.info("{} partition revocation took {} ms.\n" +
-                        "\tsuspended active tasks: {}\n" +
-                        "\tsuspended standby tasks: {}",
-                    logPrefix,
-                    time.milliseconds() - start,
-                    suspendedTasks.keySet(),
-                    suspendedStandbyTasks.keySet());
+                                 "\tsuspended active tasks: {}\n" +
+                                 "\tsuspended standby tasks: {}",
+                         logPrefix,
+                         time.milliseconds() - start,
+                         active.suspendedTaskIds(),
+                         standby.suspendedTaskIds());
             }
         }
     }
 
     abstract class AbstractTaskCreator {
-        final static long MAX_BACKOFF_TIME_MS = 1000L;
-        void retryWithBackoff(final Map<TaskId, Set<TopicPartition>> tasksToBeCreated, final long start) {
-            long backoffTimeMs = 50L;
-            final Set<TaskId> retryingTasks = new HashSet<>();
-            while (true) {
-                final Iterator<Map.Entry<TaskId, Set<TopicPartition>>> it = tasksToBeCreated.entrySet().iterator();
-                while (it.hasNext()) {
-                    final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions = it.next();
-                    final TaskId taskId = newTaskAndPartitions.getKey();
-                    final Set<TopicPartition> partitions = newTaskAndPartitions.getValue();
-
-                    try {
-                        createTask(taskId, partitions);
-                        it.remove();
-                        backoffTimeMs = 50L;
-                        retryingTasks.remove(taskId);
-                    } catch (final LockException e) {
-                        // ignore and retry
-                        if (!retryingTasks.contains(taskId)) {
-                            log.warn("{} Could not create task {} due to {}; will retry", logPrefix, taskId, e);
-                            retryingTasks.add(taskId);
-                        }
-                    }
-                }
-
-                if (tasksToBeCreated.isEmpty() || time.milliseconds() - start > rebalanceTimeoutMs) {
-                    break;
-                }
-
-                try {
-                    Thread.sleep(backoffTimeMs);
-                    backoffTimeMs <<= 1;
-                    backoffTimeMs = Math.min(backoffTimeMs, MAX_BACKOFF_TIME_MS);
-                } catch (final InterruptedException e) {
-                    // ignore
-                }
+        void createTasks(final Map<TaskId, Set<TopicPartition>> tasksToBeCreated) {
+            for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions : tasksToBeCreated.entrySet()) {
+                final TaskId taskId = newTaskAndPartitions.getKey();
+                final Set<TopicPartition> partitions = newTaskAndPartitions.getValue();
+                createTask(taskId, partitions);
             }
         }
 
@@ -291,35 +252,21 @@ public class StreamThread extends Thread {
     class TaskCreator extends AbstractTaskCreator {
         @Override
         void createTask(final TaskId taskId, final Set<TopicPartition> partitions) {
-            final StreamTask task = createStreamTask(taskId, partitions);
-
-            activeTasks.put(taskId, task);
-
-            for (final TopicPartition partition : partitions) {
-                activeTasksByPartition.put(partition, task);
-            }
+            active.addNewTask(createStreamTask(taskId, partitions));
         }
     }
 
     class StandbyTaskCreator extends AbstractTaskCreator {
-        private final Map<TopicPartition, Long> checkpointedOffsets;
-
-        StandbyTaskCreator(final Map<TopicPartition, Long> checkpointedOffsets) {
-            this.checkpointedOffsets = checkpointedOffsets;
-        }
 
         @Override
         void createTask(final TaskId taskId, final Set<TopicPartition> partitions) {
             final StandbyTask task = createStandbyTask(taskId, partitions);
-            updateStandByTaskMaps(checkpointedOffsets, taskId, partitions, task);
+            if (task != null) {
+                standby.addNewTask(task);
+            }
         }
     }
 
-    interface StreamTaskAction {
-        String name();
-        void apply(final StreamTask task);
-    }
-
     /**
      * This class extends {@link StreamsMetricsImpl(Metrics, String, String, Map)} and
      * overrides one of its functions for efficiency
@@ -404,17 +351,11 @@ public class StreamThread extends Thread {
     private final String logPrefix;
     private final String threadClientId;
     private final Pattern sourceTopicPattern;
-    private final Map<TaskId, StreamTask> activeTasks;
-    private final Map<TaskId, StandbyTask> standbyTasks;
-    private final Map<TopicPartition, StreamTask> activeTasksByPartition;
-    private final Map<TopicPartition, StandbyTask> standbyTasksByPartition;
-    private final Set<TaskId> prevActiveTasks;
-    private final Map<TaskId, StreamTask> suspendedTasks;
-    private final Map<TaskId, StandbyTask> suspendedStandbyTasks;
+    private final AssignedTasks<StreamTask> active;
+    private final AssignedTasks<StandbyTask> standby;
+
     private final Time time;
-    private final int rebalanceTimeoutMs;
     private final long pollTimeMs;
-    private final long cleanTimeMs;
     private final long commitTimeMs;
     private final StreamsMetricsThreadImpl streamsMetrics;
     // TODO: this is not private only for tests, should be better refactored
@@ -422,7 +363,6 @@ public class StreamThread extends Thread {
     private String originalReset;
     private StreamPartitionAssignor partitionAssignor;
     private long timerStartedMs;
-    private long lastCleanMs;
     private long lastCommitMs;
     private Throwable rebalanceException = null;
     private final boolean eosEnabled;
@@ -431,7 +371,7 @@ public class StreamThread extends Thread {
     private boolean processStandbyRecords = false;
 
     private final ThreadCache cache;
-    private StoreChangelogReader storeChangelogReader;
+    final StoreChangelogReader storeChangelogReader;
 
     private final TaskCreator taskCreator = new TaskCreator();
 
@@ -484,32 +424,22 @@ public class StreamThread extends Thread {
         consumer = clientSupplier.getConsumer(consumerConfigs);
         log.info("{} Creating restore consumer client", logPrefix);
         restoreConsumer = clientSupplier.getRestoreConsumer(config.getRestoreConsumerConfigs(threadClientId));
-        // initialize the task list
-        // activeTasks needs to be concurrent as it can be accessed
-        // by QueryableState
-        activeTasks = new ConcurrentHashMap<>();
-        standbyTasks = new HashMap<>();
-        activeTasksByPartition = new HashMap<>();
-        standbyTasksByPartition = new HashMap<>();
-        prevActiveTasks = new HashSet<>();
-        suspendedTasks = new HashMap<>();
-        suspendedStandbyTasks = new HashMap<>();
-
         // standby KTables
         standbyRecords = new HashMap<>();
 
+
         this.stateDirectory = stateDirectory;
-        final Object maxPollInterval = consumerConfigs.get(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG);
-        rebalanceTimeoutMs =  (Integer) ConfigDef.parseType(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, maxPollInterval, Type.INT);
         pollTimeMs = config.getLong(StreamsConfig.POLL_MS_CONFIG);
         commitTimeMs = config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG);
-        cleanTimeMs = config.getLong(StreamsConfig.STATE_CLEANUP_DELAY_MS_CONFIG);
 
         this.time = time;
         timerStartedMs = time.milliseconds();
-        lastCleanMs = Long.MAX_VALUE; // the cleaning cycle won't start until partition assignment
         lastCommitMs = timerStartedMs;
-        rebalanceListener = new RebalanceListener(time, config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG));
+        final Integer requestTimeOut = config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG);
+        rebalanceListener = new RebalanceListener(time);
+        active = new AssignedTasks<>(logPrefix, "stream task", Time.SYSTEM);
+        standby = new AssignedTasks<>(logPrefix, "standby task", Time.SYSTEM);
+        storeChangelogReader = new StoreChangelogReader(getName(), restoreConsumer, time, requestTimeOut);
     }
 
     /**
@@ -547,27 +477,52 @@ public class StreamThread extends Thread {
         consumer.subscribe(sourceTopicPattern, rebalanceListener);
 
         while (stillRunning()) {
-            timerStartedMs = time.milliseconds();
-
-            // try to fetch some records if necessary
-            final ConsumerRecords<byte[], byte[]> records = pollRequests();
-            if (records != null && !records.isEmpty() && !activeTasks.isEmpty()) {
-                streamsMetrics.pollTimeSensor.record(computeLatency(), timerStartedMs);
-                addRecordsToTasks(records);
-                final long totalProcessed = processAndPunctuate(activeTasks, recordsProcessedBeforeCommit);
-                if (totalProcessed > 0) {
-                    final long processLatency = computeLatency();
-                    streamsMetrics.processTimeSensor.record(processLatency / (double) totalProcessed,
-                        timerStartedMs);
-                    recordsProcessedBeforeCommit = adjustRecordsProcessedBeforeCommit(recordsProcessedBeforeCommit, totalProcessed,
-                        processLatency, commitTimeMs);
-                }
+            recordsProcessedBeforeCommit = runOnce(recordsProcessedBeforeCommit);
+        }
+        log.info("{} Shutting down at user request", logPrefix);
+    }
+
+    // Visible for testing
+    long runOnce(long recordsProcessedBeforeCommit) {
+        timerStartedMs = time.milliseconds();
+
+        // try to fetch some records if necessary
+        final ConsumerRecords<byte[], byte[]> records = pollRequests();
+
+        if (state == State.PARTITIONS_ASSIGNED) {
+            active.initializeNewTasks();
+            standby.initializeNewTasks();
+
+            final Collection<TopicPartition> restored = storeChangelogReader.restore();
+            final Set<TopicPartition> resumed = active.updateRestored(restored);
+
+            if (!resumed.isEmpty()) {
+                log.trace("{} resuming partitions {}", logPrefix, resumed);
+                consumer.resume(resumed);
             }
 
-            maybeCommit(timerStartedMs);
-            maybeUpdateStandbyTasks(timerStartedMs);
+            if (active.allTasksRunning()) {
+                assignStandbyPartitions();
+                setState(State.RUNNING);
+            }
         }
-        log.info("{} Shutting down at user request", logPrefix);
+
+        if (records != null && !records.isEmpty() && active.hasRunningTasks()) {
+            streamsMetrics.pollTimeSensor.record(computeLatency(), timerStartedMs);
+            addRecordsToTasks(records);
+            final long totalProcessed = processAndPunctuate(recordsProcessedBeforeCommit);
+            if (totalProcessed > 0) {
+                final long processLatency = computeLatency();
+                streamsMetrics.processTimeSensor.record(processLatency / (double) totalProcessed,
+                                                        timerStartedMs);
+                recordsProcessedBeforeCommit = adjustRecordsProcessedBeforeCommit(recordsProcessedBeforeCommit, totalProcessed,
+                                                                                  processLatency, commitTimeMs);
+            }
+        }
+
+        maybeCommit(timerStartedMs);
+        maybeUpdateStandbyTasks(timerStartedMs);
+        return recordsProcessedBeforeCommit;
     }
 
     /**
@@ -645,7 +600,7 @@ public class StreamThread extends Thread {
             int numAddedRecords = 0;
 
             for (final TopicPartition partition : records.partitions()) {
-                final StreamTask task = activeTasksByPartition.get(partition);
+                final StreamTask task = active.runningTaskFor(partition);
                 numAddedRecords += task.addRecords(partition, records.records(partition));
             }
             streamsMetrics.skippedRecordsSensor.record(records.count() - numAddedRecords, timerStartedMs);
@@ -655,38 +610,19 @@ public class StreamThread extends Thread {
     /**
      * Schedule the records processing by selecting which record is processed next. Commits may
      * happen as records are processed.
-     * @param tasks The tasks that have records.
      * @param recordsProcessedBeforeCommit number of records to be processed before commit is called.
      *                                     if UNLIMITED_RECORDS, then commit is never called
      * @return Number of records processed since last commit.
      */
-    private long processAndPunctuate(final Map<TaskId, StreamTask> tasks,
-                                     final long recordsProcessedBeforeCommit) {
+    private long processAndPunctuate(final long recordsProcessedBeforeCommit) {
 
-        long totalProcessedEachRound;
+        int processed;
         long totalProcessedSinceLastMaybeCommit = 0;
         // Round-robin scheduling by taking one record from each task repeatedly
         // until no task has any records left
         do {
-            totalProcessedEachRound = 0;
-            final Iterator<Map.Entry<TaskId, StreamTask>> it = tasks.entrySet().iterator();
-            while (it.hasNext()) {
-                final StreamTask task = it.next().getValue();
-                try {
-                    // we processed one record,
-                    // if more are buffered waiting for the next round
-
-                    // TODO: We should check for stream time punctuation right after each process call
-                    //       of the task instead of only calling it after all records being processed
-                    if (task.process()) {
-                        totalProcessedEachRound++;
-                        totalProcessedSinceLastMaybeCommit++;
-                    }
-                } catch (final ProducerFencedException e) {
-                    closeZombieTask(task);
-                    it.remove();
-                }
-            }
+            processed = active.process();
+            totalProcessedSinceLastMaybeCommit += processed;
 
             if (recordsProcessedBeforeCommit != UNLIMITED_RECORDS &&
                 totalProcessedSinceLastMaybeCommit >= recordsProcessedBeforeCommit) {
@@ -696,55 +632,13 @@ public class StreamThread extends Thread {
                     timerStartedMs);
                 maybeCommit(timerStartedMs);
             }
-        } while (totalProcessedEachRound != 0);
+        } while (processed != 0);
 
         // go over the tasks again to punctuate or commit
-        final RuntimeException e = performOnStreamTasks(new StreamTaskAction() {
-            private String name;
-            @Override
-            public String name() {
-                return name;
-            }
-
-            @Override
-            public void apply(final StreamTask task) {
-                name = "punctuate";
-                maybePunctuate(task);
-                if (task.commitNeeded()) {
-                    name = "commit";
-
-                    long beforeCommitMs = time.milliseconds();
-
-                    commitOne(task);
-
-                    if (log.isDebugEnabled()) {
-                        log.debug("{} Committed active task {} per user request in {}ms",
-                                logPrefix, task.id(), timerStartedMs - beforeCommitMs);
-                    }
-                }
-            }
-        });
-
-        if (e != null) {
-            throw e;
-        }
-
+        active.punctuateAndCommit(streamsMetrics.commitTimeSensor, streamsMetrics.punctuateTimeSensor);
         return totalProcessedSinceLastMaybeCommit;
     }
 
-    private void maybePunctuate(final StreamTask task) {
-        try {
-            // check whether we should punctuate based on the task's partition group timestamp;
-            // which are essentially based on record timestamp.
-            if (task.maybePunctuate()) {
-                streamsMetrics.punctuateTimeSensor.record(computeLatency(), timerStartedMs);
-            }
-        } catch (final KafkaException e) {
-            log.error("{} Failed to punctuate active task {} due to the following error:", logPrefix, task.id(), e);
-            throw e;
-        }
-    }
-
     /**
      * Adjust the number of records that should be processed by scheduler. This avoids
      * scenarios where the processing time is higher than the commit time.
@@ -783,14 +677,14 @@ public class StreamThread extends Thread {
         if (commitTimeMs >= 0 && lastCommitMs + commitTimeMs < now) {
             if (log.isTraceEnabled()) {
                 log.trace("{} Committing all active tasks {} and standby tasks {} since {}ms has elapsed (commit interval is {}ms)",
-                        logPrefix, activeTasks.keySet(), standbyTasks.keySet(), now - lastCommitMs, commitTimeMs);
+                          logPrefix, active.runningTaskIds(), standby.runningTaskIds(), now - lastCommitMs, commitTimeMs);
             }
 
             commitAll();
 
             if (log.isDebugEnabled()) {
                 log.info("{} Committed all active tasks {} and standby tasks {} in {}ms",
-                        logPrefix, activeTasks.keySet(), standbyTasks.keySet(), timerStartedMs - now);
+                         logPrefix, active.runningTaskIds(), standby.runningTaskIds(), timerStartedMs - now);
             }
 
             lastCommitMs = now;
@@ -803,46 +697,33 @@ public class StreamThread extends Thread {
      * Commit the states of all its tasks
      */
     private void commitAll() {
-        final RuntimeException e = performOnStreamTasks(new StreamTaskAction() {
-            @Override
-            public String name() {
-                return "commit";
-            }
-
-            @Override
-            public void apply(final StreamTask task) {
-                commitOne(task);
-            }
-        });
-        if (e != null) {
-            throw e;
-        }
-
-        for (final StandbyTask task : standbyTasks.values()) {
-            commitOne(task);
-        }
+        active.commit();
+        standby.commit();
     }
 
-    /**
-     * Commit the state of a task
-     */
-    private void commitOne(final AbstractTask task) {
-        try {
-            task.commit();
-        } catch (final CommitFailedException e) {
-            // commit failed. This is already logged inside the task as WARN and we can just log it again here.
-            log.warn("{} Failed to commit {} {} state due to CommitFailedException; this task may be no longer owned by the thread", logPrefix, task.getClass().getSimpleName(), task.id());
-        } catch (final KafkaException e) {
-            // commit failed due to an unexpected exception. Log it and rethrow the exception.
-            log.error("{} Failed to commit {} {} state due to the following error:", logPrefix, task.getClass().getSimpleName(), task.id(), e);
-            throw e;
+    private void assignStandbyPartitions() {
+        final Collection<StandbyTask> running = standby.runningTasks();
+        final Map<TopicPartition, Long> checkpointedOffsets = new HashMap<>();
+        for (StandbyTask standbyTask : running) {
+            checkpointedOffsets.putAll(standbyTask.checkpointedOffsets());
         }
 
-        streamsMetrics.commitTimeSensor.record(computeLatency(), timerStartedMs);
+        final List<TopicPartition> assignment = new ArrayList<>(checkpointedOffsets.keySet());
+        restoreConsumer.assign(assignment);
+        for (final Map.Entry<TopicPartition, Long> entry : checkpointedOffsets.entrySet()) {
+            final TopicPartition partition = entry.getKey();
+            final long offset = entry.getValue();
+            if (offset >= 0) {
+                restoreConsumer.seek(partition, offset);
+            } else {
+                restoreConsumer.seekToBeginning(singleton(partition));
+            }
+        }
+        log.trace("{} assigned {} partitions to restore consumer for standby tasks {}", logPrefix, assignment, standby.runningTaskIds());
     }
 
     private void maybeUpdateStandbyTasks(final long now) {
-        if (!standbyTasks.isEmpty()) {
+        if (state == State.RUNNING && standby.hasRunningTasks()) {
             if (processStandbyRecords) {
                 if (!standbyRecords.isEmpty()) {
                     final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> remainingStandbyRecords = new HashMap<>();
@@ -851,7 +732,7 @@ public class StreamThread extends Thread {
                         final TopicPartition partition = entry.getKey();
                         List<ConsumerRecord<byte[], byte[]>> remaining = entry.getValue();
                         if (remaining != null) {
-                            final StandbyTask task = standbyTasksByPartition.get(partition);
+                            final StandbyTask task = standby.runningTaskFor(partition);
                             remaining = task.update(partition, remaining);
                             if (remaining != null) {
                                 remainingStandbyRecords.put(partition, remaining);
@@ -863,7 +744,7 @@ public class StreamThread extends Thread {
 
                     standbyRecords = remainingStandbyRecords;
 
-                    log.debug("{} Updated standby tasks {} in {}ms", logPrefix, standbyTasks.keySet(), time.milliseconds() - now);
+                    log.debug("{} Updated standby tasks {} in {}ms", logPrefix, standby.runningTaskIds(), time.milliseconds() - now);
                 }
                 processStandbyRecords = false;
             }
@@ -872,7 +753,7 @@ public class StreamThread extends Thread {
 
             if (!records.isEmpty()) {
                 for (final TopicPartition partition : records.partitions()) {
-                    final StandbyTask task = standbyTasksByPartition.get(partition);
+                    final StandbyTask task = standby.runningTaskFor(partition);
 
                     if (task == null) {
                         throw new StreamsException(logPrefix + " Missing standby task for partition " + partition);
@@ -920,14 +801,14 @@ public class StreamThread extends Thread {
     }
 
     public Map<TaskId, StreamTask> tasks() {
-        return Collections.unmodifiableMap(activeTasks);
+        return active.runningTaskMap();
     }
 
     /**
      * Returns ids of tasks that were being executed before the rebalance.
      */
     public Set<TaskId> prevActiveTasks() {
-        return Collections.unmodifiableSet(prevActiveTasks);
+        return Collections.unmodifiableSet(active.previousTaskIds());
     }
 
     /**
@@ -975,11 +856,13 @@ public class StreamThread extends Thread {
         return state;
     }
 
+
+
     /**
      * Sets the state
      * @param newState New state
      */
-    void setState(final State newState) {
+    boolean setState(final State newState) {
         State oldState;
         synchronized (stateLock) {
             oldState = state;
@@ -996,7 +879,7 @@ public class StreamThread extends Thread {
             // transition, hence the check newState != DEAD.
             if (newState != State.DEAD &&
                     (state == State.PENDING_SHUTDOWN || state == State.DEAD)) {
-                return;
+                return false;
             }
             if (!state.isValidTransition(newState)) {
                 log.warn("{} Unexpected state transition from {} to {}.", logPrefix, oldState, newState);
@@ -1010,6 +893,7 @@ public class StreamThread extends Thread {
         if (stateListener != null) {
             stateListener.onChange(this, state, oldState);
         }
+        return true;
     }
 
     /**
@@ -1033,23 +917,11 @@ public class StreamThread extends Thread {
             .append(indent).append("\tStreamsThread clientId: ").append(clientId).append("\n")
             .append(indent).append("\tStreamsThread threadId: ").append(getName()).append("\n");
 
-        // iterate and print active tasks
-        if (activeTasks != null) {
-            sb.append(indent).append("\tActive tasks:\n");
-            for (final Map.Entry<TaskId, StreamTask> entry : activeTasks.entrySet()) {
-                final StreamTask task = entry.getValue();
-                sb.append(indent).append(task.toString(indent + "\t\t"));
-            }
-        }
-
-        // iterate and print standby tasks
-        if (standbyTasks != null) {
-            sb.append(indent).append("\tStandby tasks:\n");
-            for (final StandbyTask task : standbyTasks.values()) {
-                sb.append(indent).append(task.toString(indent + "\t\t"));
-            }
-            sb.append("\n");
-        }
+        sb.append(indent).append("\tActive tasks:\n");
+        sb.append(active.toString(indent + "\t\t"));
+        sb.append(indent).append("\tStandby tasks:\n");
+        sb.append(standby.toString(indent + "\t\t"));
+        sb.append("\n");
 
         return sb.toString();
     }
@@ -1062,7 +934,8 @@ public class StreamThread extends Thread {
         this.partitionAssignor = partitionAssignor;
     }
 
-    private void shutdown(final boolean cleanRun) {
+    // Visible for testing
+    void shutdown(final boolean cleanRun) {
         log.info("{} Shutting down", logPrefix);
         setState(State.PENDING_SHUTDOWN);
         shutdownTasksAndState(cleanRun);
@@ -1091,8 +964,8 @@ public class StreamThread extends Thread {
             log.error("{} Failed to close KafkaStreamClient due to the following error:", logPrefix, e);
         }
 
-        removeStreamTasks();
-        removeStandbyTasks();
+        active.clear();
+        standby.clear();
 
         // clean up global tasks
 
@@ -1104,18 +977,18 @@ public class StreamThread extends Thread {
     @SuppressWarnings("ThrowableNotThrown")
     private void shutdownTasksAndState(final boolean cleanRun) {
         log.debug("{} Shutting down all active tasks {}, standby tasks {}, suspended tasks {}, and suspended standby tasks {}",
-            logPrefix, activeTasks.keySet(), standbyTasks.keySet(),
-            suspendedTasks.keySet(), suspendedStandbyTasks.keySet());
+                  logPrefix, active.runningTaskIds(), standby.runningTaskIds(),
+                  active.previousTaskIds(), standby.previousTaskIds());
 
         for (final AbstractTask task : allTasks()) {
             try {
                 task.close(cleanRun);
             } catch (final RuntimeException e) {
                 log.error("{} Failed while closing {} {} due to the following error:",
-                    logPrefix,
-                    task.getClass().getSimpleName(),
-                    task.id(),
-                    e);
+                          logPrefix,
+                          task.getClass().getSimpleName(),
+                          task.id(),
+                          e);
             }
         }
 
@@ -1129,58 +1002,15 @@ public class StreamThread extends Thread {
      */
     private void suspendTasksAndState()  {
         log.debug("{} Suspending all active tasks {} and standby tasks {}",
-            logPrefix, activeTasks.keySet(), standbyTasks.keySet());
+                  logPrefix, active.runningTaskIds(), standby.runningTaskIds());
 
         final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
 
-        firstException.compareAndSet(null, performOnStreamTasks(new StreamTaskAction() {
-            @Override
-            public String name() {
-                return "suspend";
-            }
-
-            @Override
-            public void apply(final StreamTask task) {
-                try {
-                    task.suspend();
-                } catch (final CommitFailedException e) {
-                    // commit failed during suspension. Just log it.
-                    log.warn("{} Failed to commit task {} state when suspending due to CommitFailedException", logPrefix, task.id);
-                } catch (final Exception e) {
-                    log.error("{} Suspending task {} failed due to the following error:", logPrefix, task.id, e);
-                    try {
-                        task.close(false);
-                    } catch (final Exception f) {
-                        log.error("{} After suspending failed, closing the same task {} failed again due to the following error:", logPrefix, task.id, f);
-                    }
-                    throw e;
-                }
-            }
-        }));
-
-        for (final StandbyTask task : standbyTasks.values()) {
-            try {
-                try {
-                    task.suspend();
-                } catch (final Exception e) {
-                    log.error("{} Suspending standby task {} failed due to the following error:", logPrefix, task.id, e);
-                    try {
-                        task.close(false);
-                    } catch (final Exception f) {
-                        log.error("{} After suspending failed, closing the same standby task {} failed again due to the following error:", logPrefix, task.id, f);
-                    }
-                    throw e;
-                }
-            } catch (final RuntimeException e) {
-                firstException.compareAndSet(null, e);
-            }
-        }
-
+        firstException.compareAndSet(null, active.suspend());
+        firstException.compareAndSet(null, standby.suspend());
         // remove the changelog partitions from restore consumer
         firstException.compareAndSet(null, unAssignChangeLogPartitions());
 
-        updateSuspendedTasks();
-
         if (firstException.get() != null) {
             throw new StreamsException(logPrefix + " failed to suspend stream tasks", firstException.get());
         }
@@ -1198,56 +1028,23 @@ public class StreamThread extends Thread {
     }
 
     private List<AbstractTask> allTasks() {
-        final List<AbstractTask> tasks = activeAndStandbytasks();
-        tasks.addAll(suspendedAndSuspendedStandbytasks());
-        return tasks;
-    }
-
-    private List<AbstractTask> activeAndStandbytasks() {
-        final List<AbstractTask> tasks = new ArrayList<AbstractTask>(activeTasks.values());
-        tasks.addAll(standbyTasks.values());
-        return tasks;
-    }
-
-    private List<AbstractTask> suspendedAndSuspendedStandbytasks() {
-        final List<AbstractTask> tasks = new ArrayList<AbstractTask>(suspendedTasks.values());
-        tasks.addAll(suspendedStandbyTasks.values());
+        final List<AbstractTask> tasks = active.allInitializedTasks();
+        tasks.addAll(standby.allInitializedTasks());
         return tasks;
     }
 
-    private StreamTask findMatchingSuspendedTask(final TaskId taskId, final Set<TopicPartition> partitions) {
-        if (suspendedTasks.containsKey(taskId)) {
-            final StreamTask task = suspendedTasks.get(taskId);
-            if (task.partitions().equals(partitions)) {
-                return task;
-            }
-        }
-        return null;
-    }
-
-    private StandbyTask findMatchingSuspendedStandbyTask(final TaskId taskId, final Set<TopicPartition> partitions) {
-        if (suspendedStandbyTasks.containsKey(taskId)) {
-            final StandbyTask task = suspendedStandbyTasks.get(taskId);
-            if (task.partitions().equals(partitions)) {
-                return task;
-            }
-        }
-        return null;
-    }
-
     private void closeNonAssignedSuspendedTasks() {
         final Map<TaskId, Set<TopicPartition>> newTaskAssignment = partitionAssignor.activeTasks();
-        final Iterator<Map.Entry<TaskId, StreamTask>> suspendedTaskIterator = suspendedTasks.entrySet().iterator();
+        final Iterator<StreamTask> suspendedTaskIterator = active.suspendedTasks().iterator();
         while (suspendedTaskIterator.hasNext()) {
-            final Map.Entry<TaskId, StreamTask> next = suspendedTaskIterator.next();
-            final StreamTask task = next.getValue();
-            final Set<TopicPartition> assignedPartitionsForTask = newTaskAssignment.get(next.getKey());
+            final StreamTask task = suspendedTaskIterator.next();
+            final Set<TopicPartition> assignedPartitionsForTask = newTaskAssignment.get(task.id);
             if (!task.partitions().equals(assignedPartitionsForTask)) {
                 log.debug("{} Closing suspended and not re-assigned task {}", logPrefix, task.id());
                 try {
                     task.closeSuspended(true, null);
                 } catch (final Exception e) {
-                    log.error("{} Failed to close suspended task {} due to the following error:", logPrefix, next.getKey(), e);
+                    log.error("{} Failed to close suspended task {} due to the following error:", logPrefix, task.id, e);
                 } finally {
                     suspendedTaskIterator.remove();
                 }
@@ -1256,12 +1053,11 @@ public class StreamThread extends Thread {
     }
 
     private void closeNonAssignedSuspendedStandbyTasks() {
-        final Set<TaskId> currentSuspendedTaskIds = partitionAssignor.standbyTasks().keySet();
-        final Iterator<Map.Entry<TaskId, StandbyTask>> standByTaskIterator = suspendedStandbyTasks.entrySet().iterator();
+        final Set<TaskId> newStandbyTaskIds = partitionAssignor.standbyTasks().keySet();
+        final Iterator<StandbyTask> standByTaskIterator = standby.suspendedTasks().iterator();
         while (standByTaskIterator.hasNext()) {
-            final Map.Entry<TaskId, StandbyTask> suspendedTask = standByTaskIterator.next();
-            if (!currentSuspendedTaskIds.contains(suspendedTask.getKey())) {
-                final StandbyTask task = suspendedTask.getValue();
+            final StandbyTask task = standByTaskIterator.next();
+            if (!newStandbyTaskIds.contains(task.id)) {
                 log.debug("{} Closing suspended and not re-assigned standby task {}", logPrefix, task.id());
                 try {
                     task.close(true);
@@ -1280,18 +1076,18 @@ public class StreamThread extends Thread {
 
         try {
             return new StreamTask(
-                id,
-                applicationId,
-                partitions,
-                builder.build(id.topicGroupId),
-                consumer,
-                storeChangelogReader,
-                config,
-                streamsMetrics,
-                stateDirectory,
-                cache,
-                time,
-                createProducer(id));
+                    id,
+                    applicationId,
+                    partitions,
+                    builder.build(id.topicGroupId),
+                    consumer,
+                    storeChangelogReader,
+                    config,
+                    streamsMetrics,
+                    stateDirectory,
+                    cache,
+                    time,
+                    createProducer(id));
         } finally {
             log.trace("{} Created active task {} with assigned partitions {}", logPrefix, id, partitions);
         }
@@ -1317,7 +1113,7 @@ public class StreamThread extends Thread {
         return producer;
     }
 
-    private void addStreamTasks(final Collection<TopicPartition> assignment, final long start) {
+    private void addStreamTasks(final Collection<TopicPartition> assignment) {
         if (partitionAssignor == null) {
             throw new IllegalStateException(logPrefix + " Partition assignor has not been initialized while adding stream tasks: this should not happen.");
         }
@@ -1332,17 +1128,7 @@ public class StreamThread extends Thread {
 
             if (assignment.containsAll(partitions)) {
                 try {
-                    final StreamTask task = findMatchingSuspendedTask(taskId, partitions);
-                    if (task != null) {
-                        suspendedTasks.remove(taskId);
-                        task.resume();
-
-                        activeTasks.put(taskId, task);
-
-                        for (final TopicPartition partition : partitions) {
-                            activeTasksByPartition.put(partition, task);
-                        }
-                    } else {
+                    if (!active.maybeResumeSuspendedTask(taskId, partitions)) {
                         newTasks.put(taskId, partitions);
                     }
                 } catch (final StreamsException e) {
@@ -1358,7 +1144,7 @@ public class StreamThread extends Thread {
         // -> other thread will call removeSuspendedTasks(); eventually
         log.trace("{} New active tasks to be created: {}", logPrefix, newTasks);
 
-        taskCreator.retryWithBackoff(newTasks, start);
+        taskCreator.createTasks(newTasks);
     }
 
     // visible for testing
@@ -1380,13 +1166,11 @@ public class StreamThread extends Thread {
         }
     }
 
-    private void addStandbyTasks(final long start) {
+    private void addStandbyTasks() {
         if (partitionAssignor == null) {
             throw new IllegalStateException(logPrefix + " Partition assignor has not been initialized while adding standby tasks: this should not happen.");
         }
 
-        final Map<TopicPartition, Long> checkpointedOffsets = new HashMap<>();
-
         final Map<TaskId, Set<TopicPartition>> newStandbyTasks = new HashMap<>();
 
         log.debug("{} Adding assigned standby tasks {}", logPrefix, partitionAssignor.standbyTasks());
@@ -1394,116 +1178,18 @@ public class StreamThread extends Thread {
         for (final Map.Entry<TaskId, Set<TopicPartition>> entry : partitionAssignor.standbyTasks().entrySet()) {
             final TaskId taskId = entry.getKey();
             final Set<TopicPartition> partitions = entry.getValue();
-            final StandbyTask task = findMatchingSuspendedStandbyTask(taskId, partitions);
-
-            if (task != null) {
-                suspendedStandbyTasks.remove(taskId);
-                task.resume();
-            } else {
+            if (!standby.maybeResumeSuspendedTask(taskId, partitions)) {
                 newStandbyTasks.put(taskId, partitions);
             }
 
-            updateStandByTaskMaps(checkpointedOffsets, taskId, partitions, task);
         }
 
         // create all newly assigned standby tasks (guard against race condition with other thread via backoff and retry)
         // -> other thread will call removeSuspendedStandbyTasks(); eventually
         log.trace("{} New standby tasks to be created: {}", logPrefix, newStandbyTasks);
 
-        new StandbyTaskCreator(checkpointedOffsets).retryWithBackoff(newStandbyTasks, start);
-
-        restoreConsumer.assign(new ArrayList<>(checkpointedOffsets.keySet()));
-
-        for (final Map.Entry<TopicPartition, Long> entry : checkpointedOffsets.entrySet()) {
-            final TopicPartition partition = entry.getKey();
-            final long offset = entry.getValue();
-            if (offset >= 0) {
-                restoreConsumer.seek(partition, offset);
-            } else {
-                restoreConsumer.seekToBeginning(singleton(partition));
-            }
-        }
-    }
-
-    private void updateStandByTaskMaps(final Map<TopicPartition, Long> checkpointedOffsets,
-                                       final TaskId taskId,
-                                       final Set<TopicPartition> partitions,
-                                       final StandbyTask task) {
-        if (task != null) {
-            standbyTasks.put(taskId, task);
-            for (final TopicPartition partition : partitions) {
-                standbyTasksByPartition.put(partition, task);
-            }
-            // collect checked pointed offsets to position the restore consumer
-            // this include all partitions from which we restore states
-            for (final TopicPartition partition : task.checkpointedOffsets().keySet()) {
-                standbyTasksByPartition.put(partition, task);
-            }
-            checkpointedOffsets.putAll(task.checkpointedOffsets());
-        }
-    }
-
-    private void updateSuspendedTasks() {
-        suspendedTasks.clear();
-        suspendedTasks.putAll(activeTasks);
-        suspendedStandbyTasks.putAll(standbyTasks);
-    }
-
-    private void removeStreamTasks() {
-        log.debug("{} Removing all active tasks {}", logPrefix, activeTasks.keySet());
-
-        try {
-            prevActiveTasks.clear();
-            prevActiveTasks.addAll(activeTasks.keySet());
-
-            activeTasks.clear();
-            activeTasksByPartition.clear();
-        } catch (final Exception e) {
-            log.error("{} Failed to remove stream tasks due to the following error:", logPrefix, e);
-        }
+        new StandbyTaskCreator().createTasks(newStandbyTasks);
     }
 
-    private void removeStandbyTasks() {
-        log.debug("{} Removing all standby tasks {}", logPrefix, standbyTasks.keySet());
 
-        standbyTasks.clear();
-        standbyTasksByPartition.clear();
-        standbyRecords.clear();
-    }
-
-    private void closeZombieTask(final StreamTask task) {
-        log.warn("{} Producer of task {} fenced; closing zombie task", logPrefix, task.id);
-        try {
-            task.close(false);
-        } catch (final Exception e) {
-            log.warn("{} Failed to close zombie task due to {}, ignore and proceed", logPrefix, e);
-        }
-        activeTasks.remove(task.id);
-    }
-
-
-    private RuntimeException performOnStreamTasks(final StreamTaskAction action) {
-        RuntimeException firstException = null;
-        final Iterator<Map.Entry<TaskId, StreamTask>> it = activeTasks.entrySet().iterator();
-        while (it.hasNext()) {
-            final StreamTask task = it.next().getValue();
-            try {
-                action.apply(task);
-            } catch (final ProducerFencedException e) {
-                closeZombieTask(task);
-                it.remove();
-            } catch (final RuntimeException t) {
-                log.error("{} Failed to {} stream task {} due to the following error:",
-                    logPrefix,
-                    action.name(),
-                    task.id(),
-                    t);
-                if (firstException == null) {
-                    firstException = t;
-                }
-            }
-        }
-
-        return firstException;
-    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/b268322e/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
index 985dc93..d2eaca7 100644
--- a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
@@ -437,7 +437,7 @@ public class KafkaStreamsTest {
         CLUSTER.createTopic(topic);
         final KStreamBuilder builder = new KStreamBuilder();
 
-        builder.stream(Serdes.String(), Serdes.String(), topic);
+        builder.table(Serdes.String(), Serdes.String(), topic, topic);
 
         final KafkaStreams streams = new KafkaStreams(builder, props);
         final CountDownLatch latch = new CountDownLatch(1);

http://git-wip-us.apache.org/repos/asf/kafka/blob/b268322e/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java
index 3cff7f7..7868981 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java
@@ -132,7 +132,7 @@ public class ResetIntegrationTest {
     public void testReprocessingFromScratchAfterResetWithIntermediateUserTopic() throws Exception {
         CLUSTER.createTopic(INTERMEDIATE_USER_TOPIC);
 
-        final Properties streamsConfiguration = prepareTest();
+        final Properties streamsConfiguration = prepareTest(4);
         final Properties resultTopicConsumerConfig = TestUtils.consumerConfig(
             CLUSTER.bootstrapServers(),
             APP_ID + "-standard-consumer-" + OUTPUT_TOPIC,
@@ -198,7 +198,7 @@ public class ResetIntegrationTest {
 
     @Test
     public void testReprocessingFromScratchAfterResetWithoutIntermediateUserTopic() throws Exception {
-        final Properties streamsConfiguration = prepareTest();
+        final Properties streamsConfiguration = prepareTest(1);
         final Properties resultTopicConsumerConfig = TestUtils.consumerConfig(
                 CLUSTER.bootstrapServers(),
                 APP_ID + "-standard-consumer-" + OUTPUT_TOPIC,
@@ -241,14 +241,14 @@ public class ResetIntegrationTest {
         cleanGlobal(null);
     }
 
-    private Properties prepareTest() throws Exception {
+    private Properties prepareTest(final int threads) throws Exception {
         final Properties streamsConfiguration = new Properties();
         streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, APP_ID + testNo);
         streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
         streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath());
         streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Long().getClass());
         streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass());
-        streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 4);
+        streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, threads);
         streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0);
         streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100);
         streamsConfiguration.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 100);

http://git-wip-us.apache.org/repos/asf/kafka/blob/b268322e/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
index 123cbf0..5e71f31 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
@@ -24,60 +24,104 @@ import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.AuthorizationException;
 import org.apache.kafka.common.errors.WakeupException;
-import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.test.TestUtils;
+import org.easymock.EasyMock;
+import org.junit.Before;
 import org.junit.Test;
 
+import java.io.IOException;
 import java.util.Collections;
+import java.util.List;
 import java.util.Properties;
 
+import static org.junit.Assert.fail;
+
 public class AbstractTaskTest {
 
+    private final TaskId id = new TaskId(0, 0);
+    private StateDirectory stateDirectory  = EasyMock.createMock(StateDirectory.class);
+
+    @Before
+    public void before() {
+        EasyMock.expect(stateDirectory.directoryForTask(id)).andReturn(TestUtils.tempDirectory());
+    }
+
     @Test(expected = ProcessorStateException.class)
     public void shouldThrowProcessorStateExceptionOnInitializeOffsetsWhenAuthorizationException() throws Exception {
         final Consumer consumer = mockConsumer(new AuthorizationException("blah"));
-        final AbstractTask task = createTask(consumer);
+        final AbstractTask task = createTask(consumer, Collections.<StateStore>emptyList());
         task.updateOffsetLimits();
     }
 
     @Test(expected = ProcessorStateException.class)
     public void shouldThrowProcessorStateExceptionOnInitializeOffsetsWhenKafkaException() throws Exception {
         final Consumer consumer = mockConsumer(new KafkaException("blah"));
-        final AbstractTask task = createTask(consumer);
+        final AbstractTask task = createTask(consumer, Collections.<StateStore>emptyList());
         task.updateOffsetLimits();
     }
 
     @Test(expected = WakeupException.class)
     public void shouldThrowWakeupExceptionOnInitializeOffsetsWhenWakeupException() throws Exception {
         final Consumer consumer = mockConsumer(new WakeupException());
-        final AbstractTask task = createTask(consumer);
+        final AbstractTask task = createTask(consumer, Collections.<StateStore>emptyList());
         task.updateOffsetLimits();
     }
 
-    private AbstractTask createTask(final Consumer consumer) {
-        final MockTime time = new MockTime();
+    @Test
+    public void shouldThrowLockExceptionIfFailedToLockStateDirectoryWhenTopologyHasStores() throws IOException {
+        final Consumer consumer = EasyMock.createNiceMock(Consumer.class);
+        final StateStore store = EasyMock.createNiceMock(StateStore.class);
+        EasyMock.expect(stateDirectory.lock(id, 5)).andReturn(false);
+        EasyMock.replay(stateDirectory);
+
+        final AbstractTask task = createTask(consumer, Collections.singletonList(store));
+
+        try {
+            task.initializeStateStores();
+            fail("Should have thrown LockException");
+        } catch (final LockException e) {
+            // ok
+        }
+
+    }
+
+    @Test
+    public void shouldNotAttemptToLockIfNoStores() throws IOException {
+        final Consumer consumer = EasyMock.createNiceMock(Consumer.class);
+        EasyMock.replay(stateDirectory);
+
+        final AbstractTask task = createTask(consumer, Collections.<StateStore>emptyList());
+
+        task.initializeStateStores();
+
+        // should fail if lock is called
+        EasyMock.verify(stateDirectory);
+    }
+
+    private AbstractTask createTask(final Consumer consumer, final List<StateStore> stateStores) {
         final Properties properties = new Properties();
         properties.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-id");
         properties.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummyhost:9092");
         final StreamsConfig config = new StreamsConfig(properties);
-        return new AbstractTask(new TaskId(0, 0),
+        return new AbstractTask(id,
                                 "app",
                                 Collections.singletonList(new TopicPartition("t", 0)),
                                 new ProcessorTopology(Collections.<ProcessorNode>emptyList(),
                                                       Collections.<String, SourceNode>emptyMap(),
                                                       Collections.<String, SinkNode>emptyMap(),
-                                                      Collections.<StateStore>emptyList(),
+                                                      stateStores,
                                                       Collections.<String, String>emptyMap(),
                                                       Collections.<StateStore>emptyList()),
                                 consumer,
                                 new StoreChangelogReader(consumer, Time.SYSTEM, 5000),
                                 false,
-                                new StateDirectory("app", TestUtils.tempDirectory().getPath(), time),
+                                stateDirectory,
                                 config) {
             @Override
             public void resume() {}
@@ -90,6 +134,26 @@ public class AbstractTaskTest {
 
             @Override
             public void close(final boolean clean) {}
+
+            @Override
+            public boolean initialize() {
+                return false;
+            }
+
+            @Override
+            boolean process() {
+                return false;
+            }
+
+            @Override
+            boolean maybePunctuate() {
+                return false;
+            }
+
+            @Override
+            boolean commitNeeded() {
+                return false;
+            }
         };
     }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/b268322e/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java
new file mode 100644
index 0000000..52d0ea8
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java
@@ -0,0 +1,426 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.CommitFailedException;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.ProducerFencedException;
+import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.processor.TaskId;
+import org.easymock.EasyMock;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Set;
+
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.CoreMatchers.nullValue;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsEqual.equalTo;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class AssignedTasksTest {
+
+    private final AssignedTasks<AbstractTask> assignedTasks = new AssignedTasks<>("log", "task", Time.SYSTEM);
+    private final AbstractTask t1 = EasyMock.createMock(AbstractTask.class);
+    private final AbstractTask t2 = EasyMock.createMock(AbstractTask.class);
+    private final TopicPartition tp1 = new TopicPartition("t1", 0);
+    private final TopicPartition tp2 = new TopicPartition("t2", 0);
+    private final TopicPartition changeLog1 = new TopicPartition("cl1", 0);
+    private final TopicPartition changeLog2 = new TopicPartition("cl2", 0);
+    private final TaskId taskId1 = new TaskId(0, 0);
+    private final TaskId taskId2 = new TaskId(1, 0);
+    private final Metrics metrics = new Metrics();
+    private final Sensor punctuateSensor = metrics.sensor("punctuate");
+    private final Sensor commitSensor = metrics.sensor("commit");
+
+    @Before
+    public void before() {
+        EasyMock.expect(t1.id()).andReturn(taskId1).anyTimes();
+        EasyMock.expect(t2.id()).andReturn(taskId2).anyTimes();
+    }
+
+    @Test
+    public void shouldGetPartitionsFromNewTasksThatHaveStateStores() {
+        EasyMock.expect(t1.hasStateStores()).andReturn(true);
+        EasyMock.expect(t2.hasStateStores()).andReturn(true);
+        EasyMock.expect(t1.partitions()).andReturn(Collections.singleton(tp1));
+        EasyMock.expect(t2.partitions()).andReturn(Collections.singleton(tp2));
+        EasyMock.replay(t1, t2);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.addNewTask(t2);
+
+        final Set<TopicPartition> partitions = assignedTasks.uninitializedPartitions();
+        assertThat(partitions, equalTo(Utils.mkSet(tp1, tp2)));
+        EasyMock.verify(t1, t2);
+    }
+
+    @Test
+    public void shouldNotGetPartitionsFromNewTasksWithoutStateStores() {
+        EasyMock.expect(t1.hasStateStores()).andReturn(false);
+        EasyMock.expect(t2.hasStateStores()).andReturn(false);
+        EasyMock.replay(t1, t2);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.addNewTask(t2);
+
+        final Set<TopicPartition> partitions = assignedTasks.uninitializedPartitions();
+        assertTrue(partitions.isEmpty());
+        EasyMock.verify(t1, t2);
+    }
+
+    @Test
+    public void shouldInitializeNewTasks() {
+        EasyMock.expect(t1.initialize()).andReturn(false);
+        EasyMock.replay(t1);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldMoveInitializedTasksNeedingRestoreToRestoring() {
+        EasyMock.expect(t1.initialize()).andReturn(false);
+        EasyMock.expect(t2.initialize()).andReturn(true);
+        EasyMock.expect(t2.partitions()).andReturn(Collections.singleton(tp2));
+        EasyMock.expect(t2.changelogPartitions()).andReturn(Collections.<TopicPartition>emptyList());
+
+        EasyMock.replay(t1, t2);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.addNewTask(t2);
+
+        assignedTasks.initializeNewTasks();
+
+        Collection<AbstractTask> restoring = assignedTasks.restoringTasks();
+        assertThat(restoring.size(), equalTo(1));
+        assertSame(restoring.iterator().next(), t1);
+    }
+
+    @Test
+    public void shouldMoveInitializedTasksThatDontNeedRestoringToRunning() {
+        EasyMock.expect(t2.initialize()).andReturn(true);
+        EasyMock.expect(t2.partitions()).andReturn(Collections.singleton(tp2));
+        EasyMock.expect(t2.changelogPartitions()).andReturn(Collections.<TopicPartition>emptyList());
+
+        EasyMock.replay(t2);
+
+        assignedTasks.addNewTask(t2);
+        assignedTasks.initializeNewTasks();
+
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId2)));
+    }
+
+    @Test
+    public void shouldTransitionFullyRestoredTasksToRunning() {
+        final Set<TopicPartition> task1Partitions = Utils.mkSet(tp1);
+        EasyMock.expect(t1.initialize()).andReturn(false);
+        EasyMock.expect(t1.partitions()).andReturn(task1Partitions).anyTimes();
+        EasyMock.expect(t1.changelogPartitions()).andReturn(Utils.mkSet(changeLog1, changeLog2)).anyTimes();
+        EasyMock.replay(t1);
+
+        assignedTasks.addNewTask(t1);
+
+        assignedTasks.initializeNewTasks();
+
+        assertTrue(assignedTasks.updateRestored(Utils.mkSet(changeLog1)).isEmpty());
+        Set<TopicPartition> partitions = assignedTasks.updateRestored(Utils.mkSet(changeLog2));
+        assertThat(partitions, equalTo(task1Partitions));
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+    }
+
+    @Test
+    public void shouldSuspendRunningTasks() {
+        mockRunningTaskSuspension();
+        EasyMock.replay(t1);
+
+        suspendTask();
+
+        assertThat(assignedTasks.previousTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCloseUnInitializedTasksOnSuspend() {
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.suspend();
+
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldNotSuspendSuspendedTasks() {
+        mockRunningTaskSuspension();
+        EasyMock.replay(t1);
+
+        suspendTask();
+        assignedTasks.suspend();
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCloseTaskOnSuspendWhenRuntimeException() {
+        mockInitializedTask();
+        t1.suspend();
+        EasyMock.expectLastCall().andThrow(new RuntimeException("KABOOM!"));
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        assertThat(suspendTask(), not(nullValue()));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCloseTaskOnSuspendWhenProducerFencedException() {
+        mockInitializedTask();
+        t1.suspend();
+        EasyMock.expectLastCall().andThrow(new ProducerFencedException("KABOOM!"));
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        assertThat(suspendTask(), nullValue());
+        assertTrue(assignedTasks.previousTaskIds().isEmpty());
+        EasyMock.verify(t1);
+    }
+
+    private void mockInitializedTask() {
+        EasyMock.expect(t1.initialize()).andReturn(true);
+        EasyMock.expect(t1.partitions()).andReturn(Collections.singleton(tp1));
+        EasyMock.expect(t1.changelogPartitions()).andReturn(Collections.<TopicPartition>emptyList());
+    }
+
+    @Test
+    public void shouldResumeMatchingSuspendedTasks() {
+        mockRunningTaskSuspension();
+        t1.resume();
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        suspendTask();
+
+        assertTrue(assignedTasks.maybeResumeSuspendedTask(taskId1, Collections.singleton(tp1)));
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCommitRunningTasks() {
+        mockInitializedTask();
+        t1.commit();
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        assignedTasks.commit();
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCloseTaskOnCommitIfProduceFencedException() {
+        mockInitializedTask();
+        t1.commit();
+        EasyMock.expectLastCall().andThrow(new ProducerFencedException(""));
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        assignedTasks.commit();
+        assertTrue(assignedTasks.runningTasks().isEmpty());
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldNotThrowCommitFailedExceptionOnCommit() {
+        mockInitializedTask();
+        t1.commit();
+        EasyMock.expectLastCall().andThrow(new CommitFailedException());
+        EasyMock.replay(t1);
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        assignedTasks.commit();
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldThrowExceptionOnCommitWhenNotCommitFailedOrProducerFenced() {
+        mockInitializedTask();
+        t1.commit();
+        EasyMock.expectLastCall().andThrow(new RuntimeException(""));
+        EasyMock.replay(t1);
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        try {
+            assignedTasks.commit();
+            fail("Should have thrown exception");
+        } catch (Exception e) {
+            // ok
+        }
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldProcessRunningTasks() {
+        mockInitializedTask();
+        EasyMock.expect(t1.process()).andReturn(true);
+        EasyMock.replay(t1);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        assertThat(assignedTasks.process(), equalTo(1));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCloseTaskOnProcessIfProducerFencedException() {
+        mockInitializedTask();
+        EasyMock.expect(t1.process()).andThrow(new ProducerFencedException(""));
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        assignedTasks.process();
+        assertTrue(assignedTasks.runningTasks().isEmpty());
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldThrowExceptionOnProcessWhenNotCommitFailedOrProducerFencedException() {
+        mockInitializedTask();
+        EasyMock.expect(t1.process()).andThrow(new RuntimeException(""));
+        EasyMock.replay(t1);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        try {
+            assignedTasks.process();
+            fail("should have thrown exception");
+        } catch (Exception e) {
+            // okd
+        }
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldPunctuateRunningTasks() {
+        mockInitializedTask();
+        EasyMock.expect(t1.maybePunctuate()).andReturn(true);
+        EasyMock.expect(t1.commitNeeded()).andReturn(false);
+        EasyMock.replay(t1);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        assignedTasks.punctuateAndCommit(commitSensor, punctuateSensor);
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCommitRunningTasksIfNeeded() {
+        mockInitializedTask();
+        EasyMock.expect(t1.maybePunctuate()).andReturn(true);
+        EasyMock.expect(t1.commitNeeded()).andReturn(true);
+        t1.commit();
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        assignedTasks.punctuateAndCommit(commitSensor, punctuateSensor);
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldThrowExceptionOnPunctuateAndCommitWhenNotCommitFailedOrProducerFencedException() {
+        mockInitializedTask();
+        EasyMock.expect(t1.maybePunctuate()).andThrow(new RuntimeException(""));
+        EasyMock.replay(t1);
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        try {
+            assignedTasks.punctuateAndCommit(commitSensor, punctuateSensor);
+            fail("should have thrown exception");
+        } catch (Exception e) {
+            // ok
+        }
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCloseTaskOnPunctuateAndCommitIfProducerFencedException() {
+        mockInitializedTask();
+        EasyMock.expect(t1.maybePunctuate()).andThrow(new ProducerFencedException(""));
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+
+        assignedTasks.punctuateAndCommit(commitSensor, punctuateSensor);
+        assertTrue(assignedTasks.runningTasks().isEmpty());
+        EasyMock.verify(t1);
+
+    }
+
+    private RuntimeException suspendTask() {
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+        return assignedTasks.suspend();
+    }
+
+    private void mockRunningTaskSuspension() {
+        EasyMock.expect(t1.initialize()).andReturn(true);
+        EasyMock.expect(t1.partitions()).andReturn(Collections.singleton(tp1)).anyTimes();
+        EasyMock.expect(t1.changelogPartitions()).andReturn(Collections.<TopicPartition>emptyList()).anyTimes();
+        t1.suspend();
+        EasyMock.expectLastCall();
+    }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/b268322e/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
index f454216..1e4e0ee 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
@@ -21,7 +21,6 @@ import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
-import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
@@ -34,9 +33,7 @@ import org.junit.Test;
 
 import java.io.File;
 import java.io.IOException;
-import java.nio.channels.FileChannel;
-import java.nio.channels.FileLock;
-import java.nio.file.StandardOpenOption;
+
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
@@ -389,33 +386,6 @@ public class ProcessorStateManagerTest {
         assertThat(read, equalTo(Collections.<TopicPartition, Long>emptyMap()));
     }
 
-    @Test
-    public void shouldThrowLockExceptionIfFailedToLockStateDirectory() throws Exception {
-        final File taskDirectory = stateDirectory.directoryForTask(taskId);
-        final FileChannel channel = FileChannel.open(new File(taskDirectory,
-                                                              StateDirectory.LOCK_FILE_NAME).toPath(),
-                                                     StandardOpenOption.CREATE,
-                                                     StandardOpenOption.WRITE);
-        // lock the task directory
-        final FileLock lock = channel.lock();
-
-        try {
-            new ProcessorStateManager(
-                taskId,
-                noPartitions,
-                false,
-                stateDirectory,
-                Collections.<String, String>emptyMap(),
-                changelogReader,
-                false);
-            fail("Should have thrown LockException");
-        } catch (final LockException e) {
-           // pass
-        } finally {
-            lock.release();
-            channel.close();
-        }
-    }
 
     @Test
     public void shouldThrowIllegalArgumentExceptionIfStoreNameIsSameAsCheckpointFileName() throws Exception {

http://git-wip-us.apache.org/repos/asf/kafka/blob/b268322e/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index a358be5..8621470 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -158,7 +158,7 @@ public class StandbyTaskTest {
     public void testStorePartitions() throws Exception {
         StreamsConfig config = createConfig(baseDir);
         StandbyTask task = new StandbyTask(taskId, applicationId, topicPartitions, topology, consumer, changelogReader, config, null, stateDirectory);
-
+        task.initialize();
         assertEquals(Utils.mkSet(partition2), new HashSet<>(task.checkpointedOffsets().keySet()));
 
     }
@@ -182,7 +182,7 @@ public class StandbyTaskTest {
     public void testUpdate() throws Exception {
         StreamsConfig config = createConfig(baseDir);
         StandbyTask task = new StandbyTask(taskId, applicationId, topicPartitions, topology, consumer, changelogReader, config, null, stateDirectory);
-
+        task.initialize();
         restoreStateConsumer.assign(new ArrayList<>(task.checkpointedOffsets().keySet()));
 
         for (ConsumerRecord<Integer, Integer> record : Arrays.asList(
@@ -240,7 +240,7 @@ public class StandbyTaskTest {
 
         StreamsConfig config = createConfig(baseDir);
         StandbyTask task = new StandbyTask(taskId, applicationId, ktablePartitions, ktableTopology, consumer, changelogReader, config, null, stateDirectory);
-
+        task.initialize();
         restoreStateConsumer.assign(new ArrayList<>(task.checkpointedOffsets().keySet()));
 
         for (ConsumerRecord<Integer, Integer> record : Arrays.asList(
@@ -360,6 +360,7 @@ public class StandbyTaskTest {
                                                  null,
                                                  stateDirectory
         );
+        task.initialize();
 
 
         restoreStateConsumer.assign(new ArrayList<>(task.checkpointedOffsets().keySet()));

http://git-wip-us.apache.org/repos/asf/kafka/blob/b268322e/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
index 2ff6d33..a2254b6 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
@@ -31,12 +31,14 @@ import org.apache.kafka.test.MockRestoreCallback;
 import org.hamcrest.CoreMatchers;
 import org.junit.Test;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.IsEqual.equalTo;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 public class StoreChangelogReaderTest {
@@ -253,6 +255,25 @@ public class StoreChangelogReaderTest {
         assertThat(callback.restored, CoreMatchers.equalTo(Utils.mkList(KeyValue.pair(bytes, bytes), KeyValue.pair(bytes, bytes))));
     }
 
+    @Test
+    public void shouldReturnCompletedPartitionsOnEachRestoreCall() {
+        assignPartition(10, topicPartition);
+        final byte[] bytes = new byte[0];
+        for (int i = 0; i < 5; i++) {
+            consumer.addRecord(new ConsumerRecord<>(topicPartition.topic(), topicPartition.partition(), i, bytes, bytes));
+        }
+        consumer.assign(Collections.singletonList(topicPartition));
+        changelogReader.register(new StateRestorer(topicPartition, callback, null, Long.MAX_VALUE, false));
+
+        final Collection<TopicPartition> completedFirstTime = changelogReader.restore();
+        assertTrue(completedFirstTime.isEmpty());
+        for (int i = 5; i < 10; i++) {
+            consumer.addRecord(new ConsumerRecord<>(topicPartition.topic(), topicPartition.partition(), i, bytes, bytes));
+        }
+        final Collection<TopicPartition> expected = Collections.singleton(topicPartition);
+        assertThat(changelogReader.restore(), equalTo(expected));
+    }
+
     private void setupConsumer(final long messages, final TopicPartition topicPartition) {
         assignPartition(messages, topicPartition);
 


Mime
View raw message