kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From guozh...@apache.org
Subject [kafka] branch trunk updated: KAFKA-8029: In memory session store (#6525)
Date Fri, 26 Apr 2019 20:22:51 GMT
This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 8299f2a  KAFKA-8029: In memory session store (#6525)
8299f2a is described below

commit 8299f2a397b6033d60c295a689b7a45fcf413f4a
Author: A. Sophie Blee-Goldman <sophie@confluent.io>
AuthorDate: Fri Apr 26 13:22:36 2019 -0700

    KAFKA-8029: In memory session store (#6525)
    
    First pass at an in-memory session store implementation.
    
    Reviewers: Simon Geisler, Damian Guy <damian@confluent.io>, John Roesler <john@confluent.io>, Bill Bejeck <bill@confluent.io>, Guozhang Wang <wangguoz@gmail.com>
---
 docs/streams/developer-guide/processor-api.html    |   2 +-
 docs/streams/upgrade-guide.html                    |   2 +-
 .../apache/kafka/streams/state/SessionStore.java   |   6 +-
 .../org/apache/kafka/streams/state/Stores.java     |  21 ++
 .../InMemorySessionBytesStoreSupplier.java         |  59 ++++
 .../state/internals/InMemorySessionStore.java      | 378 +++++++++++++++++++++
 ...toreTest.java => InMemorySessionStoreTest.java} | 289 +++++++++++++---
 .../state/internals/RocksDBSessionStoreTest.java   |  51 +++
 8 files changed, 759 insertions(+), 49 deletions(-)

diff --git a/docs/streams/developer-guide/processor-api.html b/docs/streams/developer-guide/processor-api.html
index 4a060a6..31c11ed 100644
--- a/docs/streams/developer-guide/processor-api.html
+++ b/docs/streams/developer-guide/processor-api.html
@@ -259,7 +259,7 @@
                                 disk space is either not available or local disk space is wiped
                                 in-between app instance restarts.</li>
                             <li>Available <a class="reference external" href="/{{version}}/javadoc/org/apache/kafka/streams/state/Stores.html#inMemoryKeyValueStore-java.lang.String-">store variants</a>:
-                                time window key-value store</li>
+                                time window key-value store, session window key-value store.</li>
                         </ul>
                             <div class="highlight-java"><div class="highlight"><pre><span></span><span class="c1">// Creating an in-memory key-value store:</span>
 <span class="c1">// here, we create a `KeyValueStore&lt;String, Long&gt;` named &quot;inmemory-counts&quot;.</span>
diff --git a/docs/streams/upgrade-guide.html b/docs/streams/upgrade-guide.html
index 58505e2..e07b5a9 100644
--- a/docs/streams/upgrade-guide.html
+++ b/docs/streams/upgrade-guide.html
@@ -73,7 +73,7 @@
     <h3><a id="streams_api_changes_230" href="#streams_api_changes_230">Streams API changes in 2.3.0</a></h3>
     <p>Version 2.3.0 adds the Suppress operator to the <code>kafka-streams-scala</code> Ktable API.</p>
     <p>
-        As of 2.3.0 Streams now offers an in-memory version of the window store, in addition to the persistent one based on RocksDB. The new public interface <code>inMemoryWindowStore()</code> is added to Stores that provides a built-in in-memory window store.
+        As of 2.3.0 Streams now offers an in-memory version of the window and the session store, in addition to the persistent ones based on RocksDB. The new public interfaces <code>inMemoryWindowStore()</code> and <code>inMemorySessionStore()</code> are added to Stores and provide the built-in in-memory window or session store.
     </p>
 
     <p>
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java
index 4f897e3..faaa751 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java
@@ -46,7 +46,7 @@ public interface SessionStore<K, AGG> extends StateStore, ReadOnlySessionStore<K
      * @return iterator of sessions with the matching key and aggregated values
      * @throws NullPointerException If null is used for key.
      */
-    KeyValueIterator<Windowed<K>, AGG> findSessions(final K key, long earliestSessionEndTime, final long latestSessionStartTime);
+    KeyValueIterator<Windowed<K>, AGG> findSessions(final K key, final long earliestSessionEndTime, final long latestSessionStartTime);
 
     /**
      * Fetch any sessions in the given range of keys and the sessions end is &ge; earliestSessionEndTime and the sessions
@@ -61,7 +61,7 @@ public interface SessionStore<K, AGG> extends StateStore, ReadOnlySessionStore<K
      * @return iterator of sessions with the matching keys and aggregated values
      * @throws NullPointerException If null is used for any key.
      */
-    KeyValueIterator<Windowed<K>, AGG> findSessions(final K keyFrom, final K keyTo, long earliestSessionEndTime, final long latestSessionStartTime);
+    KeyValueIterator<Windowed<K>, AGG> findSessions(final K keyFrom, final K keyTo, final long earliestSessionEndTime, final long latestSessionStartTime);
 
     /**
      * Get the value of key from a single session.
@@ -72,7 +72,7 @@ public interface SessionStore<K, AGG> extends StateStore, ReadOnlySessionStore<K
      * @return The value or {@code null} if no session associated with the key can be found
      * @throws NullPointerException If {@code null} is used for any key.
      */
-    AGG fetchSession(K key, long startTime, long endTime);
+    AGG fetchSession(final K key, final long startTime, final long endTime);
 
     /**
      * Remove the session aggregated with provided {@link Windowed} key from the store
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/Stores.java b/streams/src/main/java/org/apache/kafka/streams/state/Stores.java
index 70bc15a..c85fe03 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/Stores.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/Stores.java
@@ -22,6 +22,7 @@ import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.internals.ApiUtils;
 import org.apache.kafka.streams.state.internals.InMemoryKeyValueStore;
+import org.apache.kafka.streams.state.internals.InMemorySessionBytesStoreSupplier;
 import org.apache.kafka.streams.state.internals.InMemoryWindowBytesStoreSupplier;
 import org.apache.kafka.streams.state.internals.KeyValueStoreBuilder;
 import org.apache.kafka.streams.state.internals.MemoryNavigableLRUCache;
@@ -275,6 +276,26 @@ public class Stores {
     }
 
     /**
+     * Create an in-memory {@link SessionBytesStoreSupplier}.
+     * @param name              name of the store (cannot be {@code null})
+     * @param retentionPeriod   length ot time to retain data in the store (cannot be negative)
+     *                          Note that the retention period must be at least long enough to contain the
+     *                          windowed data's entire life cycle, from window-start through window-end,
+     *                          and for the entire grace period.
+     * @return an instance of a {@link  SessionBytesStoreSupplier}
+     */
+    public static SessionBytesStoreSupplier inMemorySessionStore(final String name, final Duration retentionPeriod) {
+        Objects.requireNonNull(name, "name cannot be null");
+
+        final String msgPrefix = prepareMillisCheckFailMsgPrefix(retentionPeriod, "retentionPeriod");
+        final long retentionPeriodMs = ApiUtils.validateMillisecondDuration(retentionPeriod, msgPrefix);
+        if (retentionPeriodMs < 0) {
+            throw new IllegalArgumentException("retentionPeriod cannot be negative");
+        }
+        return new InMemorySessionBytesStoreSupplier(name, retentionPeriodMs);
+    }
+
+    /**
      * Create a persistent {@link SessionBytesStoreSupplier}.
      * @param name              name of the store (cannot be {@code null})
      * @param retentionPeriodMs length ot time to retain data in the store (cannot be negative)
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionBytesStoreSupplier.java
new file mode 100644
index 0000000..18e8971
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionBytesStoreSupplier.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.state.SessionBytesStoreSupplier;
+import org.apache.kafka.streams.state.SessionStore;
+
+public class InMemorySessionBytesStoreSupplier implements SessionBytesStoreSupplier {
+    private final String name;
+    private final long retentionPeriod;
+
+    public InMemorySessionBytesStoreSupplier(final String name,
+                                             final long retentionPeriod) {
+        this.name = name;
+        this.retentionPeriod = retentionPeriod;
+    }
+
+    @Override
+    public String name() {
+        return name;
+    }
+
+    @Override
+    public SessionStore<Bytes, byte[]> get() {
+        return new InMemorySessionStore(name, retentionPeriod, metricsScope());
+    }
+
+    @Override
+    public String metricsScope() {
+        return "in-memory-session-state";
+    }
+
+    // In-memory store is not *really* segmented, so just say it is 1 (for ordering consistency with caching enabled)
+    @Override
+    public long segmentIntervalMs() {
+        return 1;
+    }
+
+    @Override
+    public long retentionPeriod() {
+        return retentionPeriod;
+    }
+}
+
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
new file mode 100644
index 0000000..c39dd58
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.EXPIRED_WINDOW_RECORD_DROP;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCount;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.NoSuchElementException;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentNavigableMap;
+import java.util.concurrent.ConcurrentSkipListMap;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.kstream.Windowed;
+import org.apache.kafka.streams.kstream.internals.SessionWindow;
+import org.apache.kafka.streams.processor.ProcessorContext;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.SessionStore;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
+
+    private static final Logger LOG = LoggerFactory.getLogger(InMemorySessionStore.class);
+
+    private final String name;
+    private final String metricScope;
+    private Sensor expiredRecordSensor;
+    private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP;
+
+    private final long retentionPeriod;
+
+    private final ConcurrentNavigableMap<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>> endTimeMap = new ConcurrentSkipListMap<>();
+    private final Set<InMemorySessionStoreIterator> openIterators  = ConcurrentHashMap.newKeySet();
+
+    private volatile boolean open = false;
+
+    InMemorySessionStore(final String name,
+                         final long retentionPeriod,
+                         final String metricScope) {
+        this.name = name;
+        this.retentionPeriod = retentionPeriod;
+        this.metricScope = metricScope;
+    }
+
+    @Override
+    public String name() {
+        return name;
+    }
+
+    @Override
+    public void init(final ProcessorContext context, final StateStore root) {
+        final StreamsMetricsImpl metrics = ((InternalProcessorContext) context).metrics();
+        final String taskName = context.taskId().toString();
+        expiredRecordSensor = metrics.storeLevelSensor(
+            taskName,
+            name(),
+            EXPIRED_WINDOW_RECORD_DROP,
+            Sensor.RecordingLevel.INFO
+        );
+        addInvocationRateAndCount(
+            expiredRecordSensor,
+            "stream-" + metricScope + "-metrics",
+            metrics.tagMap("task-id", taskName, metricScope + "-id", name()),
+            EXPIRED_WINDOW_RECORD_DROP
+        );
+
+        if (root != null) {
+            context.register(root, (key, value) -> put(SessionKeySchema.from(Bytes.wrap(key)), value));
+        }
+        open = true;
+    }
+
+    @Override
+    public void put(final Windowed<Bytes> sessionKey, final byte[] aggregate) {
+        removeExpiredSegments();
+
+        final long windowEndTimestamp = sessionKey.window().end();
+        observedStreamTime = Math.max(observedStreamTime, windowEndTimestamp);
+
+        if (windowEndTimestamp <= observedStreamTime - retentionPeriod) {
+            expiredRecordSensor.record();
+            LOG.debug("Skipping record for expired segment.");
+        } else {
+            if (aggregate != null) {
+                endTimeMap.computeIfAbsent(windowEndTimestamp, t -> new ConcurrentSkipListMap<>());
+                final ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>> keyMap = endTimeMap.get(windowEndTimestamp);
+                keyMap.computeIfAbsent(sessionKey.key(), t -> new ConcurrentSkipListMap<>());
+                keyMap.get(sessionKey.key()).put(sessionKey.window().start(), aggregate);
+            } else {
+                remove(sessionKey);
+            }
+        }
+    }
+
+    @Override
+    public void remove(final Windowed<Bytes> sessionKey) {
+        final ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>> keyMap = endTimeMap.get(sessionKey.window().end());
+        final ConcurrentNavigableMap<Long, byte[]> startTimeMap = keyMap.get(sessionKey.key());
+        startTimeMap.remove(sessionKey.window().start());
+
+        if (startTimeMap.isEmpty()) {
+            keyMap.remove(sessionKey.key());
+            if (keyMap.isEmpty()) {
+                endTimeMap.remove(sessionKey.window().end());
+            }
+        }
+    }
+
+    @Override
+    public byte[] fetchSession(final Bytes key, final long startTime, final long endTime) {
+        removeExpiredSegments();
+
+        // Only need to search if the record hasn't expired yet
+        if (endTime > observedStreamTime - retentionPeriod) {
+            final ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>> keyMap = endTimeMap.get(endTime);
+            if (keyMap != null) {
+                final ConcurrentNavigableMap<Long, byte[]> startTimeMap = keyMap.get(key);
+                if (startTimeMap != null) {
+                    return startTimeMap.get(startTime);
+                }
+            }
+        }
+        return null;
+    }
+
+    @Deprecated
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> findSessions(final Bytes key,
+                                                                  final long earliestSessionEndTime,
+                                                                  final long latestSessionStartTime) {
+        removeExpiredSegments();
+
+        return registerNewIterator(key,
+                                   key,
+                                   latestSessionStartTime,
+                                   endTimeMap.tailMap(earliestSessionEndTime, true).entrySet().iterator());
+    }
+
+    @Deprecated
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> findSessions(final Bytes keyFrom,
+                                                                  final Bytes keyTo,
+                                                                  final long earliestSessionEndTime,
+                                                                  final long latestSessionStartTime) {
+        removeExpiredSegments();
+
+        if (keyFrom.compareTo(keyTo) > 0) {
+            LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. "
+                + "This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " +
+                "Note that the built-in numerical serdes do not follow this for negative numbers");
+            return KeyValueIterators.emptyIterator();
+        }
+
+        return registerNewIterator(keyFrom,
+                                   keyTo,
+                                   latestSessionStartTime,
+                                   endTimeMap.tailMap(earliestSessionEndTime, true).entrySet().iterator());
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes key) {
+        removeExpiredSegments();
+
+        return registerNewIterator(key, key, Long.MAX_VALUE, endTimeMap.entrySet().iterator());
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes from, final Bytes to) {
+        removeExpiredSegments();
+
+        return registerNewIterator(from, to, Long.MAX_VALUE, endTimeMap.entrySet().iterator());
+    }
+
+    @Override
+    public boolean persistent() {
+        return false;
+    }
+
+    @Override
+    public boolean isOpen() {
+        return open;
+    }
+
+    @Override
+    public void flush() {
+        // do-nothing since it is in-memory
+    }
+
+    @Override
+    public void close() {
+        endTimeMap.clear();
+        openIterators.clear();
+        open = false;
+    }
+
+    private void removeExpiredSegments() {
+        long minLiveTime = Math.max(0L, observedStreamTime - retentionPeriod + 1);
+
+        for (final InMemorySessionStoreIterator it : openIterators) {
+            minLiveTime = Math.min(minLiveTime, it.minTime());
+        }
+
+        endTimeMap.headMap(minLiveTime, false).clear();
+    }
+
+    private InMemorySessionStoreIterator registerNewIterator(final Bytes keyFrom,
+                                                             final Bytes keyTo,
+                                                             final long latestSessionStartTime,
+                                                             final Iterator<Entry<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>>> endTimeIterator) {
+        final InMemorySessionStoreIterator iterator = new InMemorySessionStoreIterator(keyFrom, keyTo, latestSessionStartTime, endTimeIterator, it -> openIterators.remove(it));
+        openIterators.add(iterator);
+        return iterator;
+    }
+
+    interface ClosingCallback {
+        void deregisterIterator(final InMemorySessionStoreIterator iterator);
+    }
+
+    private static class InMemorySessionStoreIterator implements KeyValueIterator<Windowed<Bytes>, byte[]> {
+
+        private final Iterator<Entry<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>>> endTimeIterator;
+        private Iterator<Entry<Bytes, ConcurrentNavigableMap<Long, byte[]>>> keyIterator;
+        private Iterator<Entry<Long, byte[]>> recordIterator;
+
+        private KeyValue<Windowed<Bytes>, byte[]> next;
+        private Bytes currentKey;
+        private long currentEndTime;
+
+        private final Bytes keyFrom;
+        private final Bytes keyTo;
+        private final long latestSessionStartTime;
+
+        private final ClosingCallback callback;
+
+        InMemorySessionStoreIterator(final Bytes keyFrom,
+                                     final Bytes keyTo,
+                                     final long latestSessionStartTime,
+                                     final Iterator<Entry<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>>> endTimeIterator,
+                                     final ClosingCallback callback) {
+            this.keyFrom = keyFrom;
+            this.keyTo = keyTo;
+            this.latestSessionStartTime = latestSessionStartTime;
+
+            this.endTimeIterator = endTimeIterator;
+            this.callback = callback;
+            setAllIterators();
+        }
+
+        @Override
+        public boolean hasNext() {
+            if (next != null) {
+                return true;
+            } else if (recordIterator == null) {
+                return false;
+            } else {
+                next = getNext();
+                return next != null;
+            }
+        }
+
+        @Override
+        public Windowed<Bytes> peekNextKey() {
+            if (!hasNext()) {
+                throw new NoSuchElementException();
+            }
+            return next.key;
+        }
+
+        @Override
+        public KeyValue<Windowed<Bytes>, byte[]> next() {
+            if (!hasNext()) {
+                throw new NoSuchElementException();
+            }
+
+            final KeyValue<Windowed<Bytes>, byte[]> ret = next;
+            next = null;
+            return ret;
+        }
+
+        @Override
+        public void close() {
+            callback.deregisterIterator(this);
+        }
+
+        Long minTime() {
+            return currentEndTime;
+        }
+
+        // getNext is only called when either recordIterator or segmentIterator has a next
+        // Note this does not guarantee a next record exists as the next segments may not contain any keys in range
+        private KeyValue<Windowed<Bytes>, byte[]> getNext() {
+            if (!recordIterator.hasNext()) {
+                getNextIterators();
+            }
+
+            if (recordIterator == null) {
+                return null;
+            }
+
+            final Map.Entry<Long, byte[]> nextRecord = recordIterator.next();
+            final SessionWindow sessionWindow = new SessionWindow(nextRecord.getKey(), currentEndTime);
+            final Windowed<Bytes> windowedKey = new Windowed<>(currentKey, sessionWindow);
+
+            return new KeyValue<>(windowedKey, nextRecord.getValue());
+        }
+
+        // Called when the inner two (key and starttime) iterators are empty to roll to the next endTimestamp
+        // Rolls all three iterators forward until recordIterator has a next entry
+        // Sets recordIterator to null if there are no records to return
+        private void setAllIterators() {
+            while (endTimeIterator.hasNext()) {
+                final Entry<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>> nextEndTimeEntry = endTimeIterator.next();
+                currentEndTime = nextEndTimeEntry.getKey();
+                keyIterator = nextEndTimeEntry.getValue().subMap(keyFrom, true, keyTo, true).entrySet().iterator();
+
+                if (setInnerIterators()) {
+                    return;
+                }
+            }
+            recordIterator = null;
+        }
+
+        // Rolls the inner two iterators (key and record) forward until recordIterators has a next entry
+        // Returns false if no more records are found (for the current end time)
+        private boolean setInnerIterators() {
+            while (keyIterator.hasNext()) {
+                final Entry<Bytes, ConcurrentNavigableMap<Long, byte[]>> nextKeyEntry = keyIterator.next();
+                currentKey = nextKeyEntry.getKey();
+
+                if (latestSessionStartTime == Long.MAX_VALUE) {
+                    recordIterator = nextKeyEntry.getValue().entrySet().iterator();
+                } else {
+                    recordIterator = nextKeyEntry.getValue().headMap(latestSessionStartTime, true).entrySet().iterator();
+                }
+
+                if (recordIterator.hasNext()) {
+                    return true;
+                }
+            }
+            return false;
+        }
+
+        // Called when the current recordIterator has no entries left to roll it to the next valid entry
+        // When there are no more records to return, recordIterator will be set to null
+        private void getNextIterators() {
+            if (setInnerIterators()) {
+                return;
+            }
+
+            setAllIterators();
+        }
+    }
+
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java
similarity index 53%
copy from streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
copy to streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java
index 80ea4ba..bbe8d21 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java
@@ -16,67 +16,118 @@
  */
 package org.apache.kafka.streams.state.internals;
 
+import static java.time.Duration.ofMillis;
+
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.apache.kafka.test.StreamsTestUtils.toList;
+import static org.apache.kafka.test.StreamsTestUtils.valuesToList;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.hasItem;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import java.util.Map;
+import org.apache.kafka.clients.producer.MockProducer;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.kstream.internals.SessionWindow;
 import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
+import org.apache.kafka.streams.processor.internals.RecordCollector;
+import org.apache.kafka.streams.processor.internals.RecordCollectorImpl;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.SessionStore;
 import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.test.InternalMockProcessorContext;
-import org.apache.kafka.test.NoOpRecordCollector;
 import org.apache.kafka.test.TestUtils;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
+public class InMemorySessionStoreTest {
 
-import static java.time.Duration.ofMillis;
-import static org.apache.kafka.test.StreamsTestUtils.toList;
-import static org.apache.kafka.test.StreamsTestUtils.valuesToList;
-import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.CoreMatchers.hasItem;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-
-public class RocksDBSessionStoreTest {
+    private static final String STORE_NAME = "InMemorySessionStore";
+    private static final long RETENTION_PERIOD = 10_000L;
 
     private SessionStore<String, Long> sessionStore;
     private InternalMockProcessorContext context;
 
-    @Before
-    public void before() {
-        sessionStore = Stores.sessionStoreBuilder(
-            Stores.persistentSessionStore(
-                "session-store",
-                ofMillis(10_000L)),
+    private final List<KeyValue<byte[], byte[]>> changeLog = new ArrayList<>();
+
+    private final Producer<byte[], byte[]> producer = new MockProducer<>(true,
+        Serdes.ByteArray().serializer(),
+        Serdes.ByteArray().serializer());
+
+    private final RecordCollector recordCollector = new RecordCollectorImpl(
+        STORE_NAME,
+        new LogContext(STORE_NAME),
+        new DefaultProductionExceptionHandler(),
+        new Metrics().sensor("skipped-records")) {
+
+        @Override
+        public <K1, V1> void send(final String topic,
+            final K1 key,
+            final V1 value,
+            final Headers headers,
+            final Integer partition,
+            final Long timestamp,
+            final Serializer<K1> keySerializer,
+            final Serializer<V1> valueSerializer) {
+            changeLog.add(new KeyValue<>(
+                keySerializer.serialize(topic, headers, key),
+                valueSerializer.serialize(topic, headers, value))
+            );
+        }
+    };
+
+    private SessionStore<String, Long> buildSessionStore(final long retentionPeriod) {
+        return Stores.sessionStoreBuilder(
+            Stores.inMemorySessionStore(
+                STORE_NAME,
+                ofMillis(retentionPeriod)),
             Serdes.String(),
             Serdes.Long()).build();
+    }
 
+    @Before
+    public void before() {
         context = new InternalMockProcessorContext(
             TestUtils.tempDirectory(),
             Serdes.String(),
             Serdes.Long(),
-            new NoOpRecordCollector(),
+            recordCollector,
             new ThreadCache(
-                new LogContext("testCache "),
+                new LogContext("testCache"),
                 0,
                 new MockStreamsMetrics(new Metrics())));
 
+        sessionStore = buildSessionStore(RETENTION_PERIOD);
+
         sessionStore.init(context, sessionStore);
+        recordCollector.init(producer);
     }
 
     @After
-    public void close() {
+    public void after() {
         sessionStore.close();
     }
 
@@ -94,7 +145,7 @@ public class RocksDBSessionStoreTest {
             Arrays.asList(KeyValue.pair(a1, 1L), KeyValue.pair(a2, 2L));
 
         try (final KeyValueIterator<Windowed<String>, Long> values =
-                 sessionStore.findSessions(key, 0, 1000L)
+            sessionStore.findSessions(key, 0, 1000L)
         ) {
             assertEquals(expected, toList(values));
         }
@@ -102,7 +153,7 @@ public class RocksDBSessionStoreTest {
         final List<KeyValue<Windowed<String>, Long>> expected2 = Collections.singletonList(KeyValue.pair(a2, 2L));
 
         try (final KeyValueIterator<Windowed<String>, Long> values2 =
-                 sessionStore.findSessions(key, 400L, 600L)
+            sessionStore.findSessions(key, 400L, 600L)
         ) {
             assertEquals(expected2, toList(values2));
         }
@@ -129,6 +180,39 @@ public class RocksDBSessionStoreTest {
     }
 
     @Test
+    public void shouldFetchAllSessionsWithinKeyRange() {
+        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
+            KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L),
+            KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L),
+            KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L),
+            KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L));
+
+        for (final KeyValue<Windowed<String>, Long> kv : expected) {
+            sessionStore.put(kv.key, kv.value);
+        }
+
+        // add some that shouldn't appear in the results
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
+        sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L);
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("aa", "bb")) {
+            assertEquals(expected, toList(values));
+        }
+    }
+
+    @Test
+    public void shouldFetchExactSession() {
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 4)), 1L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 3)), 2L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 4)), 3L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(1, 4)), 4L);
+        sessionStore.put(new Windowed<>("aaa", new SessionWindow(0, 4)), 5L);
+
+        final long result = sessionStore.fetchSession("aa", 0, 4);
+        assertEquals(3L, result);
+    }
+
+    @Test
     public void shouldFindValuesWithinMergingSessionWindowRange() {
         final String key = "a";
         sessionStore.put(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L);
@@ -139,7 +223,7 @@ public class RocksDBSessionStoreTest {
             KeyValue.pair(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L));
 
         try (final KeyValueIterator<Windowed<String>, Long> results =
-                 sessionStore.findSessions(key, -1, 1000L)) {
+            sessionStore.findSessions(key, -1, 1000L)) {
             assertEquals(expected, toList(results));
         }
     }
