kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From guozh...@apache.org
Subject [1/4] kafka git commit: KAFKA-4311: Multi layer cache eviction causes forwarding to incorrect ProcessorNode
Date Thu, 24 Nov 2016 05:10:53 GMT
Repository: kafka
Updated Branches:
  refs/heads/0.10.1 f91d95ac9 -> ecb51680a


KAFKA-4311: Multi layer cache eviction causes forwarding to incorrect ProcessorNode

Given a topology like the one below. If a record arriving in `tableOne` causes a cache eviction,
it will trigger the `leftJoin` that will do a `get` from `reducer-store`. If the key is not
currently cached in `reducer-store`, but is in the backing store, it will be put into the
cache, and it may also trigger an eviction. If it does trigger an eviction and the eldest
entry is dirty it will flush the dirty keys. It is at this point that a ClassCastException
is thrown. This occurs because the ProcessorContext is still set to the context of the `leftJoin`
and the next child in the topology is `mapValues`.
We need to set the correct `ProcessorNode`, on the context, in the `ForwardingCacheFlushListener`
prior to calling `context.forward`. We also need to  remember to reset the `ProcessorNode`
to the previous node once `context.forward` has completed.

```
       final KTable<String, String> one = builder.table(Serdes.String(), Serdes.String(),
tableOne, tableOne);
        final KTable<Long, String> two = builder.table(Serdes.Long(), Serdes.String(),
tableTwo, tableTwo);
        final KTable<String, Long> reduce = two.groupBy(new KeyValueMapper<Long,
String, KeyValue<String, Long>>() {
            Override
            public KeyValue<String, Long> apply(final Long key, final String value)
{
                return new KeyValue<>(value, key);
            }
        }, Serdes.String(), Serdes.Long())
                .reduce(new Reducer<Long>() {..}, new Reducer<Long>() {..}, "reducer-store");

    one.leftJoin(reduce, new ValueJoiner<String, Long, String>() {..})
        .mapValues(new ValueMapper<String, String>() {..});

```

Author: Damian Guy <damian.guy@gmail.com>

Reviewers: Eno Thereska, Guozhang Wang

Closes #2051 from dguy/kafka-4311


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/c1fb615a
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/c1fb615a
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/c1fb615a

Branch: refs/heads/0.10.1
Commit: c1fb615a6ef6cbd0a725bd70a2a602ca31402f8a
Parents: f91d95a
Author: Damian Guy <damian.guy@gmail.com>
Authored: Wed Nov 9 10:43:27 2016 -0800
Committer: Guozhang Wang <wangguoz@gmail.com>
Committed: Wed Nov 23 07:52:57 2016 -0800

----------------------------------------------------------------------
 .../internals/ForwardingCacheFlushListener.java | 22 ++++--
 .../internals/InternalProcessorContext.java     |  1 +
 .../internals/ProcessorContextImpl.java         | 13 ++--
 .../processor/internals/StandbyContextImpl.java |  5 ++
 .../streams/processor/internals/StreamTask.java | 28 +++-----
 .../streams/state/internals/NamedCache.java     |  8 ++-
 .../kstream/internals/KTableAggregateTest.java  | 72 ++++++++++++++++++++
 .../internals/ProcessorTopologyTest.java        |  1 +
 .../streams/state/KeyValueStoreTestDriver.java  |  7 ++
 .../streams/state/internals/NamedCacheTest.java |  6 ++
 .../apache/kafka/test/KStreamTestDriver.java    | 23 ++++++-
 .../apache/kafka/test/MockProcessorContext.java |  7 ++
 12 files changed, 160 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ForwardingCacheFlushListener.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ForwardingCacheFlushListener.java
b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ForwardingCacheFlushListener.java
index 1796be9..4635fc9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ForwardingCacheFlushListener.java
+++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ForwardingCacheFlushListener.java
@@ -17,22 +17,32 @@
 package org.apache.kafka.streams.kstream.internals;
 
 import org.apache.kafka.streams.processor.ProcessorContext;