@@ -152,12 +236,30 @@ public class RocksDBSessionStoreTest {
         sessionStore.remove(new Windowed<>("a", new SessionWindow(0, 1000)));
 
         try (final KeyValueIterator<Windowed<String>, Long> results =
-                 sessionStore.findSessions("a", 0L, 1000L)) {
+            sessionStore.findSessions("a", 0L, 1000L)) {
+            assertFalse(results.hasNext());
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> results =
+            sessionStore.findSessions("a", 1500L, 2500L)) {
+            assertTrue(results.hasNext());
+        }
+    }
+
+    @Test
+    public void shouldRemoveOnNullAggValue() {
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), 1L);
+        sessionStore.put(new Windowed<>("a", new SessionWindow(1500, 2500)), 2L);
+
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), null);
+
+        try (final KeyValueIterator<Windowed<String>, Long> results =
+            sessionStore.findSessions("a", 0L, 1000L)) {
             assertFalse(results.hasNext());
         }
 
         try (final KeyValueIterator<Windowed<String>, Long> results =
-                 sessionStore.findSessions("a", 1500L, 2500L)) {
+            sessionStore.findSessions("a", 1500L, 2500L)) {
             assertTrue(results.hasNext());
         }
     }
@@ -176,7 +278,7 @@ public class RocksDBSessionStoreTest {
         sessionStore.put(session5, 5L);
 
         try (final KeyValueIterator<Windowed<String>, Long> results =
-                 sessionStore.findSessions("a", 150, 300)
+            sessionStore.findSessions("a", 150, 300)
         ) {
             assertEquals(session2, results.next().key);
             assertEquals(session3, results.next().key);
@@ -186,13 +288,7 @@ public class RocksDBSessionStoreTest {
 
     @Test
     public void shouldFetchExactKeys() {
-        sessionStore = Stores.sessionStoreBuilder(
-            Stores.persistentSessionStore(
-                "session-store",
-                ofMillis(0x7a00000000000000L)),
-            Serdes.String(),
-            Serdes.Long()).build();
-
+        sessionStore = buildSessionStore(0x7a00000000000000L);
         sessionStore.init(context, sessionStore);
 
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
@@ -202,31 +298,91 @@ public class RocksDBSessionStoreTest {
         sessionStore.put(new Windowed<>("a", new SessionWindow(0x7a00000000000000L - 2, 0x7a00000000000000L - 1)), 5L);
 
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
-                 sessionStore.findSessions("a", 0, Long.MAX_VALUE)
+            sessionStore.findSessions("a", 0, Long.MAX_VALUE)
         ) {
             assertThat(valuesToList(iterator), equalTo(Arrays.asList(1L, 3L, 5L)));
         }
 
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
-                 sessionStore.findSessions("aa", 0, Long.MAX_VALUE)
+            sessionStore.findSessions("aa", 0, Long.MAX_VALUE)
         ) {
             assertThat(valuesToList(iterator), equalTo(Arrays.asList(2L, 4L)));
         }
 
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
-                 sessionStore.findSessions("a", "aa", 0, Long.MAX_VALUE)
+            sessionStore.findSessions("a", "aa", 0, Long.MAX_VALUE)
         ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(1L, 3L, 2L, 4L, 5L)));
+            assertThat(valuesToList(iterator), equalTo(Arrays.asList(1L, 2L, 3L, 4L, 5L)));
         }
 
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
-                 sessionStore.findSessions("a", "aa", 10, 0)
+            sessionStore.findSessions("a", "aa", 10, 0)
         ) {
             assertThat(valuesToList(iterator), equalTo(Collections.singletonList(2L)));
         }
     }
 
     @Test
+    public void testIteratorPeek() {
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
+        sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L);
+
+        final KeyValueIterator<Windowed<String>, Long> iterator = sessionStore.findSessions("a", 0L, 20);
+
+        assertEquals(iterator.peekNextKey(), new Windowed<>("a", new SessionWindow(0L, 0L)));
+        assertEquals(iterator.peekNextKey(), iterator.next().key);
+        assertEquals(iterator.peekNextKey(), iterator.next().key);
+        assertFalse(iterator.hasNext());
+    }
+
+    @Test
+    public void shouldRemoveExpired() {
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
+        sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L);
+
+        // Advance stream time to expire the first record
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(10, RETENTION_PERIOD)), 4L);
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+            sessionStore.findSessions("a", "b", 0L, Long.MAX_VALUE)
+        ) {
+            assertThat(valuesToList(iterator), equalTo(Arrays.asList(2L, 3L, 4L)));
+        }
+    }
+
+    @Test
+    public void shouldRestore() {
+        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L),
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L),
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L),
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L));
+
+        for (final KeyValue<Windowed<String>, Long> kv : expected) {
+            sessionStore.put(kv.key, kv.value);
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
+            assertEquals(expected, toList(values));
+        }
+
+        sessionStore.close();
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
+            assertEquals(Collections.emptyList(), toList(values));
+        }
+
+        context.restore(STORE_NAME, changeLog);
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
+            assertEquals(expected, toList(values));
+        }
+    }
+
+    @Test
     public void shouldReturnSameResultsForSingleKeyFindSessionsAndEqualKeyRangeFindSessions() {
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1)), 0L);
         sessionStore.put(new Windowed<>("aa", new SessionWindow(2, 3)), 1L);