+import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
+import org.apache.kafka.streams.processor.internals.ProcessorNode;
 
 class ForwardingCacheFlushListener<K, V> implements CacheFlushListener<K, V>
{
-    private final ProcessorContext context;
+    private final InternalProcessorContext context;
     private final boolean sendOldValues;
+    private final ProcessorNode myNode;
 
     ForwardingCacheFlushListener(final ProcessorContext context, final boolean sendOldValues)
{
-        this.context = context;
+        this.context = (InternalProcessorContext) context;
+        myNode = this.context.currentNode();
         this.sendOldValues = sendOldValues;
     }
 
     @Override
     public void apply(final K key, final V newValue, final V oldValue) {
-        if (sendOldValues) {
-            context.forward(key, new Change<>(newValue, oldValue));
-        } else {
-            context.forward(key, new Change<>(newValue, null));
+        final ProcessorNode prev = context.currentNode();
+        context.setCurrentNode(myNode);
+        try {
+            if (sendOldValues) {
+                context.forward(key, new Change<>(newValue, oldValue));
+            } else {
+                context.forward(key, new Change<>(newValue, null));
+            }
+        } finally {
+            context.setCurrentNode(prev);
         }
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
index 251ff3f..016964b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
@@ -42,6 +42,7 @@ public interface InternalProcessorContext extends ProcessorContext {
      */
     void setCurrentNode(ProcessorNode currentNode);
 
+    ProcessorNode currentNode();
     /**
      * Get the thread-global cache
      */

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
index 195e5a4..be18593 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
@@ -128,13 +128,11 @@ public class ProcessorContextImpl implements InternalProcessorContext,
RecordCol
      */
     @Override
     public StateStore getStateStore(String name) {
-        ProcessorNode node = task.node();
-
-        if (node == null)
+        if (currentNode == null)
             throw new TopologyBuilderException("Accessing from an unknown node");
 
-        if (!node.stateStores.contains(name)) {
-            throw new TopologyBuilderException("Processor " + node.name() + " has no access
to StateStore " + name);
+        if (!currentNode.stateStores.contains(name)) {
+            throw new TopologyBuilderException("Processor " + currentNode.name() + " has
no access to StateStore " + name);
         }
 
         return stateMgr.getStore(name);
@@ -272,4 +270,9 @@ public class ProcessorContextImpl implements InternalProcessorContext,
RecordCol
     public void setCurrentNode(final ProcessorNode currentNode) {
         this.currentNode = currentNode;
     }
+
+    @Override
+    public ProcessorNode currentNode() {
+        return currentNode;
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
index 563dbce..80c0026 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
@@ -222,4 +222,9 @@ public class StandbyContextImpl implements InternalProcessorContext, RecordColle
     public void setCurrentNode(final ProcessorNode currentNode) {
         // no-op. can't throw as this is called on commit when the StateStores get flushed.
     }
+
+    @Override
+    public ProcessorNode currentNode() {
+        throw new UnsupportedOperationException("this should not happen: currentNode not
supported in standby tasks.");
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index b993054..9a2f03e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -59,7 +59,6 @@ public class StreamTask extends AbstractTask implements Punctuator {
 
     private boolean commitRequested = false;
     private boolean commitOffsetNeeded = false;
-    private ProcessorNode currNode = null;
 
     private boolean requiresPoll = true;
 
@@ -122,11 +121,11 @@ public class StreamTask extends AbstractTask implements Punctuator {
         // initialize the task by initializing all its processor nodes in the topology
         log.info("{} Initializing processor nodes of the topology", logPrefix);
         for (ProcessorNode node : this.topology.processors()) {
-            this.currNode = node;
+            processorContext.setCurrentNode(node);
             try {
                 node.init(this.processorContext);
             } finally {
-                this.currNode = null;
+                processorContext.setCurrentNode(null);
             }
         }
 
@@ -172,13 +171,13 @@ public class StreamTask extends AbstractTask implements Punctuator {
 
         try {
             // process the record by passing to the source node of the topology
-            this.currNode = recordInfo.node();
+            final ProcessorNode currNode = recordInfo.node();
             TopicPartition partition = recordInfo.partition();
 
             log.trace("{} Start processing one record [{}]", logPrefix, record);
             final ProcessorRecordContext recordContext = createRecordContext(record);
             updateProcessorContext(recordContext, currNode);
-            this.currNode.process(record.key(), record.value());
+            currNode.process(record.key(), record.value());
 
             log.trace("{} Completed processing one record [{}]", logPrefix, record);
 
@@ -199,14 +198,13 @@ public class StreamTask extends AbstractTask implements Punctuator {
         } catch (KafkaException ke) {
             throw new StreamsException(format("Exception caught in process. taskId=%s, processor=%s,
topic=%s, partition=%d, offset=%d",
                                               id.toString(),
-                                              currNode.name(),
+                                              processorContext.currentNode().name(),
                                               record.topic(),
                                               record.partition(),
                                               record.offset()
                                               ), ke);
         } finally {
             processorContext.setCurrentNode(null);
-            this.currNode = null;
         }
 
         return partitionGroup.numBuffered();
@@ -241,10 +239,9 @@ public class StreamTask extends AbstractTask implements Punctuator {
      */
     @Override
     public void punctuate(ProcessorNode node, long timestamp) {
-        if (currNode != null)
+        if (processorContext.currentNode() != null)
             throw new IllegalStateException(String.format("%s Current node is not null",
logPrefix));
 
-        currNode = node;
         final StampedRecord stampedRecord = new StampedRecord(DUMMY_RECORD, timestamp);
         updateProcessorContext(createRecordContext(stampedRecord), node);
 
@@ -256,15 +253,10 @@ public class StreamTask extends AbstractTask implements Punctuator {
             throw new StreamsException(String.format("Exception caught in punctuate. taskId=%s
processor=%s", id,  node.name()), ke);
         } finally {
             processorContext.setCurrentNode(null);
-            currNode = null;
         }
     }
 
 
-    public ProcessorNode node() {
-        return this.currNode;
-    }
-
     /**
      * Commit the current task state
      */
@@ -322,10 +314,10 @@ public class StreamTask extends AbstractTask implements Punctuator {
      * @throws IllegalStateException if the current node is not null
      */
     public void schedule(long interval) {
-        if (currNode == null)
+        if (processorContext.currentNode() == null)
             throw new IllegalStateException(String.format("%s Current node is null", logPrefix));
 
-        punctuationQueue.schedule(new PunctuationSchedule(currNode, interval));
+        punctuationQueue.schedule(new PunctuationSchedule(processorContext.currentNode(),
interval));
     }
 
     /**
@@ -342,13 +334,13 @@ public class StreamTask extends AbstractTask implements Punctuator {
         // make sure close() is called for each node even when there is a RuntimeException
         RuntimeException exception = null;
         for (ProcessorNode node : this.topology.processors()) {
-            currNode = node;
+            processorContext.setCurrentNode(node);
             try {
                 node.close();
             } catch (RuntimeException e) {
                 exception = e;
             } finally {
-                currNode = null;
+                processorContext.setCurrentNode(null);
             }
         }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java
b/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java
index 65a836e..ab771df 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java
@@ -109,7 +109,7 @@ class NamedCache {
         for (Bytes key : dirtyKeys) {
             final LRUNode node = getInternal(key);
             if (node == null) {
-                throw new IllegalStateException("Key found in dirty key set, but entry is
null");
+                throw new IllegalStateException("Key = " + key + " found in dirty key set,
but entry is null");
             }
             entries.add(new ThreadCache.DirtyEntry(key, node.entry.value, node.entry));
             node.entry.markClean();
@@ -120,6 +120,12 @@ class NamedCache {
 
 
     synchronized void put(final Bytes key, final LRUCacheEntry value) {
+        if (!value.isDirty && dirtyKeys.contains(key)) {
+            throw new IllegalStateException(String.format("Attempting to put a clean entry
for key [%s] " +
+                                                                  "into NamedCache [%s] when
it already contains " +
+                                                                  "a dirty entry for the
same key",
+                                                          key, name));
+        }
         LRUNode node = cache.get(key);
         if (node != null) {
             numOverwrites++;

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
index ba33d5c..8378a79 100644
--- a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
@@ -22,10 +22,14 @@ import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.kstream.Aggregator;
+import org.apache.kafka.streams.kstream.ForeachAction;
 import org.apache.kafka.streams.kstream.Initializer;
 import org.apache.kafka.streams.kstream.KStreamBuilder;
 import org.apache.kafka.streams.kstream.KTable;
 import org.apache.kafka.streams.kstream.KeyValueMapper;
+import org.apache.kafka.streams.kstream.Reducer;
+import org.apache.kafka.streams.kstream.ValueJoiner;
+import org.apache.kafka.streams.kstream.ValueMapper;
 import org.apache.kafka.test.KStreamTestDriver;
 import org.apache.kafka.test.MockAggregator;
 import org.apache.kafka.test.MockInitializer;
@@ -39,6 +43,8 @@ import org.junit.Test;
 
 import java.io.File;
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
 
 import static org.junit.Assert.assertEquals;
 
@@ -320,4 +326,70 @@ public class KTableAggregateTest {
                  "1:2"
                  ), proc.processed);
     }
+
+    @Test
+    public void shouldForwardToCorrectProcessorNodeWhenMultiCacheEvictions() throws Exception
{
+        final String tableOne = "tableOne";
+        final String tableTwo = "tableTwo";
+        final KStreamBuilder builder = new KStreamBuilder();
+        final String reduceTopic = "TestDriver-reducer-store-repartition";
+        final Map<String, Long> reduceResults = new HashMap<>();
+
+        final KTable<String, String> one = builder.table(Serdes.String(), Serdes.String(),
tableOne, tableOne);
+        final KTable<Long, String> two = builder.table(Serdes.Long(), Serdes.String(),
tableTwo, tableTwo);
+
+
+        final KTable<String, Long> reduce = two.groupBy(new KeyValueMapper<Long,
String, KeyValue<String, Long>>() {
+            @Override
+            public KeyValue<String, Long> apply(final Long key, final String value)
{
+                return new KeyValue<>(value, key);
+            }
+        }, Serdes.String(), Serdes.Long())
+                .reduce(new Reducer<Long>() {
+                    @Override
+                    public Long apply(final Long value1, final Long value2) {
+                        return value1 + value2;
+                    }
+                }, new Reducer<Long>() {
+                    @Override
+                    public Long apply(final Long value1, final Long value2) {
+                        return value1 - value2;
+                    }
+                }, "reducer-store");
+
+        reduce.foreach(new ForeachAction<String, Long>() {
+            @Override
+            public void apply(final String key, final Long value) {
+                reduceResults.put(key, value);
+            }
+        });
+
+        one.leftJoin(reduce, new ValueJoiner<String, Long, String>() {
+            @Override
+            public String apply(final String value1, final Long value2) {
+                return value1 + ":" + value2;
+            }
+        })
+                .mapValues(new ValueMapper<String, String>() {
+                    @Override
+                    public String apply(final String value) {
+                        return value;
+                    }
+                });
+
+        final KStreamTestDriver driver = new KStreamTestDriver(builder, stateDir, 111);
+        driver.process(reduceTopic, "1", new Change<>(1L, null));
+        driver.process("tableOne", "2", "2");
+        // this should trigger eviction on the reducer-store topic
+        driver.process(reduceTopic, "2", new Change<>(2L, null));
+        // this wont as it is the same value
+        driver.process(reduceTopic, "2", new Change<>(2L, null));
+        assertEquals(Long.valueOf(2L), reduceResults.get("2"));
+
+        // this will trigger eviction on the tableOne topic
+        // that in turn will cause an eviction on reducer-topic. It will flush
+        // key 2 as it is the only dirty entry in the cache
+        driver.process("tableOne", "1", "5");
+        assertEquals(Long.valueOf(4L), reduceResults.get("2"));
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
index 54ee43c..a146316 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
@@ -280,6 +280,7 @@ public class ProcessorTopologyTest {
                 .addSink("sink-2", OUTPUT_TOPIC_2, constantPartitioner(partition), "processor-2");
     }
 
+
     /**
      * A processor that simply forwards all messages to all children.
      */

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
index e84e9ba..aca974b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
@@ -32,6 +32,7 @@ import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.ProcessorNode;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
 import org.apache.kafka.streams.state.internals.ThreadCache;
 import org.apache.kafka.test.MockProcessorContext;
@@ -269,6 +270,12 @@ public class KeyValueStoreTestDriver<K, V> {
             public Map<String, Object> appConfigsWithPrefix(String prefix) {
                 return new StreamsConfig(props).originalsWithPrefix(prefix);
             }
+
+            @Override
+            public ProcessorNode currentNode() {
+                return null;
+            }
+
             @Override
             public ThreadCache getCache() {
                 return cache;

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java
b/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java
index 3067256..5c0d511 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java
@@ -191,4 +191,10 @@ public class NamedCacheTest {
     public void shouldNotThrowNullPointerWhenCacheIsEmptyAndEvictionCalled() throws Exception
{
         cache.evict();
     }
+
+    @Test(expected = IllegalStateException.class)
+    public void shouldThrowIllegalStateExceptionWhenTryingToOverwriteDirtyEntryWithCleanEntry()
throws Exception {
+        cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10}, true, 0, 0,
0, ""));
+        cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10}, false, 0,
0, 0, ""));
+    }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java b/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
index ac58f37..05abbc6 100644
--- a/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
+++ b/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
@@ -56,14 +56,26 @@ public class KStreamTestDriver {
         this(builder, stateDir, Serdes.ByteArray(), Serdes.ByteArray());
     }
 
+    public KStreamTestDriver(KStreamBuilder builder, File stateDir, final long cacheSize)
{
+        this(builder, stateDir, Serdes.ByteArray(), Serdes.ByteArray(), cacheSize);
+    }
+
     public KStreamTestDriver(KStreamBuilder builder,
                              File stateDir,
                              Serde<?> keySerde,
                              Serde<?> valSerde) {
+        this(builder, stateDir, keySerde, valSerde, DEFAULT_CACHE_SIZE_BYTES);
+    }
+
+    public KStreamTestDriver(KStreamBuilder builder,
+                             File stateDir,
+                             Serde<?> keySerde,
+                             Serde<?> valSerde,
+                             long cacheSize) {
         builder.setApplicationId("TestDriver");
         this.topology = builder.build(null);
         this.stateDir = stateDir;
-        this.cache = new ThreadCache(DEFAULT_CACHE_SIZE_BYTES);
+        this.cache = new ThreadCache(cacheSize);
         this.context = new MockProcessorContext(this, stateDir, keySerde, valSerde, new MockRecordCollector(),
cache);
         this.context.setRecordContext(new ProcessorRecordContext(0, 0, 0, "topic"));
 
@@ -73,13 +85,14 @@ public class KStreamTestDriver {
         }
 
         for (ProcessorNode node : topology.processors()) {
-            currNode = node;
+            context.setCurrentNode(node);
             try {
                 node.init(context);
             } finally {
-                currNode = null;
+                context.setCurrentNode(null);
             }
         }
+
     }
 
     public ProcessorContext context() {
@@ -225,6 +238,10 @@ public class KStreamTestDriver {
 
     }
 
+    public void setCurrentNode(final ProcessorNode currentNode) {
+        currNode = currentNode;
+    }
+
 
     private class MockRecordCollector extends RecordCollector {
         public MockRecordCollector() {

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
index 8ad2fa9..cafdd9e 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
@@ -51,6 +51,7 @@ public class MockProcessorContext implements InternalProcessorContext, RecordCol
 
     long timestamp = -1L;
     private RecordContext recordContext;
+    private ProcessorNode currentNode;
 
     public MockProcessorContext(StateSerdes<?, ?> serdes, RecordCollector collector)
{
         this(null, null, serdes.keySerde(), serdes.valueSerde(), collector, null);
@@ -248,7 +249,13 @@ public class MockProcessorContext implements InternalProcessorContext,
RecordCol
 
     @Override
     public void setCurrentNode(final ProcessorNode currentNode) {
+        this.currentNode  = currentNode;
+        driver.setCurrentNode(currentNode);
+    }
 
+    @Override
+    public ProcessorNode currentNode() {
+        return currentNode;
     }
 
 }


Mime
View raw message