@@ -242,6 +398,51 @@ public class RocksDBSessionStoreTest {
         assertFalse(keyRangeIterator.hasNext());
     }
 
+    @Test
+    public void shouldLogAndMeasureExpiredRecords() {
+        LogCaptureAppender.setClassLoggerToDebug(InMemorySessionStore.class);
+        final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
+
+
+        // Advance stream time by inserting record with large enough timestamp that records with timestamp 0 are expired
+        sessionStore.put(new Windowed<>("initial record", new SessionWindow(0, RETENTION_PERIOD)), 0L);
+
+        // Try inserting a record with timestamp 0 -- should be dropped
+        sessionStore.put(new Windowed<>("late record", new SessionWindow(0, 0)), 0L);
+        sessionStore.put(new Windowed<>("another on-time record", new SessionWindow(0, RETENTION_PERIOD)), 0L);
+
+        LogCaptureAppender.unregister(appender);
+
+        final Map<MetricName, ? extends Metric> metrics = context.metrics().metrics();
+
+        final Metric dropTotal = metrics.get(new MetricName(
+            "expired-window-record-drop-total",
+            "stream-in-memory-session-state-metrics",
+            "The total number of occurrence of expired-window-record-drop operations.",
+            mkMap(
+                mkEntry("client-id", "mock"),
+                mkEntry("task-id", "0_0"),
+                mkEntry("in-memory-session-state-id", STORE_NAME)
+            )
+        ));
+
+        final Metric dropRate = metrics.get(new MetricName(
+            "expired-window-record-drop-rate",
+            "stream-in-memory-session-state-metrics",
+            "The average number of occurrence of expired-window-record-drop operation per second.",
+            mkMap(
+                mkEntry("client-id", "mock"),
+                mkEntry("task-id", "0_0"),
+                mkEntry("in-memory-session-state-id", STORE_NAME)
+            )
+        ));
+
+        assertEquals(1.0, dropTotal.metricValue());
+        assertNotEquals(0.0, dropRate.metricValue());
+        final List<String> messages = appender.getMessages();
+        assertThat(messages, hasItem("Skipping record for expired segment."));
+    }
+
     @Test(expected = NullPointerException.class)
     public void shouldThrowNullPointerExceptionOnFindSessionsNullKey() {
         sessionStore.findSessions(null, 1L, 2L);
@@ -284,7 +485,7 @@ public class RocksDBSessionStoreTest {
 
     @Test
     public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() {
-        LogCaptureAppender.setClassLoggerToDebug(InMemoryWindowStore.class);
+        LogCaptureAppender.setClassLoggerToDebug(InMemorySessionStore.class);
         final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
 
         final String keyFrom = Serdes.String().deserializer().deserialize("", Serdes.Integer().serializer().serialize("", -1));
@@ -298,4 +499,4 @@ public class RocksDBSessionStoreTest {
             + "This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes. "
             + "Note that the built-in numerical serdes do not follow this for negative numbers"));
     }
-}
+}
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
index 80ea4ba..41abdad 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
@@ -129,6 +129,39 @@ public class RocksDBSessionStoreTest {
     }
 
     @Test
+    public void shouldFetchAllSessionsWithinKeyRange() {
+        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
+            KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L),
+            KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L),
+            KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L),
+            KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L));
+
+        for (final KeyValue<Windowed<String>, Long> kv : expected) {
+            sessionStore.put(kv.key, kv.value);
+        }
+
+        // add some that shouldn't appear in the results
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
+        sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L);
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("aa", "bb")) {
+            assertEquals(expected, toList(values));
+        }
+    }
+
+    @Test
+    public void shouldFetchExactSession() {
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 4)), 1L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 3)), 2L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 4)), 3L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(1, 4)), 4L);
+        sessionStore.put(new Windowed<>("aaa", new SessionWindow(0, 4)), 5L);
+
+        final long result = sessionStore.fetchSession("aa", 0, 4);
+        assertEquals(3L, result);
+    }
+
+    @Test
     public void shouldFindValuesWithinMergingSessionWindowRange() {
         final String key = "a";
         sessionStore.put(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L);
@@ -163,6 +196,24 @@ public class RocksDBSessionStoreTest {
     }
 
     @Test
+    public void shouldRemoveOnNullAggValue() {
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), 1L);
+        sessionStore.put(new Windowed<>("a", new SessionWindow(1500, 2500)), 2L);
+
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), null);
+
+        try (final KeyValueIterator<Windowed<String>, Long> results =
+            sessionStore.findSessions("a", 0L, 1000L)) {
+            assertFalse(results.hasNext());
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> results =
+            sessionStore.findSessions("a", 1500L, 2500L)) {
+            assertTrue(results.hasNext());
+        }
+    }
+
+    @Test
     public void shouldFindSessionsToMerge() {
         final Windowed<String> session1 = new Windowed<>("a", new SessionWindow(0, 100));
         final Windowed<String> session2 = new Windowed<>("a", new SessionWindow(101, 200));


Mime
View raw message