kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From j...@apache.org
Subject [kafka] branch trunk updated: MINOR: Add Timer to simplify timeout bookkeeping and use it in the consumer (#5087)
Date Sat, 04 Aug 2018 00:25:12 GMT
This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new fc5f6b0  MINOR: Add Timer to simplify timeout bookkeeping and use it in the consumer (#5087)
fc5f6b0 is described below

commit fc5f6b0e46ff81302b3e445fed0cdf454c942792
Author: Jason Gustafson <jason@confluent.io>
AuthorDate: Fri Aug 3 17:25:07 2018 -0700

    MINOR: Add Timer to simplify timeout bookkeeping and use it in the consumer (#5087)
    
    We currently do a lot of bookkeeping for timeouts which is both error-prone and distracting. This patch adds a new `Timer` class to simplify this logic and control unnecessary calls to system time. In particular, this helps with nested timeout operations. The consumer has been updated to use the new class.
    
    Reviewers: Ismael Juma <ismael@juma.me.uk>, Guozhang Wang <wangguoz@gmail.com>
---
 .../kafka/clients/consumer/KafkaConsumer.java      | 109 ++++----
 .../consumer/internals/AbstractCoordinator.java    |  86 +++----
 .../consumer/internals/ConsumerCoordinator.java    | 145 ++++-------
 .../consumer/internals/ConsumerNetworkClient.java  |  89 +++----
 .../kafka/clients/consumer/internals/Fetcher.java  |  66 ++---
 .../clients/consumer/internals/Heartbeat.java      |  78 +++---
 .../org/apache/kafka/common/utils/SystemTime.java  |   7 -
 .../java/org/apache/kafka/common/utils/Time.java   |  20 +-
 .../java/org/apache/kafka/common/utils/Timer.java  | 180 +++++++++++++
 .../kafka/clients/consumer/KafkaConsumerTest.java  |  82 +++---
 .../internals/AbstractCoordinatorTest.java         |  40 ++-
 .../internals/ConsumerCoordinatorTest.java         | 281 +++++++++++----------
 .../internals/ConsumerNetworkClientTest.java       |  16 +-
 .../clients/consumer/internals/FetcherTest.java    | 196 +++++++-------
 .../clients/consumer/internals/HeartbeatTest.java  |  51 +++-
 .../org/apache/kafka/common/utils/MockTime.java    |   5 -
 .../org/apache/kafka/common/utils/TimerTest.java   | 127 ++++++++++
 .../runtime/distributed/WorkerCoordinator.java     |  10 +-
 .../runtime/distributed/WorkerCoordinatorTest.java |   8 +-
 core/src/main/scala/kafka/admin/AdminClient.scala  |   2 +-
 20 files changed, 916 insertions(+), 682 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
index 651bd79..9071c9d 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
@@ -51,6 +51,7 @@ import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.utils.AppInfoParser;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Timer;
 import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
 
@@ -754,7 +755,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                     groupId,
                     maxPollIntervalMs,
                     sessionTimeoutMs,
-                    new Heartbeat(sessionTimeoutMs, heartbeatIntervalMs, maxPollIntervalMs, retryBackoffMs),
+                    new Heartbeat(time, sessionTimeoutMs, heartbeatIntervalMs, maxPollIntervalMs, retryBackoffMs),
                     assignors,
                     this.metadata,
                     this.subscriptions,
@@ -1085,7 +1086,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
      * offset for the subscribed list of partitions
      *
      *
-     * @param timeout The time, in milliseconds, spent waiting in poll if data is not available in the buffer.
+     * @param timeoutMs The time, in milliseconds, spent waiting in poll if data is not available in the buffer.
      *            If 0, returns immediately with any records that are available currently in the buffer, else returns empty.
      *            Must not be negative.
      * @return map of topic to records since the last fetch for the subscribed list of topics and partitions
@@ -1111,8 +1112,8 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
      */
     @Deprecated
     @Override
-    public ConsumerRecords<K, V> poll(final long timeout) {
-        return poll(timeout, false);
+    public ConsumerRecords<K, V> poll(final long timeoutMs) {
+        return poll(time.timer(timeoutMs), false);
     }
 
     /**
@@ -1153,41 +1154,31 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
      */
     @Override
     public ConsumerRecords<K, V> poll(final Duration timeout) {
-        return poll(timeout.toMillis(), true);
+        return poll(time.timer(timeout), true);
     }
 
-    private ConsumerRecords<K, V> poll(final long timeoutMs, final boolean includeMetadataInTimeout) {
+    private ConsumerRecords<K, V> poll(final Timer timer, final boolean includeMetadataInTimeout) {
         acquireAndEnsureOpen();
         try {
-            if (timeoutMs < 0) throw new IllegalArgumentException("Timeout must not be negative");
-
             if (this.subscriptions.hasNoSubscriptionOrUserAssignment()) {
                 throw new IllegalStateException("Consumer is not subscribed to any topics or assigned any partitions");
             }
 
             // poll for new data until the timeout expires
-            long elapsedTime = 0L;
             do {
-
                 client.maybeTriggerWakeup();
 
-                final long metadataEnd;
                 if (includeMetadataInTimeout) {
-                    final long metadataStart = time.milliseconds();
-                    if (!updateAssignmentMetadataIfNeeded(remainingTimeAtLeastZero(timeoutMs, elapsedTime))) {
+                    if (!updateAssignmentMetadataIfNeeded(timer)) {
                         return ConsumerRecords.empty();
                     }
-                    metadataEnd = time.milliseconds();
-                    elapsedTime += metadataEnd - metadataStart;
                 } else {
-                    while (!updateAssignmentMetadataIfNeeded(Long.MAX_VALUE)) {
+                    while (!updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE))) {
                         log.warn("Still waiting for metadata");
                     }
-                    metadataEnd = time.milliseconds();
                 }
 
-                final Map<TopicPartition, List<ConsumerRecord<K, V>>> records = pollForFetches(remainingTimeAtLeastZero(timeoutMs, elapsedTime));
-
+                final Map<TopicPartition, List<ConsumerRecord<K, V>>> records = pollForFetches(timer);
                 if (!records.isEmpty()) {
                     // before returning the fetched records, we can send off the next round of fetches
                     // and avoid block waiting for their responses to enable pipelining while the user
@@ -1201,10 +1192,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
 
                     return this.interceptors.onConsume(new ConsumerRecords<>(records));
                 }
-                final long fetchEnd = time.milliseconds();
-                elapsedTime += fetchEnd - metadataEnd;
-
-            } while (elapsedTime < timeoutMs);
+            } while (timer.notExpired());
 
             return ConsumerRecords.empty();
         } finally {
@@ -1215,18 +1203,16 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
     /**
      * Visible for testing
      */
-    boolean updateAssignmentMetadataIfNeeded(final long timeoutMs) {
-        final long startMs = time.milliseconds();
-        if (!coordinator.poll(timeoutMs)) {
+    boolean updateAssignmentMetadataIfNeeded(final Timer timer) {
+        if (!coordinator.poll(timer)) {
             return false;
         }
 
-        return updateFetchPositions(remainingTimeAtLeastZero(timeoutMs, time.milliseconds() - startMs));
+        return updateFetchPositions(timer);
     }
 
-    private Map<TopicPartition, List<ConsumerRecord<K, V>>> pollForFetches(final long timeoutMs) {
-        final long startMs = time.milliseconds();
-        long pollTimeout = Math.min(coordinator.timeToNextPoll(startMs), timeoutMs);
+    private Map<TopicPartition, List<ConsumerRecord<K, V>>> pollForFetches(Timer timer) {
+        long pollTimeout = Math.min(coordinator.timeToNextPoll(timer.currentTimeMs()), timer.remainingMs());
 
         // if data is available already, return it immediately
         final Map<TopicPartition, List<ConsumerRecord<K, V>>> records = fetcher.fetchedRecords();
@@ -1246,11 +1232,13 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
             pollTimeout = retryBackoffMs;
         }
 
-        client.poll(pollTimeout, startMs, () -> {
+        Timer pollTimer = time.timer(pollTimeout);
+        client.poll(pollTimer, () -> {
             // since a fetch might be completed by the background thread, we need this poll condition
             // to ensure that we do not block unnecessarily in poll()
             return !fetcher.hasCompletedFetches();
         });
+        timer.update(pollTimer.currentTimeMs());
 
         // after the long poll, we should check whether the group needs to rebalance
         // prior to returning data so that the group can stabilize faster
@@ -1261,10 +1249,6 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
         return fetcher.fetchedRecords();
     }
 
-    private long remainingTimeAtLeastZero(final long timeoutMs, final long elapsedTime) {
-        return Math.max(0, timeoutMs - elapsedTime);
-    }
-
     /**
      * Commit offsets returned on the last {@link #poll(Duration) poll()} for all the subscribed list of topics and
      * partitions.
@@ -1333,7 +1317,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
     public void commitSync(Duration timeout) {
         acquireAndEnsureOpen();
         try {
-            if (!coordinator.commitOffsetsSync(subscriptions.allConsumed(), timeout.toMillis())) {
+            if (!coordinator.commitOffsetsSync(subscriptions.allConsumed(), time.timer(timeout))) {
                 throw new TimeoutException("Timeout of " + timeout.toMillis() + "ms expired before successfully " +
                         "committing the current consumed offsets");
             }
@@ -1415,7 +1399,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
     public void commitSync(final Map<TopicPartition, OffsetAndMetadata> offsets, final Duration timeout) {
         acquireAndEnsureOpen();
         try {
-            if (!coordinator.commitOffsetsSync(new HashMap<>(offsets), timeout.toMillis())) {
+            if (!coordinator.commitOffsetsSync(new HashMap<>(offsets), time.timer(timeout))) {
                 throw new TimeoutException("Timeout of " + timeout.toMillis() + "ms expired before successfully " +
                         "committing offsets " + offsets);
             }
@@ -1622,30 +1606,23 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
      */
     @Override
     public long position(TopicPartition partition, final Duration timeout) {
-        final long timeoutMs = timeout.toMillis();
         acquireAndEnsureOpen();
         try {
             if (!this.subscriptions.isAssigned(partition))
                 throw new IllegalStateException("You can only check the position for partitions assigned to this consumer.");
-            Long offset = this.subscriptions.position(partition);
-            final long startMs = time.milliseconds();
-            long finishMs = startMs;
-
-            while (offset == null && finishMs - startMs < timeoutMs) {
-                // batch update fetch positions for any partitions without a valid position
-                if (!updateFetchPositions(remainingTimeAtLeastZero(timeoutMs, time.milliseconds() - startMs))) {
-                    break;
-                }
-                finishMs = time.milliseconds();
 
-                client.poll(remainingTimeAtLeastZero(timeoutMs, finishMs - startMs));
-                offset = this.subscriptions.position(partition);
-                finishMs = time.milliseconds();
-            }
-            if (offset == null)
-                throw new TimeoutException("Timeout of " + timeout.toMillis() + "ms expired before the position " +
-                        "for partition " + partition + " could be determined");
-            return offset;
+            Timer timer = time.timer(timeout);
+            do {
+                Long offset = this.subscriptions.position(partition);
+                if (offset != null)
+                    return offset;
+
+                updateFetchPositions(timer);
+                client.poll(timer);
+            } while (timer.notExpired());
+
+            throw new TimeoutException("Timeout of " + timeout.toMillis() + "ms expired before the position " +
+                    "for partition " + partition + " could be determined");
         } finally {
             release();
         }
@@ -1703,7 +1680,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
         acquireAndEnsureOpen();
         try {
             Map<TopicPartition, OffsetAndMetadata> offsets = coordinator.fetchCommittedOffsets(
-                    Collections.singleton(partition), timeout.toMillis());
+                    Collections.singleton(partition), time.timer(timeout));
             if (offsets == null) {
                 throw new TimeoutException("Timeout of " + timeout.toMillis() + "ms expired before the last " +
                         "committed offset for partition " + partition + " could be determined");
@@ -1766,15 +1743,15 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
     @Override
     public List<PartitionInfo> partitionsFor(String topic, Duration timeout) {
         acquireAndEnsureOpen();
-        long timeoutMs = timeout.toMillis();
         try {
             Cluster cluster = this.metadata.fetch();
             List<PartitionInfo> parts = cluster.partitionsForTopic(topic);
             if (!parts.isEmpty())
                 return parts;
 
+            Timer timer = time.timer(requestTimeoutMs);
             Map<String, List<PartitionInfo>> topicMetadata = fetcher.getTopicMetadata(
-                    new MetadataRequest.Builder(Collections.singletonList(topic), true), timeoutMs);
+                    new MetadataRequest.Builder(Collections.singletonList(topic), true), timer);
             return topicMetadata.get(topic);
         } finally {
             release();
@@ -1819,7 +1796,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
     public Map<String, List<PartitionInfo>> listTopics(Duration timeout) {
         acquireAndEnsureOpen();
         try {
-            return fetcher.getAllTopicMetadata(timeout.toMillis());
+            return fetcher.getAllTopicMetadata(time.timer(timeout));
         } finally {
             release();
         }
@@ -1940,7 +1917,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                     throw new IllegalArgumentException("The target time for partition " + entry.getKey() + " is " +
                             entry.getValue() + ". The target time cannot be negative.");
             }
-            return fetcher.offsetsByTimes(timestampsToSearch, timeout.toMillis());
+            return fetcher.offsetsByTimes(timestampsToSearch, time.timer(timeout));
         } finally {
             release();
         }
@@ -1985,7 +1962,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
     public Map<TopicPartition, Long> beginningOffsets(Collection<TopicPartition> partitions, Duration timeout) {
         acquireAndEnsureOpen();
         try {
-            return fetcher.beginningOffsets(partitions, timeout.toMillis());
+            return fetcher.beginningOffsets(partitions, time.timer(timeout));
         } finally {
             release();
         }
@@ -2040,7 +2017,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
     public Map<TopicPartition, Long> endOffsets(Collection<TopicPartition> partitions, Duration timeout) {
         acquireAndEnsureOpen();
         try {
-            return fetcher.endOffsets(partitions, timeout.toMillis());
+            return fetcher.endOffsets(partitions, time.timer(timeout));
         } finally {
             release();
         }
@@ -2139,7 +2116,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
         AtomicReference<Throwable> firstException = new AtomicReference<>();
         try {
             if (coordinator != null)
-                coordinator.close(Math.min(timeoutMs, requestTimeoutMs));
+                coordinator.close(time.timer(Math.min(timeoutMs, requestTimeoutMs)));
         } catch (Throwable t) {
             firstException.compareAndSet(null, t);
             log.error("Failed to close coordinator", t);
@@ -2170,7 +2147,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
      *             defined
      * @return true iff the operation completed without timing out
      */
-    private boolean updateFetchPositions(final long timeoutMs) {
+    private boolean updateFetchPositions(final Timer timer) {
         cachedSubscriptionHashAllFetchPositions = subscriptions.hasAllFetchPositions();
         if (cachedSubscriptionHashAllFetchPositions) return true;
 
@@ -2179,7 +2156,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
         // coordinator lookup if there are partitions which have missing positions, so
         // a consumer with manually assigned partitions can avoid a coordinator dependence
         // by always ensuring that assigned partitions have an initial position.
-        if (!coordinator.refreshCommittedOffsetsIfNeeded(timeoutMs)) return false;
+        if (!coordinator.refreshCommittedOffsetsIfNeeded(timer)) return false;
 
         // If there are partitions still needing a position and a reset policy is defined,
         // request reset using the default policy. If no reset strategy is defined and there
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
index f9e1c18..d983087 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
@@ -51,6 +51,7 @@ import org.apache.kafka.common.requests.SyncGroupResponse;
 import org.apache.kafka.common.utils.KafkaThread;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Timer;
 import org.slf4j.Logger;
 
 import java.io.Closeable;
@@ -159,7 +160,7 @@ public abstract class AbstractCoordinator implements Closeable {
                                long retryBackoffMs,
                                boolean leaveGroupOnClose) {
         this(logContext, client, groupId, rebalanceTimeoutMs, sessionTimeoutMs,
-                new Heartbeat(sessionTimeoutMs, heartbeatIntervalMs, rebalanceTimeoutMs, retryBackoffMs),
+                new Heartbeat(time, sessionTimeoutMs, heartbeatIntervalMs, rebalanceTimeoutMs, retryBackoffMs),
                 metrics, metricGrpPrefix, time, retryBackoffMs, leaveGroupOnClose);
     }
 
@@ -218,16 +219,17 @@ public abstract class AbstractCoordinator implements Closeable {
      *
      * Ensure that the coordinator is ready to receive requests.
      *
-     * @param timeoutMs Maximum time to wait to discover the coordinator
+     * @param timer Timer bounding how long this method can block
      * @return true If coordinator discovery and initial connection succeeded, false otherwise
      */
-    protected synchronized boolean ensureCoordinatorReady(final long timeoutMs) {
-        final long startTimeMs = time.milliseconds();
-        long elapsedTime = 0L;
+    protected synchronized boolean ensureCoordinatorReady(final Timer timer) {
+        if (!coordinatorUnknown())
+            return true;
 
-        while (coordinatorUnknown()) {
+        do {
             final RequestFuture<Void> future = lookupCoordinator();
-            client.poll(future, remainingTimeAtLeastZero(timeoutMs, elapsedTime));
+            client.poll(future, timer);
+
             if (!future.isDone()) {
                 // ran out of time
                 break;
@@ -235,24 +237,17 @@ public abstract class AbstractCoordinator implements Closeable {
 
             if (future.failed()) {
                 if (future.isRetriable()) {
-                    elapsedTime = time.milliseconds() - startTimeMs;
-
-                    if (elapsedTime >= timeoutMs) break;
-
                     log.debug("Coordinator discovery failed, refreshing metadata");
-                    client.awaitMetadataUpdate(remainingTimeAtLeastZero(timeoutMs, elapsedTime));
-                    elapsedTime = time.milliseconds() - startTimeMs;
+                    client.awaitMetadataUpdate(timer);
                 } else
                     throw future.exception();
             } else if (coordinator != null && client.isUnavailable(coordinator)) {
                 // we found the coordinator, but the connection has failed, so mark
                 // it dead and backoff before retrying discovery
                 markCoordinatorUnknown();
-                final long sleepTime = Math.min(retryBackoffMs, remainingTimeAtLeastZero(timeoutMs, elapsedTime));
-                time.sleep(sleepTime);
-                elapsedTime += sleepTime;
+                timer.sleep(retryBackoffMs);
             }
-        }
+        } while (coordinatorUnknown() && timer.notExpired());
 
         return !coordinatorUnknown();
     }
@@ -291,6 +286,7 @@ public abstract class AbstractCoordinator implements Closeable {
      * to ensure that the member stays in the group. If an interval of time longer than the
      * provided rebalance timeout expires without calling this method, then the client will proactively
      * leave the group.
+     *
      * @param now current time in milliseconds
      * @throws RuntimeException for unexpected errors raised from the heartbeat thread
      */
@@ -322,7 +318,7 @@ public abstract class AbstractCoordinator implements Closeable {
      * Ensure that the group is active (i.e. joined and synced)
      */
     public void ensureActiveGroup() {
-        while (!ensureActiveGroup(Long.MAX_VALUE)) {
+        while (!ensureActiveGroup(time.timer(Long.MAX_VALUE))) {
             log.warn("still waiting to ensure active group");
         }
     }
@@ -330,26 +326,18 @@ public abstract class AbstractCoordinator implements Closeable {
     /**
      * Ensure the group is active (i.e., joined and synced)
      *
-     * @param timeoutMs A time budget for ensuring the group is active
+     * @param timer Timer bounding how long this method can block
      * @return true iff the group is active
      */
-    boolean ensureActiveGroup(final long timeoutMs) {
-        return ensureActiveGroup(timeoutMs, time.milliseconds());
-    }
-
-    // Visible for testing
-    boolean ensureActiveGroup(long timeoutMs, long startMs) {
+    boolean ensureActiveGroup(final Timer timer) {
         // always ensure that the coordinator is ready because we may have been disconnected
         // when sending heartbeats and does not necessarily require us to rejoin the group.
-        if (!ensureCoordinatorReady(timeoutMs)) {
+        if (!ensureCoordinatorReady(timer)) {
             return false;
         }
 
         startHeartbeatThreadIfNeeded();
-
-        long joinStartMs = time.milliseconds();
-        long joinTimeoutMs = remainingTimeAtLeastZero(timeoutMs, joinStartMs - startMs);
-        return joinGroupIfNeeded(joinTimeoutMs, joinStartMs);
+        return joinGroupIfNeeded(timer);
     }
 
     private synchronized void startHeartbeatThreadIfNeeded() {
@@ -386,18 +374,14 @@ public abstract class AbstractCoordinator implements Closeable {
      *
      * Visible for testing.
      *
-     * @param timeoutMs Time to complete this action
-     * @param startTimeMs Current time when invoked
+     * @param timer Timer bounding how long this method can block
      * @return true iff the operation succeeded
      */
-    boolean joinGroupIfNeeded(final long timeoutMs, final long startTimeMs) {
-        long elapsedTime = 0L;
-
+    boolean joinGroupIfNeeded(final Timer timer) {
         while (rejoinNeededOrPending()) {
-            if (!ensureCoordinatorReady(remainingTimeAtLeastZero(timeoutMs, elapsedTime))) {
+            if (!ensureCoordinatorReady(timer)) {
                 return false;
             }
-            elapsedTime = time.milliseconds() - startTimeMs;
 
             // call onJoinPrepare if needed. We set a flag to make sure that we do not call it a second
             // time if the client is woken up before a pending rebalance completes. This must be called
@@ -410,7 +394,7 @@ public abstract class AbstractCoordinator implements Closeable {
             }
 
             final RequestFuture<ByteBuffer> future = initiateJoinGroup();
-            client.poll(future, remainingTimeAtLeastZero(timeoutMs, elapsedTime));
+            client.poll(future, timer);
             if (!future.isDone()) {
                 // we ran out of time
                 return false;
@@ -434,20 +418,13 @@ public abstract class AbstractCoordinator implements Closeable {
                     continue;
                 else if (!future.isRetriable())
                     throw exception;
-                time.sleep(retryBackoffMs);
-            }
 
-            if (rejoinNeededOrPending()) {
-                elapsedTime = time.milliseconds() - startTimeMs;
+                timer.sleep(retryBackoffMs);
             }
         }
         return true;
     }
 
-    private long remainingTimeAtLeastZero(final long timeout, final long elapsedTime) {
-        return Math.max(0, timeout - elapsedTime);
-    }
-
     private synchronized void resetJoinGroupFuture() {
         this.joinFuture = null;
     }
@@ -676,7 +653,7 @@ public abstract class AbstractCoordinator implements Closeable {
                             findCoordinatorResponse.node().port());
                     log.info("Discovered group coordinator {}", coordinator);
                     client.tryConnect(coordinator);
-                    heartbeat.resetTimeouts(time.milliseconds());
+                    heartbeat.resetSessionTimeout();
                 }
                 future.complete(null);
             } else if (error == Errors.GROUP_AUTHORIZATION_FAILED) {
@@ -769,14 +746,13 @@ public abstract class AbstractCoordinator implements Closeable {
      */
     @Override
     public final void close() {
-        close(0);
+        close(time.timer(0));
     }
 
-    protected void close(long timeoutMs) {
+    protected void close(Timer timer) {
         try {
             closeHeartbeatThread();
         } finally {
-
             // Synchronize after closing the heartbeat thread since heartbeat thread
             // needs this lock to complete and terminate after close flag is set.
             synchronized (this) {
@@ -789,7 +765,7 @@ public abstract class AbstractCoordinator implements Closeable {
                 // yet sent to the broker. Wait up to close timeout for these pending requests to be processed.
                 // If coordinator is not known, requests are aborted.
                 Node coordinator = checkAndGetCoordinator();
-                if (coordinator != null && !client.awaitPendingRequests(coordinator, timeoutMs))
+                if (coordinator != null && !client.awaitPendingRequests(coordinator, timer))
                     log.warn("Close timed out with {} pending requests to coordinator, terminating client connections",
                             client.pendingRequestCount(coordinator));
             }
@@ -951,7 +927,7 @@ public abstract class AbstractCoordinator implements Closeable {
                 };
             metrics.addMetric(metrics.metricName("last-heartbeat-seconds-ago",
                 this.metricGrpName,
-                "The number of seconds since the last controller heartbeat was sent"),
+                "The number of seconds since the last coordinator heartbeat was sent"),
                 lastHeartbeat);
         }
     }
@@ -969,7 +945,7 @@ public abstract class AbstractCoordinator implements Closeable {
             synchronized (AbstractCoordinator.this) {
                 log.debug("Enabling heartbeat thread");
                 this.enabled = true;
-                heartbeat.resetTimeouts(time.milliseconds());
+                heartbeat.resetTimeouts();
                 AbstractCoordinator.this.notify();
             }
         }
@@ -1050,7 +1026,7 @@ public abstract class AbstractCoordinator implements Closeable {
                                 @Override
                                 public void onSuccess(Void value) {
                                     synchronized (AbstractCoordinator.this) {
-                                        heartbeat.receiveHeartbeat(time.milliseconds());
+                                        heartbeat.receiveHeartbeat();
                                     }
                                 }
 
@@ -1062,7 +1038,7 @@ public abstract class AbstractCoordinator implements Closeable {
                                             // ensures that the coordinator keeps the member in the group for as long
                                             // as the duration of the rebalance timeout. If we stop sending heartbeats,
                                             // however, then the session timeout may expire before we can rejoin.
-                                            heartbeat.receiveHeartbeat(time.milliseconds());
+                                            heartbeat.receiveHeartbeat();
                                         } else {
                                             heartbeat.failHeartbeat();
 
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index 8f25d6e..8762480 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -49,6 +49,7 @@ import org.apache.kafka.common.requests.OffsetFetchRequest;
 import org.apache.kafka.common.requests.OffsetFetchResponse;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Timer;
 import org.slf4j.Logger;
 
 import java.nio.ByteBuffer;
@@ -87,7 +88,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     private Set<String> joinedSubscription;
     private MetadataSnapshot metadataSnapshot;
     private MetadataSnapshot assignmentSnapshot;
-    private long nextAutoCommitDeadline;
+    private Timer nextAutoCommitTimer;
 
     // hold onto request&future for committed offset requests to enable async calls.
     private PendingCommittedOffsetRequest pendingCommittedOffsetRequest = null;
@@ -158,7 +159,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         this.pendingAsyncCommits = new AtomicInteger();
 
         if (autoCommitEnabled)
-            this.nextAutoCommitDeadline = time.milliseconds() + autoCommitIntervalMs;
+            this.nextAutoCommitTimer = time.timer(autoCommitIntervalMs);
 
         this.metadata.requestUpdate();
         addMetadataListener();
@@ -278,7 +279,8 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         assignor.onAssignment(assignment);
 
         // reschedule the auto commit starting from now
-        this.nextAutoCommitDeadline = time.milliseconds() + autoCommitIntervalMs;
+        if (autoCommitEnabled)
+            this.nextAutoCommitTimer.updateAndReset(autoCommitIntervalMs);
 
         // execute the user's callback after rebalance
         ConsumerRebalanceListener listener = subscriptions.rebalanceListener();
@@ -300,27 +302,18 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
      * <p>
      * Returns early if the timeout expires
      *
-     * @param timeoutMs The amount of time, in ms, allotted for this operation.
+     * @param timer Timer bounding how long this method can block
      * @return true iff the operation succeeded
      */
-    public boolean poll(final long timeoutMs) {
-        final long startTime = time.milliseconds();
-        long currentTime = startTime;
-        long elapsed = 0L;
-
+    public boolean poll(Timer timer) {
         invokeCompletedOffsetCommitCallbacks();
 
         if (subscriptions.partitionsAutoAssigned()) {
             // Always update the heartbeat last poll time so that the heartbeat thread does not leave the
             // group proactively due to application inactivity even if (say) the coordinator cannot be found.
-            pollHeartbeat(currentTime);
-
-            if (coordinatorUnknown()) {
-                if (!ensureCoordinatorReady(remainingTimeAtLeastZero(timeoutMs, elapsed))) {
-                    return false;
-                }
-                currentTime = time.milliseconds();
-                elapsed = currentTime - startTime;
+            pollHeartbeat(timer.currentTimeMs());
+            if (coordinatorUnknown() && !ensureCoordinatorReady(timer)) {
+                return false;
             }
 
             if (rejoinNeededOrPending()) {
@@ -335,21 +328,18 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                     // reduce the number of rebalances caused by single topic creation by asking consumer to
                     // refresh metadata before re-joining the group as long as the refresh backoff time has
                     // passed.
-                    if (this.metadata.timeToAllowUpdate(currentTime) == 0) {
+                    if (this.metadata.timeToAllowUpdate(time.milliseconds()) == 0) {
                         this.metadata.requestUpdate();
                     }
-                    if (!client.ensureFreshMetadata(remainingTimeAtLeastZero(timeoutMs, elapsed))) {
+
+                    if (!client.ensureFreshMetadata(timer)) {
                         return false;
                     }
-                    currentTime = time.milliseconds();
-                    elapsed = currentTime - startTime;
                 }
 
-                if (!ensureActiveGroup(remainingTimeAtLeastZero(timeoutMs, elapsed))) {
+                if (!ensureActiveGroup(timer)) {
                     return false;
                 }
-
-                currentTime = time.milliseconds();
             }
         } else {
             // For manually assigned partitions, if there are no ready nodes, await metadata.
@@ -359,26 +349,17 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             // awaitMetadataUpdate() initiates new connections with configured backoff and avoids the busy loop.
             // When group management is used, metadata wait is already performed for this scenario as
             // coordinator is unknown, hence this check is not required.
-            if (metadata.updateRequested() && !client.hasReadyNodes(startTime)) {
-                final boolean metadataUpdated = client.awaitMetadataUpdate(remainingTimeAtLeastZero(timeoutMs, elapsed));
-                if (!metadataUpdated && !client.hasReadyNodes(time.milliseconds())) {
-                    return false;
-                }
-
-                currentTime = time.milliseconds();
+            if (metadata.updateRequested() && !client.hasReadyNodes(timer.currentTimeMs())) {
+                client.awaitMetadataUpdate(timer);
             }
         }
 
-        maybeAutoCommitOffsetsAsync(currentTime);
+        maybeAutoCommitOffsetsAsync(timer.currentTimeMs());
         return true;
     }
 
-    private long remainingTimeAtLeastZero(final long timeoutMs, final long elapsed) {
-        return Math.max(0, timeoutMs - elapsed);
-    }
-
     /**
-     * Return the time to the next needed invocation of {@link #poll(long)}.
+     * Return the time to the next needed invocation of {@link #poll(Timer)}.
      * @param now current time in milliseconds
      * @return the maximum time in milliseconds the caller should wait before the next invocation of poll()
      */
@@ -386,10 +367,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         if (!autoCommitEnabled)
             return timeToNextHeartbeat(now);
 
-        if (now > nextAutoCommitDeadline)
-            return 0;
-
-        return Math.min(nextAutoCommitDeadline - now, timeToNextHeartbeat(now));
+        return Math.min(nextAutoCommitTimer.remainingMs(), timeToNextHeartbeat(now));
     }
 
     @Override
@@ -415,7 +393,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
 
         // update metadata (if needed) and keep track of the metadata used for assignment so that
         // we can check after rebalance completion whether anything has changed
-        if (!client.ensureFreshMetadata(Long.MAX_VALUE)) throw new TimeoutException();
+        if (!client.ensureFreshMetadata(time.timer(Long.MAX_VALUE))) throw new TimeoutException();
 
         isLeader = true;
 
@@ -451,7 +429,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             allSubscribedTopics.addAll(assignedTopics);
             this.subscriptions.groupSubscribe(allSubscribedTopics);
             metadata.setTopics(this.subscriptions.groupSubscription());
-            if (!client.ensureFreshMetadata(Long.MAX_VALUE)) throw new TimeoutException();
+            if (!client.ensureFreshMetadata(time.timer(Long.MAX_VALUE))) throw new TimeoutException();
         }
 
         assignmentSnapshot = metadataSnapshot;
@@ -470,7 +448,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     @Override
     protected void onJoinPrepare(int generation, String memberId) {
         // commit offsets prior to rebalance if auto-commit enabled
-        maybeAutoCommitOffsetsSync(rebalanceTimeoutMs);
+        maybeAutoCommitOffsetsSync(time.timer(rebalanceTimeoutMs));
 
         // execute the user's callback before rebalance
         ConsumerRebalanceListener listener = subscriptions.rebalanceListener();
@@ -507,13 +485,13 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     /**
      * Refresh the committed offsets for provided partitions.
      *
-     * @param timeoutMs A time limit for this operation
+     * @param timer Timer bounding how long this method can block
      * @return true iff the operation completed within the timeout
      */
-    public boolean refreshCommittedOffsetsIfNeeded(final long timeoutMs) {
+    public boolean refreshCommittedOffsetsIfNeeded(Timer timer) {
         final Set<TopicPartition> missingFetchPositions = subscriptions.missingFetchPositions();
 
-        final Map<TopicPartition, OffsetAndMetadata> offsets = fetchCommittedOffsets(missingFetchPositions, timeoutMs);
+        final Map<TopicPartition, OffsetAndMetadata> offsets = fetchCommittedOffsets(missingFetchPositions, timer);
         if (offsets == null) return false;
 
         for (final Map.Entry<TopicPartition, OffsetAndMetadata> entry : offsets.entrySet()) {
@@ -532,7 +510,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
      * @return A map from partition to the committed offset or null if the operation timed out
      */
     public Map<TopicPartition, OffsetAndMetadata> fetchCommittedOffsets(final Set<TopicPartition> partitions,
-                                                                        final long timeoutMs) {
+                                                                        final Timer timer) {
         if (partitions.isEmpty()) return Collections.emptyMap();
 
         final Generation generation = generation();
@@ -541,12 +519,8 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             pendingCommittedOffsetRequest = null;
         }
 
-        final long startMs = time.milliseconds();
-        long elapsedTime = 0L;
-
-        while (true) {
-            if (!ensureCoordinatorReady(remainingTimeAtLeastZero(timeoutMs, elapsedTime))) return null;
-            elapsedTime = time.milliseconds() - startMs;
+        do {
+            if (!ensureCoordinatorReady(timer)) return null;
 
             // contact coordinator to fetch committed offsets
             final RequestFuture<Map<TopicPartition, OffsetAndMetadata>> future;
@@ -557,7 +531,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                 pendingCommittedOffsetRequest = new PendingCommittedOffsetRequest(partitions, generation, future);
 
             }
-            client.poll(future, remainingTimeAtLeastZero(timeoutMs, elapsedTime));
+            client.poll(future, timer);
 
             if (future.isDone()) {
                 pendingCommittedOffsetRequest = null;
@@ -567,32 +541,27 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                 } else if (!future.isRetriable()) {
                     throw future.exception();
                 } else {
-                    elapsedTime = time.milliseconds() - startMs;
-                    final long sleepTime = Math.min(retryBackoffMs, remainingTimeAtLeastZero(startMs, elapsedTime));
-                    time.sleep(sleepTime);
-                    elapsedTime += sleepTime;
+                    timer.sleep(retryBackoffMs);
                 }
             } else {
                 return null;
             }
-        }
+        } while (timer.notExpired());
+        return null;
     }
 
-    public void close(final long timeoutMs) {
+    public void close(final Timer timer) {
         // we do not need to re-enable wakeups since we are closing already
         client.disableWakeups();
-
-        long now = time.milliseconds();
-        final long endTimeMs = now + timeoutMs;
         try {
-            maybeAutoCommitOffsetsSync(timeoutMs);
-            now = time.milliseconds();
-            if (pendingAsyncCommits.get() > 0 && endTimeMs > now) {
-                ensureCoordinatorReady(endTimeMs - now);
-                now = time.milliseconds();
+            maybeAutoCommitOffsetsSync(timer);
+            while (pendingAsyncCommits.get() > 0 && timer.notExpired()) {
+                ensureCoordinatorReady(timer);
+                client.poll(timer);
+                invokeCompletedOffsetCommitCallbacks();
             }
         } finally {
-            super.close(Math.max(0, endTimeMs - now));
+            super.close(timer);
         }
     }
 
@@ -676,25 +645,19 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
      * @return If the offset commit was successfully sent and a successful response was received from
      *         the coordinator
      */
-    public boolean commitOffsetsSync(Map<TopicPartition, OffsetAndMetadata> offsets, long timeoutMs) {
+    public boolean commitOffsetsSync(Map<TopicPartition, OffsetAndMetadata> offsets, Timer timer) {
         invokeCompletedOffsetCommitCallbacks();
 
         if (offsets.isEmpty())
             return true;
 
-        long now = time.milliseconds();
-        long startMs = now;
-        long remainingMs = timeoutMs;
         do {
-            if (coordinatorUnknown()) {
-                if (!ensureCoordinatorReady(remainingMs))
-                    return false;
-
-                remainingMs = timeoutMs - (time.milliseconds() - startMs);
+            if (coordinatorUnknown() && !ensureCoordinatorReady(timer)) {
+                return false;
             }
 
             RequestFuture<Void> future = sendOffsetCommitRequest(offsets);
-            client.poll(future, remainingMs);
+            client.poll(future, timer);
 
             // We may have had in-flight offset commits when the synchronous commit began. If so, ensure that
             // the corresponding callbacks are invoked prior to returning in order to preserve the order that
@@ -710,19 +673,19 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             if (future.failed() && !future.isRetriable())
                 throw future.exception();
 
-            time.sleep(retryBackoffMs);
-
-            now = time.milliseconds();
-            remainingMs = timeoutMs - (now - startMs);
-        } while (remainingMs > 0);
+            timer.sleep(retryBackoffMs);
+        } while (timer.notExpired());
 
         return false;
     }
 
     public void maybeAutoCommitOffsetsAsync(long now) {
-        if (autoCommitEnabled && now >= nextAutoCommitDeadline) {
-            this.nextAutoCommitDeadline = now + autoCommitIntervalMs;
-            doAutoCommitOffsetsAsync();
+        if (autoCommitEnabled) {
+            nextAutoCommitTimer.update(now);
+            if (nextAutoCommitTimer.isExpired()) {
+                nextAutoCommitTimer.reset(autoCommitIntervalMs);
+                doAutoCommitOffsetsAsync();
+            }
         }
     }
 
@@ -737,7 +700,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                     if (exception instanceof RetriableException) {
                         log.debug("Asynchronous auto-commit of offsets {} failed due to retriable error: {}", offsets,
                                 exception);
-                        nextAutoCommitDeadline = Math.min(time.milliseconds() + retryBackoffMs, nextAutoCommitDeadline);
+                        nextAutoCommitTimer.updateAndReset(retryBackoffMs);
                     } else {
                         log.warn("Asynchronous auto-commit of offsets {} failed: {}", offsets, exception.getMessage());
                     }
@@ -748,12 +711,12 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         });
     }
 
-    private void maybeAutoCommitOffsetsSync(long timeoutMs) {
+    private void maybeAutoCommitOffsetsSync(Timer timer) {
         if (autoCommitEnabled) {
             Map<TopicPartition, OffsetAndMetadata> allConsumedOffsets = subscriptions.allConsumed();
             try {
                 log.debug("Sending synchronous auto-commit of offsets {}", allConsumedOffsets);
-                if (!commitOffsetsSync(allConsumedOffsets, timeoutMs))
+                if (!commitOffsetsSync(allConsumedOffsets, timer))
                     log.debug("Auto-commit of offsets {} timed out before completion", allConsumedOffsets);
             } catch (WakeupException | InterruptException e) {
                 log.debug("Auto-commit of offsets {} was interrupted before completion", allConsumedOffsets);
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java
index 0bf0aad..924b3ef 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java
@@ -30,6 +30,7 @@ import org.apache.kafka.common.errors.WakeupException;
 import org.apache.kafka.common.requests.AbstractRequest;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Timer;
 import org.slf4j.Logger;
 
 import java.io.Closeable;
@@ -104,7 +105,7 @@ public class ConsumerNetworkClient implements Closeable {
 
     /**
      * Send a new request. Note that the request is not actually transmitted on the
-     * network until one of the {@link #poll(long)} variants is invoked. At this
+     * network until one of the {@link #poll(Timer)} variants is invoked. At this
      * point the request will either be transmitted successfully or will fail.
      * Use the returned future to obtain the result of the send. Note that there is no
      * need to check for disconnects explicitly on the {@link ClientResponse} object;
@@ -154,15 +155,14 @@ public class ConsumerNetworkClient implements Closeable {
      *
      * @return true if update succeeded, false otherwise.
      */
-    public boolean awaitMetadataUpdate(long timeout) {
-        long startMs = time.milliseconds();
+    public boolean awaitMetadataUpdate(Timer timer) {
         int version = this.metadata.requestUpdate();
         do {
-            poll(timeout);
+            poll(timer);
             AuthenticationException ex = this.metadata.getAndClearAuthenticationException();
             if (ex != null)
                 throw ex;
-        } while (this.metadata.version() == version && time.milliseconds() - startMs < timeout);
+        } while (this.metadata.version() == version && timer.notExpired());
         return this.metadata.version() > version;
     }
 
@@ -170,9 +170,9 @@ public class ConsumerNetworkClient implements Closeable {
      * Ensure our metadata is fresh (if an update is expected, this will block
      * until it has completed).
      */
-    boolean ensureFreshMetadata(final long timeout) {
-        if (this.metadata.updateRequested() || this.metadata.timeToNextUpdate(time.milliseconds()) == 0) {
-            return awaitMetadataUpdate(timeout);
+    boolean ensureFreshMetadata(Timer timer) {
+        if (this.metadata.updateRequested() || this.metadata.timeToNextUpdate(timer.currentTimeMs()) == 0) {
+            return awaitMetadataUpdate(timer);
         } else {
             // the metadata is already fresh
             return true;
@@ -185,7 +185,7 @@ public class ConsumerNetworkClient implements Closeable {
      */
     public void wakeup() {
         // wakeup should be safe without holding the client lock since it simply delegates to
-        // Selector's wakeup, which is threadsafe
+        // Selector's wakeup, which is thread-safe
         log.debug("Received user wakeup");
         this.wakeup.set(true);
         this.client.wakeup();
@@ -199,56 +199,50 @@ public class ConsumerNetworkClient implements Closeable {
      */
     public void poll(RequestFuture<?> future) {
         while (!future.isDone())
-            poll(Long.MAX_VALUE, time.milliseconds(), future);
+            poll(time.timer(Long.MAX_VALUE), future);
     }
 
     /**
      * Block until the provided request future request has finished or the timeout has expired.
      * @param future The request future to wait for
-     * @param timeout The maximum duration (in ms) to wait for the request
+     * @param timer Timer bounding how long this method can block
      * @return true if the future is done, false otherwise
      * @throws WakeupException if {@link #wakeup()} is called from another thread
      * @throws InterruptException if the calling thread is interrupted
      */
-    public boolean poll(RequestFuture<?> future, long timeout) {
-        long begin = time.milliseconds();
-        long remaining = timeout;
-        long now = begin;
+    public boolean poll(RequestFuture<?> future, Timer timer) {
         do {
-            poll(remaining, now, future);
-            now = time.milliseconds();
-            long elapsed = now - begin;
-            remaining = timeout - elapsed;
-        } while (!future.isDone() && remaining > 0);
+            poll(timer, future);
+        } while (!future.isDone() && timer.notExpired());
         return future.isDone();
     }
 
     /**
      * Poll for any network IO.
-     * @param timeout The maximum time to wait for an IO event.
+     * @param timer Timer bounding how long this method can block
      * @throws WakeupException if {@link #wakeup()} is called from another thread
      * @throws InterruptException if the calling thread is interrupted
      */
-    public void poll(long timeout) {
-        poll(timeout, time.milliseconds(), null);
+    public void poll(Timer timer) {
+        poll(timer, null);
     }
 
     /**
      * Poll for any network IO.
-     * @param timeout timeout in milliseconds
-     * @param now current time in milliseconds
+     * @param timer Timer bounding how long this method can block
+     * @param pollCondition Nullable blocking condition
      */
-    public void poll(long timeout, long now, PollCondition pollCondition) {
-        poll(timeout, now, pollCondition, false);
+    public void poll(Timer timer, PollCondition pollCondition) {
+        poll(timer, pollCondition, false);
     }
 
     /**
      * Poll for any network IO.
-     * @param timeout timeout in milliseconds
-     * @param now current time in milliseconds
+     * @param timer Timer bounding how long this method can block
+     * @param pollCondition Nullable blocking condition
      * @param disableWakeup If TRUE disable triggering wake-ups
      */
-    public void poll(long timeout, long now, PollCondition pollCondition, boolean disableWakeup) {
+    public void poll(Timer timer, PollCondition pollCondition, boolean disableWakeup) {
         // there may be handlers which need to be invoked if we woke up the previous call to poll
         firePendingCompletedRequests();
 
@@ -258,26 +252,26 @@ public class ConsumerNetworkClient implements Closeable {
             handlePendingDisconnects();
 
             // send all the requests we can send now
-            long pollDelayMs = trySend(now);
-            timeout = Math.min(timeout, pollDelayMs);
+            long pollDelayMs = trySend(timer.currentTimeMs());
 
             // check whether the poll is still needed by the caller. Note that if the expected completion
             // condition becomes satisfied after the call to shouldBlock() (because of a fired completion
             // handler), the client will be woken up.
             if (pendingCompletion.isEmpty() && (pollCondition == null || pollCondition.shouldBlock())) {
                 // if there are no requests in flight, do not block longer than the retry backoff
+                long pollTimeout = Math.min(timer.remainingMs(), pollDelayMs);
                 if (client.inFlightRequestCount() == 0)
-                    timeout = Math.min(timeout, retryBackoffMs);
-                client.poll(Math.min(maxPollTimeoutMs, timeout), now);
-                now = time.milliseconds();
+                    pollTimeout = Math.min(pollTimeout, retryBackoffMs);
+                client.poll(pollTimeout, timer.currentTimeMs());
             } else {
-                client.poll(0, now);
+                client.poll(0, timer.currentTimeMs());
             }
+            timer.update();
 
             // handle any disconnects by failing the active requests. note that disconnects must
             // be checked immediately following poll since any subsequent call to client.ready()
             // will reset the disconnect status
-            checkDisconnects(now);
+            checkDisconnects(timer.currentTimeMs());
             if (!disableWakeup) {
                 // trigger wakeups after checking for disconnects so that the callbacks will be ready
                 // to be fired on the next call to poll()
@@ -288,10 +282,10 @@ public class ConsumerNetworkClient implements Closeable {
 
             // try again to send requests since buffer space may have been
             // cleared or a connect finished in the poll
-            trySend(now);
+            trySend(timer.currentTimeMs());
 
             // fail requests that couldn't be sent if they have expired
-            failExpiredRequests(now);
+            failExpiredRequests(timer.currentTimeMs());
 
             // clean unsent requests collection to keep the map from growing indefinitely
             unsent.clean();
@@ -307,24 +301,19 @@ public class ConsumerNetworkClient implements Closeable {
      * Poll for network IO and return immediately. This will not trigger wakeups.
      */
     public void pollNoWakeup() {
-        poll(0, time.milliseconds(), null, true);
+        poll(time.timer(0), null, true);
     }
 
     /**
      * Block until all pending requests from the given node have finished.
      * @param node The node to await requests from
-     * @param timeoutMs The maximum time in milliseconds to block
+     * @param timer Timer bounding how long this method can block
      * @return true If all requests finished, false if the timeout expired first
      */
-    public boolean awaitPendingRequests(Node node, long timeoutMs) {
-        long startMs = time.milliseconds();
-        long remainingMs = timeoutMs;
-
-        while (hasPendingRequests(node) && remainingMs > 0) {
-            poll(remainingMs);
-            remainingMs = timeoutMs - (time.milliseconds() - startMs);
+    public boolean awaitPendingRequests(Node node, Timer timer) {
+        while (hasPendingRequests(node) && timer.notExpired()) {
+            poll(timer);
         }
-
         return !hasPendingRequests(node);
     }
 
@@ -472,7 +461,7 @@ public class ConsumerNetworkClient implements Closeable {
     }
 
     private long trySend(long now) {
-        long pollDelayMs = Long.MAX_VALUE;
+        long pollDelayMs = maxPollTimeoutMs;
 
         // send any requests that can be sent now
         for (Node node : unsent.nodes()) {
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index dd412ab..36a3314 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -66,6 +66,7 @@ import org.apache.kafka.common.serialization.ExtendedDeserializer;
 import org.apache.kafka.common.utils.CloseableIterator;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Timer;
 import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
 
@@ -247,31 +248,28 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
 
     /**
      * Get topic metadata for all topics in the cluster
-     * @param timeout time for which getting topic metadata is attempted
+     * @param timer Timer bounding how long this method can block
      * @return The map of topics with their partition information
      */
-    public Map<String, List<PartitionInfo>> getAllTopicMetadata(long timeout) {
-        return getTopicMetadata(MetadataRequest.Builder.allTopics(), timeout);
+    public Map<String, List<PartitionInfo>> getAllTopicMetadata(Timer timer) {
+        return getTopicMetadata(MetadataRequest.Builder.allTopics(), timer);
     }
 
     /**
      * Get metadata for all topics present in Kafka cluster
      *
      * @param request The MetadataRequest to send
-     * @param timeout time for which getting topic metadata is attempted
+     * @param timer Timer bounding how long this method can block
      * @return The map of topics with their partition information
      */
-    public Map<String, List<PartitionInfo>> getTopicMetadata(MetadataRequest.Builder request, long timeout) {
+    public Map<String, List<PartitionInfo>> getTopicMetadata(MetadataRequest.Builder request, Timer timer) {
         // Save the round trip if no topics are requested.
         if (!request.isAllTopics() && request.topics().isEmpty())
             return Collections.emptyMap();
 
-        long start = time.milliseconds();
-        long remaining = timeout;
-
         do {
             RequestFuture<ClientResponse> future = sendMetadataRequest(request);
-            client.poll(future, remaining);
+            client.poll(future, timer);
 
             if (future.failed() && !future.isRetriable())
                 throw future.exception();
@@ -318,15 +316,8 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
                 }
             }
 
-            long elapsed = time.milliseconds() - start;
-            remaining = timeout - elapsed;
-
-            if (remaining > 0) {
-                long backoff = Math.min(remaining, retryBackoffMs);
-                time.sleep(backoff);
-                remaining -= backoff;
-            }
-        } while (remaining > 0);
+            timer.sleep(retryBackoffMs);
+        } while (timer.notExpired());
 
         throw new TimeoutException("Timeout expired while fetching topic metadata");
     }
@@ -380,9 +371,9 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
     }
 
     public Map<TopicPartition, OffsetAndTimestamp> offsetsByTimes(Map<TopicPartition, Long> timestampsToSearch,
-                                                                  long timeout) {
+                                                                  Timer timer) {
         Map<TopicPartition, OffsetData> fetchedOffsets = fetchOffsetsByTimes(timestampsToSearch,
-                timeout, true).fetchedOffsets;
+                timer, true).fetchedOffsets;
 
         HashMap<TopicPartition, OffsetAndTimestamp> offsetsByTimes = new HashMap<>(timestampsToSearch.size());
         for (Map.Entry<TopicPartition, Long> entry : timestampsToSearch.entrySet())
@@ -399,19 +390,16 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
     }
 
     private ListOffsetResult fetchOffsetsByTimes(Map<TopicPartition, Long> timestampsToSearch,
-                                                 long timeout,
+                                                 Timer timer,
                                                  boolean requireTimestamps) {
         ListOffsetResult result = new ListOffsetResult();
         if (timestampsToSearch.isEmpty())
             return result;
 
         Map<TopicPartition, Long> remainingToSearch = new HashMap<>(timestampsToSearch);
-
-        long startMs = time.milliseconds();
-        long remaining = timeout;
         do {
             RequestFuture<ListOffsetResult> future = sendListOffsetsRequests(remainingToSearch, requireTimestamps);
-            client.poll(future, remaining);
+            client.poll(future, timer);
 
             if (!future.isDone())
                 break;
@@ -427,39 +415,31 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
                 throw future.exception();
             }
 
-            long elapsed = time.milliseconds() - startMs;
-            remaining = timeout - elapsed;
-            if (remaining <= 0)
-                break;
-
             if (metadata.updateRequested())
-                client.awaitMetadataUpdate(remaining);
+                client.awaitMetadataUpdate(timer);
             else
-                time.sleep(Math.min(remaining, retryBackoffMs));
-
-            elapsed = time.milliseconds() - startMs;
-            remaining = timeout - elapsed;
-        } while (remaining > 0);
+                timer.sleep(retryBackoffMs);
+        } while (timer.notExpired());
 
-        throw new TimeoutException("Failed to get offsets by times in " + timeout + "ms");
+        throw new TimeoutException("Failed to get offsets by times in " + timer.elapsedMs() + "ms");
     }
 
-    public Map<TopicPartition, Long> beginningOffsets(Collection<TopicPartition> partitions, long timeout) {
-        return beginningOrEndOffset(partitions, ListOffsetRequest.EARLIEST_TIMESTAMP, timeout);
+    public Map<TopicPartition, Long> beginningOffsets(Collection<TopicPartition> partitions, Timer timer) {
+        return beginningOrEndOffset(partitions, ListOffsetRequest.EARLIEST_TIMESTAMP, timer);
     }
 
-    public Map<TopicPartition, Long> endOffsets(Collection<TopicPartition> partitions, long timeout) {
-        return beginningOrEndOffset(partitions, ListOffsetRequest.LATEST_TIMESTAMP, timeout);
+    public Map<TopicPartition, Long> endOffsets(Collection<TopicPartition> partitions, Timer timer) {
+        return beginningOrEndOffset(partitions, ListOffsetRequest.LATEST_TIMESTAMP, timer);
     }
 
     private Map<TopicPartition, Long> beginningOrEndOffset(Collection<TopicPartition> partitions,
                                                            long timestamp,
-                                                           long timeout) {
+                                                           Timer timer) {
         Map<TopicPartition, Long> timestampsToSearch = new HashMap<>();
         for (TopicPartition tp : partitions)
             timestampsToSearch.put(tp, timestamp);
         Map<TopicPartition, Long> offsets = new HashMap<>();
-        ListOffsetResult result = fetchOffsetsByTimes(timestampsToSearch, timeout, false);
+        ListOffsetResult result = fetchOffsetsByTimes(timestampsToSearch, timer, false);
         for (Map.Entry<TopicPartition, OffsetData> entry : result.fetchedOffsets.entrySet()) {
             offsets.put(entry.getKey(), entry.getValue().offset);
         }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java
index 01d7810..8a67f31 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java
@@ -16,6 +16,9 @@
  */
 package org.apache.kafka.clients.consumer.internals;
 
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Timer;
+
 /**
  * A helper class for managing the heartbeat to the coordinator
  */
@@ -24,45 +27,61 @@ public final class Heartbeat {
     private final int heartbeatIntervalMs;
     private final int maxPollIntervalMs;
     private final long retryBackoffMs;
+    private final Time time;
+    private final Timer heartbeatTimer;
+    private final Timer sessionTimer;
+    private final Timer pollTimer;
 
-    private volatile long lastHeartbeatSend; // volatile since it is read by metrics
-    private long lastHeartbeatReceive;
-    private long lastSessionReset;
-    private long lastPoll;
-    private boolean heartbeatFailed;
+    private volatile long lastHeartbeatSend;
 
-    public Heartbeat(int sessionTimeoutMs,
+    public Heartbeat(Time time,
+                     int sessionTimeoutMs,
                      int heartbeatIntervalMs,
                      int maxPollIntervalMs,
                      long retryBackoffMs) {
         if (heartbeatIntervalMs >= sessionTimeoutMs)
             throw new IllegalArgumentException("Heartbeat must be set lower than the session timeout");
 
+        this.time = time;
         this.sessionTimeoutMs = sessionTimeoutMs;
         this.heartbeatIntervalMs = heartbeatIntervalMs;
         this.maxPollIntervalMs = maxPollIntervalMs;
         this.retryBackoffMs = retryBackoffMs;
+        this.heartbeatTimer = time.timer(heartbeatIntervalMs);
+        this.sessionTimer = time.timer(sessionTimeoutMs);
+        this.pollTimer = time.timer(maxPollIntervalMs);
+    }
+
+    private void update(long now) {
+        heartbeatTimer.update(now);
+        sessionTimer.update(now);
+        pollTimer.update(now);
     }
 
     public void poll(long now) {
-        this.lastPoll = now;
+        update(now);
+        pollTimer.reset(maxPollIntervalMs);
     }
 
     public void sentHeartbeat(long now) {
         this.lastHeartbeatSend = now;
-        this.heartbeatFailed = false;
+        update(now);
+        heartbeatTimer.reset(heartbeatIntervalMs);
     }
 
     public void failHeartbeat() {
-        this.heartbeatFailed = true;
+        update(time.milliseconds());
+        heartbeatTimer.reset(retryBackoffMs);
     }
 
-    public void receiveHeartbeat(long now) {
-        this.lastHeartbeatReceive = now;
+    public void receiveHeartbeat() {
+        update(time.milliseconds());
+        sessionTimer.reset(sessionTimeoutMs);
     }
 
     public boolean shouldHeartbeat(long now) {
-        return timeToNextHeartbeat(now) == 0;
+        update(now);
+        return heartbeatTimer.isExpired();
     }
     
     public long lastHeartbeatSend() {
@@ -70,39 +89,34 @@ public final class Heartbeat {
     }
 
     public long timeToNextHeartbeat(long now) {
-        long timeSinceLastHeartbeat = now - Math.max(lastHeartbeatSend, lastSessionReset);
-        final long delayToNextHeartbeat;
-        if (heartbeatFailed)
-            delayToNextHeartbeat = retryBackoffMs;
-        else
-            delayToNextHeartbeat = heartbeatIntervalMs;
-
-        if (timeSinceLastHeartbeat > delayToNextHeartbeat)
-            return 0;
-        else
-            return delayToNextHeartbeat - timeSinceLastHeartbeat;
+        update(now);
+        return heartbeatTimer.remainingMs();
     }
 
     public boolean sessionTimeoutExpired(long now) {
-        return now - Math.max(lastSessionReset, lastHeartbeatReceive) > sessionTimeoutMs;
+        update(now);
+        return sessionTimer.isExpired();
     }
 
-    public long interval() {
-        return heartbeatIntervalMs;
+    public void resetTimeouts() {
+        update(time.milliseconds());
+        sessionTimer.reset(sessionTimeoutMs);
+        pollTimer.reset(maxPollIntervalMs);
+        heartbeatTimer.reset(heartbeatIntervalMs);
     }
 
-    public void resetTimeouts(long now) {
-        this.lastSessionReset = now;
-        this.lastPoll = now;
-        this.heartbeatFailed = false;
+    public void resetSessionTimeout() {
+        update(time.milliseconds());
+        sessionTimer.reset(sessionTimeoutMs);
     }
 
     public boolean pollTimeoutExpired(long now) {
-        return now - lastPoll > maxPollIntervalMs;
+        update(now);
+        return pollTimer.isExpired();
     }
 
     public long lastPollTime() {
-        return lastPoll;
+        return pollTimer.currentTimeMs();
     }
 
 }
\ No newline at end of file
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java b/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java
index 60da064..c8b79ab 100644
--- a/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java
@@ -16,8 +16,6 @@
  */
 package org.apache.kafka.common.utils;
 
-import java.util.concurrent.TimeUnit;
-
 /**
  * A time implementation that uses the system clock and sleep call. Use `Time.SYSTEM` instead of creating an instance
  * of this class.
@@ -30,11 +28,6 @@ public class SystemTime implements Time {
     }
 
     @Override
-    public long hiResClockMs() {
-        return TimeUnit.NANOSECONDS.toMillis(nanoseconds());
-    }
-
-    @Override
     public long nanoseconds() {
         return System.nanoTime();
     }
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Time.java b/clients/src/main/java/org/apache/kafka/common/utils/Time.java
index c288bd3..90190cb 100644
--- a/clients/src/main/java/org/apache/kafka/common/utils/Time.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Time.java
@@ -16,6 +16,9 @@
  */
 package org.apache.kafka.common.utils;
 
+import java.time.Duration;
+import java.util.concurrent.TimeUnit;
+
 /**
  * An interface abstracting the clock to use in unit testing classes that make use of clock time.
  *
@@ -33,7 +36,9 @@ public interface Time {
     /**
      * Returns the value returned by `nanoseconds` converted into milliseconds.
      */
-    long hiResClockMs();
+    default long hiResClockMs() {
+        return TimeUnit.NANOSECONDS.toMillis(nanoseconds());
+    }
 
     /**
      * Returns the current value of the running JVM's high-resolution time source, in nanoseconds.
@@ -53,4 +58,17 @@ public interface Time {
      */
     void sleep(long ms);
 
+    /**
+     * Get a timer which is bound to this time instance and expires after the given timeout
+     */
+    default Timer timer(long timeoutMs) {
+        return new Timer(this, timeoutMs);
+    }
+
+    /**
+     * Get a timer which is bound to this time instance and expires after the given timeout
+     */
+    default Timer timer(Duration timeout) {
+        return timer(timeout.toMillis());
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Timer.java b/clients/src/main/java/org/apache/kafka/common/utils/Timer.java
new file mode 100644
index 0000000..ba734b6
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Timer.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.utils;
+
+/**
+ * This is a helper class which makes blocking methods with a timeout easier to implement.
+ * In particular it enables use cases where a high-level blocking call with a timeout is
+ * composed of several lower level calls, each of which has their own respective timeouts. The idea
+ * is to create a single timer object for the high level timeout and carry it along to
+ * all of the lower level methods. This class also handles common problems such as integer overflow.
+ * This class also ensures monotonic updates to the timer even if the underlying clock is subject
+ * to non-monotonic behavior. For example, the remaining time returned by {@link #remainingMs()} is
+ * guaranteed to decrease monotonically until it hits zero.
+ *
+ * Note that it is up to the caller to ensure progress of the timer using one of the
+ * {@link #update()} methods or {@link #sleep(long)}. The timer will cache the current time and
+ * return it indefinitely until the timer has been updated. This allows the caller to limit
+ * unnecessary system calls and update the timer only when needed. For example, a timer which is
+ * waiting a request sent through the {@link org.apache.kafka.clients.NetworkClient} should call
+ * {@link #update()} following each blocking call to
+ * {@link org.apache.kafka.clients.NetworkClient#poll(long, long)}.
+ *
+ * A typical usage might look something like this:
+ *
+ * <pre>
+ *     Time time = Time.SYSTEM;
+ *     Timer timer = time.timer(500);
+ *
+ *     while (!conditionSatisfied() && timer.notExpired) {
+ *         client.poll(timer.remainingMs(), timer.currentTimeMs());
+ *         timer.update();
+ *     }
+ * </pre>
+ */
+public class Timer {
+    private final Time time;
+    private long startMs;
+    private long currentTimeMs;
+    private long deadlineMs;
+
+    Timer(Time time, long timeoutMs) {
+        this.time = time;
+        update();
+        reset(timeoutMs);
+    }
+
+    /**
+     * Check timer expiration. Like {@link #remainingMs()}, this depends on the current cached
+     * time in milliseconds, which is only updated through one of the {@link #update()} methods
+     * or with {@link #sleep(long)};
+     *
+     * @return true if the timer has expired, false otherwise
+     */
+    public boolean isExpired() {
+        return currentTimeMs >= deadlineMs;
+    }
+
+    /**
+     * Check whether the timer has not yet expired.
+     * @return true if there is still time remaining before expiration
+     */
+    public boolean notExpired() {
+        return !isExpired();
+    }
+
+    /**
+     * Reset the timer to the specific timeout. This will use the underlying {@link #Timer(Time, long)}
+     * implementation to update the current cached time in milliseconds and it will set a new timer
+     * deadline.
+     *
+     * @param timeoutMs The new timeout in milliseconds
+     */
+    public void updateAndReset(long timeoutMs) {
+        update();
+        reset(timeoutMs);
+    }
+
+    /**
+     * Reset the timer using a new timeout. Note that this does not update the cached current time
+     * in milliseconds, so it typically must be accompanied with a separate call to {@link #update()}.
+     * Typically, you can just use {@link #updateAndReset(long)}.
+     *
+     * @param timeoutMs The new timeout in milliseconds
+     */
+    public void reset(long timeoutMs) {
+        if (timeoutMs < 0)
+            throw new IllegalArgumentException("Invalid negative timeout " + timeoutMs);
+
+        this.startMs = this.currentTimeMs;
+
+        if (currentTimeMs > Long.MAX_VALUE - timeoutMs)
+            this.deadlineMs = Long.MAX_VALUE;
+        else
+            this.deadlineMs = currentTimeMs + timeoutMs;
+    }
+
+    /**
+     * Use the underlying {@link Time} implementation to update the current cached time. If
+     * the underlying time returns a value which is smaller than the current cached time,
+     * the update will be ignored.
+     */
+    public void update() {
+        update(time.milliseconds());
+    }
+
+    /**
+     * Update the cached current time to a specific value. In some contexts, the caller may already
+     * have an accurate time, so this avoids unnecessary calls to system time.
+     *
+     * Note that if the updated current time is smaller than the cached time, then the update
+     * is ignored.
+     *
+     * @param currentTimeMs The current time in milliseconds to cache
+     */
+    public void update(long currentTimeMs) {
+        this.currentTimeMs = Math.max(currentTimeMs, this.currentTimeMs);
+    }
+
+    /**
+     * Get the remaining time in milliseconds until the timer expires. Like {@link #currentTimeMs},
+     * this depends on the cached current time, so the returned value will not change until the timer
+     * has been updated using one of the {@link #update()} methods or {@link #sleep(long)}.
+     *
+     * @return The cached remaining time in milliseconds until timer expiration
+     */
+    public long remainingMs() {
+        return Math.max(0, deadlineMs - currentTimeMs);
+    }
+
+    /**
+     * Get the current time in milliseconds. This will return the same cached value until the timer
+     * has been updated using one of the {@link #update()} methods or {@link #sleep(long)} is used.
+     *
+     * Note that the value returned is guaranteed to increase monotonically even if the underlying
+     * {@link Time} implementation goes backwards. Effectively, the timer will just wait for the
+     * time to catch up.
+     *
+     * @return The current cached time in milliseconds
+     */
+    public long currentTimeMs() {
+        return currentTimeMs;
+    }
+
+    /**
+     * Get the amount of time that has elapsed since the timer began. If the timer was reset, this
+     * will be the amount of time since the last reset.
+     *
+     * @return The elapsed time since construction or the last reset
+     */
+    public long elapsedMs() {
+        return currentTimeMs - startMs;
+    }
+
+    /**
+     * Sleep for the requested duration and update the timer. Return when either the duration has
+     * elapsed or the timer has expired.
+     *
+     * @param durationMs The duration in milliseconds to sleep
+     */
+    public void sleep(long durationMs) {
+        long sleepDurationMs = Math.min(durationMs, remainingMs());
+        time.sleep(sleepDurationMs);
+        update();
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
index 634a1ab..a0f95c4 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
@@ -110,6 +110,7 @@ import static java.util.Collections.singletonMap;
 import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
@@ -136,7 +137,7 @@ public class KafkaConsumerTest {
     public ExpectedException expectedException = ExpectedException.none();
 
     @Test
-    public void testConstructorClose() throws Exception {
+    public void testConstructorClose() {
         Properties props = new Properties();
         props.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testConstructorClose");
         props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "invalid-23-8409-adsfsdj");
@@ -155,7 +156,7 @@ public class KafkaConsumerTest {
     }
 
     @Test
-    public void testOsDefaultSocketBufferSizes() throws Exception {
+    public void testOsDefaultSocketBufferSizes() {
         Map<String, Object> config = new HashMap<>();
         config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
         config.put(ConsumerConfig.SEND_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE);
@@ -166,7 +167,7 @@ public class KafkaConsumerTest {
     }
 
     @Test(expected = KafkaException.class)
-    public void testInvalidSocketSendBufferSize() throws Exception {
+    public void testInvalidSocketSendBufferSize() {
         Map<String, Object> config = new HashMap<>();
         config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
         config.put(ConsumerConfig.SEND_BUFFER_CONFIG, -2);
@@ -174,7 +175,7 @@ public class KafkaConsumerTest {
     }
 
     @Test(expected = KafkaException.class)
-    public void testInvalidSocketReceiveBufferSize() throws Exception {
+    public void testInvalidSocketReceiveBufferSize() {
         Map<String, Object> config = new HashMap<>();
         config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
         config.put(ConsumerConfig.RECEIVE_BUFFER_CONFIG, -2);
@@ -357,12 +358,7 @@ public class KafkaConsumerTest {
 
         // initial fetch
         client.prepareResponseFrom(fetchResponse(tp0, 0, 0), node);
-
-        // We need two update calls:
-        // 1. the first call "sends" the metadata update requests
-        // 2. the second one gets the response we already queued up
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertEquals(singleton(tp0), consumer.assignment());
 
@@ -372,8 +368,7 @@ public class KafkaConsumerTest {
         time.sleep(heartbeatIntervalMs);
         Thread.sleep(heartbeatIntervalMs);
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertTrue(heartbeatReceived.get());
         consumer.close(0, TimeUnit.MILLISECONDS);
@@ -396,8 +391,7 @@ public class KafkaConsumerTest {
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
         consumer.poll(Duration.ZERO);
 
         // respond to the outstanding fetch so that we have data available on the next poll
@@ -681,8 +675,7 @@ public class KafkaConsumerTest {
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
         consumer.poll(Duration.ZERO);
 
         // respond to the outstanding fetch so that we have data available on the next poll
@@ -726,8 +719,7 @@ public class KafkaConsumerTest {
 
         client.prepareMetadataUpdate(cluster, Collections.<String>emptySet());
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertEquals(singleton(topic), consumer.subscription());
         assertEquals(singleton(tp0), consumer.assignment());
@@ -762,8 +754,7 @@ public class KafkaConsumerTest {
         Node coordinator = prepareRebalance(client, node, singleton(topic), assignor, singletonList(tp0), null);
         consumer.subscribe(Pattern.compile(topic), getConsumerRebalanceListener(consumer));
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
         consumer.poll(Duration.ZERO);
 
         assertEquals(singleton(topic), consumer.subscription());
@@ -794,8 +785,7 @@ public class KafkaConsumerTest {
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
         consumer.poll(Duration.ZERO);
 
         // respond to the outstanding fetch so that we have data available on the next poll
@@ -846,8 +836,7 @@ public class KafkaConsumerTest {
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
         consumer.poll(Duration.ZERO);
 
         // interrupt the thread and call poll
@@ -885,8 +874,7 @@ public class KafkaConsumerTest {
         fetches1.put(t2p0, new FetchInfo(0, 10)); // not assigned and not fetched
         client.prepareResponseFrom(fetchResponse(fetches1), node);
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
 
         ConsumerRecords<String, String> records = consumer.poll(Duration.ZERO);
         assertEquals(0, records.count());
@@ -930,14 +918,13 @@ public class KafkaConsumerTest {
         // mock rebalance responses
         Node coordinator = prepareRebalance(client, node, assignor, Arrays.asList(tp0, t2p0), null);
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
         consumer.poll(Duration.ZERO);
 
         // verify that subscription is still the same, and now assignment has caught up
-        assertTrue(consumer.subscription().size() == 2);
+        assertEquals(2, consumer.subscription().size());
         assertTrue(consumer.subscription().contains(topic) && consumer.subscription().contains(topic2));
-        assertTrue(consumer.assignment().size() == 2);
+        assertEquals(2, consumer.assignment().size());
         assertTrue(consumer.assignment().contains(tp0) && consumer.assignment().contains(t2p0));
 
         // mock a response to the outstanding fetch so that we have data available on the next poll
@@ -1039,19 +1026,18 @@ public class KafkaConsumerTest {
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
 
         // verify that subscription has changed but assignment is still unchanged
-        assertTrue(consumer.subscription().equals(singleton(topic)));
-        assertTrue(consumer.assignment().isEmpty());
+        assertEquals(singleton(topic), consumer.subscription());
+        assertEquals(Collections.emptySet(), consumer.assignment());
 
         // mock rebalance responses
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
         consumer.poll(Duration.ZERO);
 
         // verify that subscription is still the same, and now assignment has caught up
-        assertTrue(consumer.subscription().equals(singleton(topic)));
-        assertTrue(consumer.assignment().equals(singleton(tp0)));
+        assertEquals(singleton(topic), consumer.subscription());
+        assertEquals(singleton(tp0), consumer.assignment());
 
         consumer.poll(Duration.ZERO);
 
@@ -1059,23 +1045,23 @@ public class KafkaConsumerTest {
         consumer.subscribe(singleton(topic2), getConsumerRebalanceListener(consumer));
 
         // verify that subscription has changed but assignment is still unchanged
-        assertTrue(consumer.subscription().equals(singleton(topic2)));
-        assertTrue(consumer.assignment().equals(singleton(tp0)));
+        assertEquals(singleton(topic2), consumer.subscription());
+        assertEquals(singleton(tp0), consumer.assignment());
 
         // the auto commit is disabled, so no offset commit request should be sent
         for (ClientRequest req: client.requests())
-            assertTrue(req.requestBuilder().apiKey() != ApiKeys.OFFSET_COMMIT);
+            assertNotSame(ApiKeys.OFFSET_COMMIT, req.requestBuilder().apiKey());
 
         // subscription change
         consumer.unsubscribe();
 
         // verify that subscription and assignment are both updated
-        assertTrue(consumer.subscription().isEmpty());
-        assertTrue(consumer.assignment().isEmpty());
+        assertEquals(Collections.emptySet(), consumer.subscription());
+        assertEquals(Collections.emptySet(), consumer.assignment());
 
         // the auto commit is disabled, so no offset commit request should be sent
         for (ClientRequest req: client.requests())
-            assertTrue(req.requestBuilder().apiKey() != ApiKeys.OFFSET_COMMIT);
+            assertNotSame(ApiKeys.OFFSET_COMMIT, req.requestBuilder().apiKey());
 
         client.requests().clear();
         consumer.close();
@@ -1352,8 +1338,7 @@ public class KafkaConsumerTest {
         client.prepareResponseFrom(fetchResponse(tp0, 0, 1), node);
         client.prepareResponseFrom(fetchResponse(tp0, 1, 0), node);
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
         consumer.poll(Duration.ZERO);
 
         // heartbeat fails due to rebalance in progress
@@ -1390,8 +1375,7 @@ public class KafkaConsumerTest {
         }, fetchResponse(tp0, 1, 1), node);
         time.sleep(heartbeatIntervalMs);
         Thread.sleep(heartbeatIntervalMs);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
         final ConsumerRecords<String, String> records = consumer.poll(Duration.ZERO);
         assertFalse(records.isEmpty());
         consumer.close(0, TimeUnit.MILLISECONDS);
@@ -1401,7 +1385,6 @@ public class KafkaConsumerTest {
                                    List<? extends AbstractResponse> responses,
                                    long waitMs,
                                    boolean interrupt) throws Exception {
-
         Time time = new MockTime();
         Cluster cluster = TestUtils.singletonCluster(topic, 1);
         Node node = cluster.nodes().get(0);
@@ -1419,8 +1402,7 @@ public class KafkaConsumerTest {
 
         client.prepareMetadataUpdate(cluster, Collections.<String>emptySet());
 
-        consumer.updateAssignmentMetadataIfNeeded(0L);
-        consumer.updateAssignmentMetadataIfNeeded(0L);
+        consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
 
         // Poll with responses
         client.prepareResponseFrom(fetchResponse(tp0, 0, 1), node);
@@ -1777,7 +1759,7 @@ public class KafkaConsumerTest {
         ConsumerNetworkClient consumerClient = new ConsumerNetworkClient(loggerFactory, client, metadata, time,
                 retryBackoffMs, requestTimeoutMs, heartbeatIntervalMs);
 
-        Heartbeat heartbeat = new Heartbeat(sessionTimeoutMs, heartbeatIntervalMs, rebalanceTimeoutMs, retryBackoffMs);
+        Heartbeat heartbeat = new Heartbeat(time, sessionTimeoutMs, heartbeatIntervalMs, rebalanceTimeoutMs, retryBackoffMs);
         ConsumerCoordinator consumerCoordinator = new ConsumerCoordinator(
                 loggerFactory,
                 consumerClient,
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
index f88e725..004445f 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
@@ -35,6 +35,7 @@ import org.apache.kafka.common.requests.SyncGroupResponse;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Timer;
 import org.apache.kafka.test.TestCondition;
 import org.apache.kafka.test.TestUtils;
 import org.junit.Test;
@@ -59,7 +60,6 @@ import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 public class AbstractCoordinatorTest {
-
     private static final ByteBuffer EMPTY_DATA = ByteBuffer.wrap(new byte[0]);
     private static final int REBALANCE_TIMEOUT_MS = 60000;
     private static final int SESSION_TIMEOUT_MS = 10000;
@@ -115,7 +115,7 @@ public class AbstractCoordinatorTest {
         mockClient.blackout(coordinatorNode, 10L);
 
         long initialTime = mockTime.milliseconds();
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(mockTime.timer(Long.MAX_VALUE));
         long endTime = mockTime.milliseconds();
 
         assertTrue(endTime - initialTime >= RETRY_BACKOFF_MS);
@@ -125,13 +125,12 @@ public class AbstractCoordinatorTest {
     public void testTimeoutAndRetryJoinGroupIfNeeded() throws Exception {
         setupCoordinator();
         mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(0);
+        coordinator.ensureCoordinatorReady(mockTime.timer(0));
 
         ExecutorService executor = Executors.newFixedThreadPool(1);
         try {
-            long firstAttemptStartMs = mockTime.milliseconds();
-            Future<Boolean> firstAttempt = executor.submit(() ->
-                    coordinator.joinGroupIfNeeded(REQUEST_TIMEOUT_MS, firstAttemptStartMs));
+            Timer firstAttemptTimer = mockTime.timer(REQUEST_TIMEOUT_MS);
+            Future<Boolean> firstAttempt = executor.submit(() -> coordinator.joinGroupIfNeeded(firstAttemptTimer));
 
             mockTime.sleep(REQUEST_TIMEOUT_MS);
             assertFalse(firstAttempt.get());
@@ -140,9 +139,8 @@ public class AbstractCoordinatorTest {
             mockClient.respond(joinGroupFollowerResponse(1, "memberId", "leaderId", Errors.NONE));
             mockClient.prepareResponse(syncGroupResponse(Errors.NONE));
 
-            long secondAttemptMs = mockTime.milliseconds();
-            Future<Boolean> secondAttempt = executor.submit(() ->
-                    coordinator.joinGroupIfNeeded(REQUEST_TIMEOUT_MS, secondAttemptMs));
+            Timer secondAttemptTimer = mockTime.timer(REQUEST_TIMEOUT_MS);
+            Future<Boolean> secondAttempt = executor.submit(() -> coordinator.joinGroupIfNeeded(secondAttemptTimer));
 
             assertTrue(secondAttempt.get());
         } finally {
@@ -155,15 +153,15 @@ public class AbstractCoordinatorTest {
     public void testJoinGroupRequestTimeout() {
         setupCoordinator(RETRY_BACKOFF_MS, REBALANCE_TIMEOUT_MS);
         mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(0);
+        coordinator.ensureCoordinatorReady(mockTime.timer(0));
 
         RequestFuture<ByteBuffer> future = coordinator.sendJoinGroupRequest();
 
         mockTime.sleep(REQUEST_TIMEOUT_MS + 1);
-        assertFalse(consumerClient.poll(future, 0));
+        assertFalse(consumerClient.poll(future, mockTime.timer(0)));
 
         mockTime.sleep(REBALANCE_TIMEOUT_MS - REQUEST_TIMEOUT_MS + 5000);
-        assertTrue(consumerClient.poll(future, 0));
+        assertTrue(consumerClient.poll(future, mockTime.timer(0)));
     }
 
     @Test
@@ -172,13 +170,13 @@ public class AbstractCoordinatorTest {
 
         setupCoordinator(RETRY_BACKOFF_MS, Integer.MAX_VALUE);
         mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(0);
+        coordinator.ensureCoordinatorReady(mockTime.timer(0));
 
         RequestFuture<ByteBuffer> future = coordinator.sendJoinGroupRequest();
-        assertFalse(consumerClient.poll(future, 0));
+        assertFalse(consumerClient.poll(future, mockTime.timer(0)));
 
         mockTime.sleep(Integer.MAX_VALUE + 1L);
-        assertTrue(consumerClient.poll(future, 0));
+        assertTrue(consumerClient.poll(future, mockTime.timer(0)));
     }
 
     @Test
@@ -256,7 +254,7 @@ public class AbstractCoordinatorTest {
         assertSame("New request sent while one is in progress", future, coordinator.lookupCoordinator());
 
         mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(mockTime.timer(Long.MAX_VALUE));
         assertNotSame("New request not sent after previous completed", future, coordinator.lookupCoordinator());
     }
 
@@ -329,7 +327,7 @@ public class AbstractCoordinatorTest {
         assertFalse(heartbeatReceived.get());
 
         // the join group completes in this poll()
-        consumerClient.poll(0);
+        consumerClient.poll(mockTime.timer(0));
         coordinator.ensureActiveGroup();
 
         assertEquals(1, coordinator.onJoinPrepareInvokes);
@@ -403,7 +401,7 @@ public class AbstractCoordinatorTest {
         assertFalse(heartbeatReceived.get());
 
         // the join group completes in this poll()
-        consumerClient.poll(0);
+        consumerClient.poll(mockTime.timer(0));
         coordinator.ensureActiveGroup();
 
         assertEquals(1, coordinator.onJoinPrepareInvokes);
@@ -481,7 +479,7 @@ public class AbstractCoordinatorTest {
         assertFalse(heartbeatReceived.get());
 
         // the join group completes in this poll()
-        consumerClient.poll(0);
+        consumerClient.poll(mockTime.timer(0));
         coordinator.ensureActiveGroup();
 
         assertEquals(1, coordinator.onJoinPrepareInvokes);
@@ -584,7 +582,7 @@ public class AbstractCoordinatorTest {
 
         // the join group completes in this poll()
         coordinator.wakeupOnJoinComplete = false;
-        consumerClient.poll(0);
+        consumerClient.poll(mockTime.timer(0));
         coordinator.ensureActiveGroup();
 
         assertEquals(1, coordinator.onJoinPrepareInvokes);
@@ -600,7 +598,7 @@ public class AbstractCoordinatorTest {
         mockClient.createPendingAuthenticationError(node, 300);
 
         try {
-            coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+            coordinator.ensureCoordinatorReady(mockTime.timer(Long.MAX_VALUE));
             fail("Expected an authentication error.");
         } catch (AuthenticationException e) {
             // OK
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
index cec56b0..62c70a0 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
@@ -89,7 +89,6 @@ import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 public class ConsumerCoordinatorTest {
-
     private final String topic1 = "test1";
     private final String topic2 = "test2";
     private final TopicPartition t1p = new TopicPartition(topic1, 0);
@@ -101,12 +100,12 @@ public class ConsumerCoordinatorTest {
     private final long retryBackoffMs = 100;
     private final int autoCommitIntervalMs = 2000;
     private final int requestTimeoutMs = 30000;
-    private final Heartbeat heartbeat = new Heartbeat(sessionTimeoutMs, heartbeatIntervalMs,
+    private final MockTime time = new MockTime();
+    private final Heartbeat heartbeat = new Heartbeat(time, sessionTimeoutMs, heartbeatIntervalMs,
             rebalanceTimeoutMs, retryBackoffMs);
 
     private MockPartitionAssignor partitionAssignor = new MockPartitionAssignor();
-    private List<PartitionAssignor> assignors = Collections.<PartitionAssignor>singletonList(partitionAssignor);
-    private MockTime time;
+    private List<PartitionAssignor> assignors = Collections.singletonList(partitionAssignor);
     private MockClient client;
     private Cluster cluster = TestUtils.clusterWith(1, new HashMap<String, Integer>() {
         {
@@ -125,7 +124,6 @@ public class ConsumerCoordinatorTest {
 
     @Before
     public void setup() {
-        this.time = new MockTime();
         this.subscriptions = new SubscriptionState(OffsetResetStrategy.EARLIEST);
         this.metadata = new Metadata(0, Long.MAX_VALUE, true);
         this.metadata.update(cluster, Collections.<String>emptySet(), time.milliseconds());
@@ -144,13 +142,13 @@ public class ConsumerCoordinatorTest {
     @After
     public void teardown() {
         this.metrics.close();
-        this.coordinator.close(0);
+        this.coordinator.close(time.timer(0));
     }
 
     @Test
     public void testNormalHeartbeat() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // normal heartbeat
         time.sleep(sessionTimeoutMs);
@@ -159,7 +157,7 @@ public class ConsumerCoordinatorTest {
         assertFalse(future.isDone());
 
         client.prepareResponse(heartbeatResponse(Errors.NONE));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         assertTrue(future.isDone());
         assertTrue(future.succeeded());
@@ -168,7 +166,7 @@ public class ConsumerCoordinatorTest {
     @Test(expected = GroupAuthorizationException.class)
     public void testGroupDescribeUnauthorized() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.GROUP_AUTHORIZATION_FAILED));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
     }
 
     @Test(expected = GroupAuthorizationException.class)
@@ -176,17 +174,17 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         client.prepareResponse(joinGroupLeaderResponse(0, "memberId", Collections.<String, List<String>>emptyMap(),
                 Errors.GROUP_AUTHORIZATION_FAILED));
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
     }
 
     @Test
     public void testCoordinatorNotAvailable() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // COORDINATOR_NOT_AVAILABLE will mark coordinator as unknown
         time.sleep(sessionTimeoutMs);
@@ -196,7 +194,7 @@ public class ConsumerCoordinatorTest {
 
         client.prepareResponse(heartbeatResponse(Errors.COORDINATOR_NOT_AVAILABLE));
         time.sleep(sessionTimeoutMs);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         assertTrue(future.isDone());
         assertTrue(future.failed());
@@ -207,7 +205,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testManyInFlightAsyncCommitsWithCoordinatorDisconnect() throws Exception {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         int numRequests = 1000;
         TopicPartition tp = new TopicPartition("foo", 0);
@@ -239,7 +237,7 @@ public class ConsumerCoordinatorTest {
         // the coordinator as unknown which prevents additional retries to the same coordinator.
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         final AtomicBoolean asyncCallbackInvoked = new AtomicBoolean(false);
         Map<TopicPartition, OffsetCommitRequest.PartitionData> offsets = singletonMap(
@@ -265,7 +263,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testNotCoordinator() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // not_coordinator will mark coordinator as unknown
         time.sleep(sessionTimeoutMs);
@@ -275,7 +273,7 @@ public class ConsumerCoordinatorTest {
 
         client.prepareResponse(heartbeatResponse(Errors.NOT_COORDINATOR));
         time.sleep(sessionTimeoutMs);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         assertTrue(future.isDone());
         assertTrue(future.failed());
@@ -286,7 +284,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testIllegalGeneration() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // illegal_generation will cause re-partition
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
@@ -299,7 +297,7 @@ public class ConsumerCoordinatorTest {
 
         client.prepareResponse(heartbeatResponse(Errors.ILLEGAL_GENERATION));
         time.sleep(sessionTimeoutMs);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         assertTrue(future.isDone());
         assertTrue(future.failed());
@@ -310,7 +308,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testUnknownConsumerId() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // illegal_generation will cause re-partition
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
@@ -323,7 +321,7 @@ public class ConsumerCoordinatorTest {
 
         client.prepareResponse(heartbeatResponse(Errors.UNKNOWN_MEMBER_ID));
         time.sleep(sessionTimeoutMs);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         assertTrue(future.isDone());
         assertTrue(future.failed());
@@ -334,7 +332,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testCoordinatorDisconnect() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // coordinator disconnect will mark coordinator as unknown
         time.sleep(sessionTimeoutMs);
@@ -344,7 +342,7 @@ public class ConsumerCoordinatorTest {
 
         client.prepareResponse(heartbeatResponse(Errors.NONE), true); // return disconnected
         time.sleep(sessionTimeoutMs);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         assertTrue(future.isDone());
         assertTrue(future.failed());
@@ -363,11 +361,11 @@ public class ConsumerCoordinatorTest {
         metadata.update(cluster, Collections.<String>emptySet(), time.milliseconds());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         client.prepareResponse(joinGroupLeaderResponse(0, consumerId, Collections.<String, List<String>>emptyMap(),
                 Errors.INVALID_GROUP_ID));
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
     }
 
     @Test
@@ -381,7 +379,7 @@ public class ConsumerCoordinatorTest {
         metadata.update(cluster, Collections.<String>emptySet(), time.milliseconds());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // normal join group
         Map<String, List<String>> memberSubscriptions = singletonMap(consumerId, singletonList(topic1));
@@ -397,7 +395,7 @@ public class ConsumerCoordinatorTest {
                         sync.groupAssignment().containsKey(consumerId);
             }
         }, syncGroupResponse(singletonList(t1p), Errors.NONE));
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(singleton(t1p), subscriptions.assignedPartitions());
@@ -420,7 +418,7 @@ public class ConsumerCoordinatorTest {
         metadata.update(TestUtils.singletonCluster(topic1, 1), Collections.<String>emptySet(), time.milliseconds());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // normal join group
         Map<String, List<String>> memberSubscriptions = singletonMap(consumerId, singletonList(topic1));
@@ -439,7 +437,7 @@ public class ConsumerCoordinatorTest {
         // expect client to force updating the metadata, if yes gives it both topics
         client.prepareMetadataUpdate(cluster, Collections.<String>emptySet());
 
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(2, subscriptions.assignedPartitions().size());
@@ -462,7 +460,7 @@ public class ConsumerCoordinatorTest {
         assertEquals(singleton(topic1), subscriptions.subscription());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         Map<String, List<String>> initialSubscription = singletonMap(consumerId, singletonList(topic1));
         partitionAssignor.prepare(singletonMap(consumerId, singletonList(t1p)));
@@ -502,7 +500,7 @@ public class ConsumerCoordinatorTest {
         }, joinGroupLeaderResponse(2, consumerId, updatedSubscriptions, Errors.NONE));
         client.prepareResponse(syncGroupResponse(newAssignment, Errors.NONE));
 
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(updatedSubscriptionSet, subscriptions.subscription());
@@ -525,7 +523,7 @@ public class ConsumerCoordinatorTest {
         assertEquals(singleton(topic1), subscriptions.subscription());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // Instrument the test so that metadata will contain two topics after next refresh.
         client.prepareMetadataUpdate(cluster, Collections.emptySet());
@@ -544,7 +542,7 @@ public class ConsumerCoordinatorTest {
         partitionAssignor.prepare(singletonMap(consumerId, singletonList(t1p)));
 
         // This will trigger rebalance.
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         // Make sure that the metadata was refreshed during the rebalance and thus subscriptions now contain two topics.
         final Set<String> updatedSubscriptionSet = new HashSet<>(Arrays.asList(topic1, topic2));
@@ -568,7 +566,7 @@ public class ConsumerCoordinatorTest {
         metadata.update(cluster, Collections.<String>emptySet(), time.milliseconds());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         Map<String, List<String>> memberSubscriptions = singletonMap(consumerId, singletonList(topic1));
         partitionAssignor.prepare(singletonMap(consumerId, singletonList(t1p)));
@@ -578,14 +576,14 @@ public class ConsumerCoordinatorTest {
         consumerClient.wakeup();
 
         try {
-            coordinator.poll(Long.MAX_VALUE);
+            coordinator.poll(time.timer(Long.MAX_VALUE));
         } catch (WakeupException e) {
             // ignore
         }
 
         // now complete the second half
         client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(singleton(t1p), subscriptions.assignedPartitions());
@@ -602,7 +600,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // normal join group
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE));
@@ -616,7 +614,7 @@ public class ConsumerCoordinatorTest {
             }
         }, syncGroupResponse(singletonList(t1p), Errors.NONE));
 
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(singleton(t1p), subscriptions.assignedPartitions());
@@ -634,14 +632,14 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // Join the group, but signal a coordinator change after the first heartbeat
         client.prepareResponse(joinGroupFollowerResponse(1, "consumer", "leader", Errors.NONE));
         client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
         client.prepareResponse(heartbeatResponse(Errors.NOT_COORDINATOR));
 
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         time.sleep(heartbeatIntervalMs);
 
         // Await the first heartbeat which forces us to find a new coordinator
@@ -649,7 +647,7 @@ public class ConsumerCoordinatorTest {
                 "Failed to observe expected heartbeat from background thread");
 
         assertTrue(coordinator.coordinatorUnknown());
-        assertFalse(coordinator.poll(0));
+        assertFalse(coordinator.poll(time.timer(0)));
         assertEquals(time.milliseconds(), heartbeat.lastPollTime());
 
         time.sleep(rebalanceTimeoutMs - 1);
@@ -668,7 +666,7 @@ public class ConsumerCoordinatorTest {
         metadata.update(TestUtils.singletonCluster(topic1, 1), Collections.<String>emptySet(), time.milliseconds());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // normal join group
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE));
@@ -684,7 +682,7 @@ public class ConsumerCoordinatorTest {
         // expect client to force updating the metadata, if yes gives it both topics
         client.prepareMetadataUpdate(cluster, Collections.emptySet());
 
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(2, subscriptions.assignedPartitions().size());
@@ -711,7 +709,7 @@ public class ConsumerCoordinatorTest {
                         leaveRequest.groupId().equals(groupId);
             }
         }, new LeaveGroupResponse(Errors.NONE));
-        coordinator.close(0);
+        coordinator.close(time.timer(0));
         assertTrue(received.get());
     }
 
@@ -746,12 +744,12 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // join initially, but let coordinator rebalance on sync
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE));
         client.prepareResponse(syncGroupResponse(Collections.emptyList(), Errors.UNKNOWN_SERVER_ERROR));
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
     }
 
     @Test
@@ -761,7 +759,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // join initially, but let coordinator returns unknown member id
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE));
@@ -777,7 +775,7 @@ public class ConsumerCoordinatorTest {
         }, joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE));
         client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
 
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(singleton(t1p), subscriptions.assignedPartitions());
@@ -790,7 +788,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // join initially, but let coordinator rebalance on sync
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE));
@@ -800,7 +798,7 @@ public class ConsumerCoordinatorTest {
         client.prepareResponse(joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE));
         client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
 
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(singleton(t1p), subscriptions.assignedPartitions());
@@ -813,7 +811,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // join initially, but let coordinator rebalance on sync
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE));
@@ -829,7 +827,7 @@ public class ConsumerCoordinatorTest {
         }, joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE));
         client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
 
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(singleton(t1p), subscriptions.assignedPartitions());
@@ -846,7 +844,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         Map<String, List<String>> memberSubscriptions = singletonMap(consumerId, singletonList(topic1));
         partitionAssignor.prepare(singletonMap(consumerId, singletonList(t1p)));
@@ -855,7 +853,7 @@ public class ConsumerCoordinatorTest {
         client.prepareResponse(joinGroupLeaderResponse(1, consumerId, memberSubscriptions, Errors.NONE));
         client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
 
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
 
@@ -883,7 +881,7 @@ public class ConsumerCoordinatorTest {
         metadata.update(TestUtils.singletonCluster(topic1, 1), Collections.<String>emptySet(), time.milliseconds());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // prepare initial rebalance
         Map<String, List<String>> memberSubscriptions = singletonMap(consumerId, topics);
@@ -912,7 +910,7 @@ public class ConsumerCoordinatorTest {
         client.prepareResponse(joinGroupLeaderResponse(2, consumerId, memberSubscriptions, Errors.NONE));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp1, tp2), Errors.NONE));
 
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(new HashSet<>(Arrays.asList(tp1, tp2)), subscriptions.assignedPartitions());
@@ -945,7 +943,7 @@ public class ConsumerCoordinatorTest {
         metadata.update(TestUtils.singletonCluster(topic, 1), Collections.emptySet(), time.milliseconds());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // prepare initial rebalance
         partitionAssignor.prepare(singletonMap(consumerId, Collections.singletonList(partition)));
@@ -956,13 +954,13 @@ public class ConsumerCoordinatorTest {
 
         // The first call to poll should raise the exception from the rebalance listener
         try {
-            coordinator.poll(Long.MAX_VALUE);
+            coordinator.poll(time.timer(Long.MAX_VALUE));
             fail("Expected exception thrown from assignment callback");
         } catch (WakeupException e) {
         }
 
         // The second call should retry the assignment callback and succeed
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(1, rebalanceListener.revokedCount);
@@ -1003,14 +1001,14 @@ public class ConsumerCoordinatorTest {
             subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         Map<String, List<String>> memberSubscriptions = singletonMap(consumerId, singletonList(topic1));
         partitionAssignor.prepare(Collections.<String, List<TopicPartition>>emptyMap());
 
         client.prepareResponse(joinGroupLeaderResponse(1, consumerId, memberSubscriptions, Errors.NONE));
         client.prepareResponse(syncGroupResponse(Collections.<TopicPartition>emptyList(), Errors.NONE));
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         if (!assign) {
             assertFalse(coordinator.rejoinNeededOrPending());
             assertEquals(Collections.<TopicPartition>emptySet(), rebalanceListener.assigned);
@@ -1021,7 +1019,7 @@ public class ConsumerCoordinatorTest {
         client.poll(0, time.milliseconds());
         client.prepareResponse(joinGroupLeaderResponse(2, consumerId, memberSubscriptions, Errors.NONE));
         client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         assertFalse("Metadata refresh requested unnecessarily", metadata.updateRequested());
         if (!assign) {
@@ -1067,7 +1065,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(new HashSet<>(Arrays.asList(topic1, otherTopic)), rebalanceListener);
         client.prepareResponse(joinGroupFollowerResponse(2, "consumer", "leader", Errors.NONE));
         client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertEquals(2, rebalanceListener.revokedCount);
         assertEquals(singleton(t1p), rebalanceListener.revoked);
@@ -1080,14 +1078,14 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // disconnected from original coordinator will cause re-discover and join again
         client.prepareResponse(joinGroupFollowerResponse(1, "consumer", "leader", Errors.NONE), true);
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         client.prepareResponse(joinGroupFollowerResponse(1, "consumer", "leader", Errors.NONE));
         client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertFalse(coordinator.rejoinNeededOrPending());
         assertEquals(singleton(t1p), subscriptions.assignedPartitions());
@@ -1101,11 +1099,11 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // coordinator doesn't like the session timeout
         client.prepareResponse(joinGroupFollowerResponse(0, "consumer", "", Errors.INVALID_SESSION_TIMEOUT));
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
     }
 
     @Test
@@ -1113,7 +1111,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.assignFromUser(singleton(t1p));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
 
@@ -1135,7 +1133,7 @@ public class ConsumerCoordinatorTest {
 
     private void testInFlightRequestsFailedAfterCoordinatorMarkedDead(Errors error) {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // Send two async commits and fail the first one with an error.
         // This should cause a coordinator disconnect which will cancel the second request.
@@ -1168,7 +1166,7 @@ public class ConsumerCoordinatorTest {
 
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
         time.sleep(autoCommitIntervalMs);
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertFalse(client.hasPendingResponses());
     }
 
@@ -1185,23 +1183,23 @@ public class ConsumerCoordinatorTest {
 
         // Send an offset commit, but let it fail with a retriable error
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NOT_COORDINATOR);
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertTrue(coordinator.coordinatorUnknown());
 
         // After the disconnect, we should rediscover the coordinator
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         subscriptions.seek(t1p, 200);
 
         // Until the retry backoff has expired, we should not retry the offset commit
         time.sleep(retryBackoffMs / 2);
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertEquals(0, client.inFlightRequestCount());
 
         // Once the backoff expires, we should retry
         time.sleep(retryBackoffMs / 2);
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertEquals(1, client.inFlightRequestCount());
         respondToOffsetCommitRequest(singletonMap(t1p, 200L), Errors.NONE);
     }
@@ -1218,28 +1216,28 @@ public class ConsumerCoordinatorTest {
         time.sleep(autoCommitIntervalMs);
 
         // Send the offset commit request, but do not respond
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertEquals(1, client.inFlightRequestCount());
 
         time.sleep(autoCommitIntervalMs / 2);
 
         // Ensure that no additional offset commit is sent
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertEquals(1, client.inFlightRequestCount());
 
         respondToOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertEquals(0, client.inFlightRequestCount());
 
         subscriptions.seek(t1p, 200);
 
         // If we poll again before the auto-commit interval, there should be no new sends
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertEquals(0, client.inFlightRequestCount());
 
         // After the remainder of the interval passes, we send a new offset commit
         time.sleep(autoCommitIntervalMs / 2);
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertEquals(1, client.inFlightRequestCount());
         respondToOffsetCommitRequest(singletonMap(t1p, 200L), Errors.NONE);
     }
@@ -1254,21 +1252,21 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // haven't joined, so should not cause a commit
         time.sleep(autoCommitIntervalMs);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE));
         client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
 
         subscriptions.seek(t1p, 100);
 
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
         time.sleep(autoCommitIntervalMs);
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertFalse(client.hasPendingResponses());
     }
 
@@ -1281,11 +1279,11 @@ public class ConsumerCoordinatorTest {
         subscriptions.seek(t1p, 100);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
         time.sleep(autoCommitIntervalMs);
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertFalse(client.hasPendingResponses());
     }
 
@@ -1298,18 +1296,18 @@ public class ConsumerCoordinatorTest {
         subscriptions.seek(t1p, 100);
 
         // no commit initially since coordinator is unknown
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         time.sleep(autoCommitIntervalMs);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         // now find the coordinator
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // sleep only for the retry backoff
         time.sleep(retryBackoffMs);
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
         assertFalse(client.hasPendingResponses());
     }
 
@@ -1318,7 +1316,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.assignFromUser(singleton(t1p));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
 
@@ -1334,7 +1332,7 @@ public class ConsumerCoordinatorTest {
     public void testCommitOffsetAsyncWithDefaultCallback() {
         int invokedBeforeTest = mockOffsetCommitCallback.invoked;
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
         coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), mockOffsetCommitCallback);
         coordinator.invokeCompletedOffsetCommitCallbacks();
@@ -1375,7 +1373,7 @@ public class ConsumerCoordinatorTest {
     public void testCommitOffsetAsyncFailedWithDefaultCallback() {
         int invokedBeforeTest = mockOffsetCommitCallback.invoked;
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.COORDINATOR_NOT_AVAILABLE);
         coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), mockOffsetCommitCallback);
         coordinator.invokeCompletedOffsetCommitCallbacks();
@@ -1386,7 +1384,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testCommitOffsetAsyncCoordinatorNotAvailable() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // async commit with coordinator not available
         MockCommitCallback cb = new MockCommitCallback();
@@ -1402,7 +1400,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testCommitOffsetAsyncNotCoordinator() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // async commit with not coordinator
         MockCommitCallback cb = new MockCommitCallback();
@@ -1418,7 +1416,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testCommitOffsetAsyncDisconnected() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // async commit with coordinator disconnected
         MockCommitCallback cb = new MockCommitCallback();
@@ -1434,43 +1432,43 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testCommitOffsetSyncNotCoordinator() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // sync commit with coordinator disconnected (should connect, get metadata, and then submit the commit request)
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NOT_COORDINATOR);
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
-        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), Long.MAX_VALUE);
+        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), time.timer(Long.MAX_VALUE));
     }
 
     @Test
     public void testCommitOffsetSyncCoordinatorNotAvailable() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // sync commit with coordinator disconnected (should connect, get metadata, and then submit the commit request)
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.COORDINATOR_NOT_AVAILABLE);
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
-        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), Long.MAX_VALUE);
+        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), time.timer(Long.MAX_VALUE));
     }
 
     @Test
     public void testCommitOffsetSyncCoordinatorDisconnected() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // sync commit with coordinator disconnected (should connect, get metadata, and then submit the commit request)
         prepareOffsetCommitRequestDisconnect(singletonMap(t1p, 100L));
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
-        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), Long.MAX_VALUE);
+        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), time.timer(Long.MAX_VALUE));
     }
 
     @Test
     public void testAsyncCommitCallbacksInvokedPriorToSyncCommitCompletion() throws Exception {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         final List<OffsetAndMetadata> committedOffsets = Collections.synchronizedList(new ArrayList<OffsetAndMetadata>());
         final OffsetAndMetadata firstOffset = new OffsetAndMetadata(0L);
@@ -1487,7 +1485,7 @@ public class ConsumerCoordinatorTest {
         Thread thread = new Thread() {
             @Override
             public void run() {
-                coordinator.commitOffsetsSync(singletonMap(t1p, secondOffset), 10000);
+                coordinator.commitOffsetsSync(singletonMap(t1p, secondOffset), time.timer(10000));
                 committedOffsets.add(secondOffset);
             }
         };
@@ -1506,68 +1504,73 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testRetryCommitUnknownTopicOrPartition() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         client.prepareResponse(offsetCommitResponse(singletonMap(t1p, Errors.UNKNOWN_TOPIC_OR_PARTITION)));
         client.prepareResponse(offsetCommitResponse(singletonMap(t1p, Errors.NONE)));
 
-        assertTrue(coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L, "metadata")), 10000));
+        assertTrue(coordinator.commitOffsetsSync(singletonMap(t1p,
+                new OffsetAndMetadata(100L, "metadata")), time.timer(10000)));
     }
 
     @Test(expected = OffsetMetadataTooLarge.class)
     public void testCommitOffsetMetadataTooLarge() {
         // since offset metadata is provided by the user, we have to propagate the exception so they can handle it
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.OFFSET_METADATA_TOO_LARGE);
-        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L, "metadata")), Long.MAX_VALUE);
+        coordinator.commitOffsetsSync(singletonMap(t1p,
+                new OffsetAndMetadata(100L, "metadata")), time.timer(Long.MAX_VALUE));
     }
 
     @Test(expected = CommitFailedException.class)
     public void testCommitOffsetIllegalGeneration() {
         // we cannot retry if a rebalance occurs before the commit completed
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.ILLEGAL_GENERATION);
-        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L, "metadata")), Long.MAX_VALUE);
+        coordinator.commitOffsetsSync(singletonMap(t1p,
+                new OffsetAndMetadata(100L, "metadata")), time.timer(Long.MAX_VALUE));
     }
 
     @Test(expected = CommitFailedException.class)
     public void testCommitOffsetUnknownMemberId() {
         // we cannot retry if a rebalance occurs before the commit completed
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.UNKNOWN_MEMBER_ID);
-        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L, "metadata")), Long.MAX_VALUE);
+        coordinator.commitOffsetsSync(singletonMap(t1p,
+                new OffsetAndMetadata(100L, "metadata")), time.timer(Long.MAX_VALUE));
     }
 
     @Test(expected = CommitFailedException.class)
     public void testCommitOffsetRebalanceInProgress() {
         // we cannot retry if a rebalance occurs before the commit completed
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.REBALANCE_IN_PROGRESS);
-        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L, "metadata")), Long.MAX_VALUE);
+        coordinator.commitOffsetsSync(singletonMap(t1p,
+                new OffsetAndMetadata(100L, "metadata")), time.timer(Long.MAX_VALUE));
     }
 
     @Test(expected = KafkaException.class)
     public void testCommitOffsetSyncCallbackWithNonRetriableException() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // sync commit with invalid partitions should throw if we have no callback
         prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.UNKNOWN_SERVER_ERROR);
-        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), Long.MAX_VALUE);
+        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), time.timer(Long.MAX_VALUE));
     }
 
     @Test(expected = IllegalArgumentException.class)
     public void testCommitSyncNegativeOffset() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(-1L)), Long.MAX_VALUE);
+        coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(-1L)), time.timer(Long.MAX_VALUE));
     }
 
     @Test
@@ -1583,18 +1586,18 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testCommitOffsetSyncWithoutFutureGetsCompleted() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
-        assertFalse(coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), 0));
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
+        assertFalse(coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), time.timer(0)));
     }
 
     @Test
     public void testRefreshOffset() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         subscriptions.assignFromUser(singleton(t1p));
         client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", 100L));
-        coordinator.refreshCommittedOffsetsIfNeeded(Long.MAX_VALUE);
+        coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertEquals(Collections.emptySet(), subscriptions.missingFetchPositions());
         assertTrue(subscriptions.hasAllFetchPositions());
@@ -1604,12 +1607,12 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testRefreshOffsetLoadInProgress() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         subscriptions.assignFromUser(singleton(t1p));
         client.prepareResponse(offsetFetchResponse(Errors.COORDINATOR_LOAD_IN_PROGRESS));
         client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", 100L));
-        coordinator.refreshCommittedOffsetsIfNeeded(Long.MAX_VALUE);
+        coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertEquals(Collections.emptySet(), subscriptions.missingFetchPositions());
         assertTrue(subscriptions.hasAllFetchPositions());
@@ -1619,12 +1622,12 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testRefreshOffsetsGroupNotAuthorized() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         subscriptions.assignFromUser(singleton(t1p));
         client.prepareResponse(offsetFetchResponse(Errors.GROUP_AUTHORIZATION_FAILED));
         try {
-            coordinator.refreshCommittedOffsetsIfNeeded(Long.MAX_VALUE);
+            coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE));
             fail("Expected group authorization error");
         } catch (GroupAuthorizationException e) {
             assertEquals(groupId, e.groupId());
@@ -1634,23 +1637,23 @@ public class ConsumerCoordinatorTest {
     @Test(expected = KafkaException.class)
     public void testRefreshOffsetUnknownTopicOrPartition() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         subscriptions.assignFromUser(singleton(t1p));
         client.prepareResponse(offsetFetchResponse(t1p, Errors.UNKNOWN_TOPIC_OR_PARTITION, "", 100L));
-        coordinator.refreshCommittedOffsetsIfNeeded(Long.MAX_VALUE);
+        coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE));
     }
 
     @Test
     public void testRefreshOffsetNotCoordinatorForConsumer() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         subscriptions.assignFromUser(singleton(t1p));
         client.prepareResponse(offsetFetchResponse(Errors.NOT_COORDINATOR));
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", 100L));
-        coordinator.refreshCommittedOffsetsIfNeeded(Long.MAX_VALUE);
+        coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertEquals(Collections.emptySet(), subscriptions.missingFetchPositions());
         assertTrue(subscriptions.hasAllFetchPositions());
@@ -1660,11 +1663,11 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testRefreshOffsetWithNoFetchableOffsets() {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         subscriptions.assignFromUser(singleton(t1p));
         client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", -1L));
-        coordinator.refreshCommittedOffsetsIfNeeded(Long.MAX_VALUE);
+        coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertEquals(Collections.singleton(t1p), subscriptions.missingFetchPositions());
         assertEquals(Collections.emptySet(), subscriptions.partitionsNeedingReset(time.milliseconds()));
@@ -1678,7 +1681,7 @@ public class ConsumerCoordinatorTest {
 
         subscriptions.assignFromUser(singleton(t1p));
         subscriptions.seek(t1p, 500L);
-        coordinator.refreshCommittedOffsetsIfNeeded(Long.MAX_VALUE);
+        coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertEquals(Collections.emptySet(), subscriptions.missingFetchPositions());
         assertTrue(subscriptions.hasAllFetchPositions());
@@ -1692,7 +1695,7 @@ public class ConsumerCoordinatorTest {
 
         subscriptions.assignFromUser(singleton(t1p));
         subscriptions.requestOffsetReset(t1p, OffsetResetStrategy.EARLIEST);
-        coordinator.refreshCommittedOffsetsIfNeeded(Long.MAX_VALUE);
+        coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertEquals(Collections.emptySet(), subscriptions.missingFetchPositions());
         assertFalse(subscriptions.hasAllFetchPositions());
@@ -1860,17 +1863,17 @@ public class ConsumerCoordinatorTest {
         ConsumerCoordinator coordinator = buildCoordinator(new Metrics(), assignors,
                 ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, autoCommit, leaveGroup);
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
         if (useGroupManagement) {
             subscriptions.subscribe(singleton(topic1), rebalanceListener);
             client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE));
             client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
-            coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+            coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
         } else
             subscriptions.assignFromUser(singleton(t1p));
 
         subscriptions.seek(t1p, 100);
-        coordinator.poll(Long.MAX_VALUE);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         return coordinator;
     }
@@ -1880,7 +1883,7 @@ public class ConsumerCoordinatorTest {
         coordinator.sendHeartbeatRequest();
         client.prepareResponse(heartbeatResponse(error));
         time.sleep(sessionTimeoutMs);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(coordinator.coordinatorUnknown());
     }
 
@@ -1896,7 +1899,7 @@ public class ConsumerCoordinatorTest {
             Future<?> future = executor.submit(new Runnable() {
                 @Override
                 public void run() {
-                    coordinator.close(Math.min(closeTimeoutMs, requestTimeoutMs));
+                    coordinator.close(time.timer(Math.min(closeTimeoutMs, requestTimeoutMs)));
                 }
             });
             // Wait for close to start. If coordinator is known, wait for close to queue
@@ -2031,10 +2034,10 @@ public class ConsumerCoordinatorTest {
                                                     ConsumerCoordinator coordinator,
                                                     List<TopicPartition> assignment) {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE));
         client.prepareResponse(syncGroupResponse(assignment, Errors.NONE));
-        coordinator.joinGroupIfNeeded(Long.MAX_VALUE, time.milliseconds());
+        coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
     }
 
     private void prepareOffsetCommitRequest(Map<TopicPartition, Long> expectedOffsets, Errors error) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java
index d5ec382..4494fd5 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java
@@ -98,7 +98,7 @@ public class ConsumerNetworkClientTest {
         assertEquals(2, consumerClient.pendingRequestCount());
         assertEquals(2, consumerClient.pendingRequestCount(node));
 
-        consumerClient.awaitPendingRequests(node, Long.MAX_VALUE);
+        consumerClient.awaitPendingRequests(node, time.timer(Long.MAX_VALUE));
         assertTrue(future1.succeeded());
         assertTrue(future2.succeeded());
     }
@@ -157,7 +157,7 @@ public class ConsumerNetworkClientTest {
 
         EasyMock.replay(mockNetworkClient);
 
-        consumerClient.poll(Long.MAX_VALUE, time.milliseconds(), new ConsumerNetworkClient.PollCondition() {
+        consumerClient.poll(time.timer(Long.MAX_VALUE), new ConsumerNetworkClient.PollCondition() {
             @Override
             public boolean shouldBlock() {
                 return false;
@@ -180,7 +180,7 @@ public class ConsumerNetworkClientTest {
 
         EasyMock.replay(mockNetworkClient);
 
-        consumerClient.poll(timeout, time.milliseconds(), new ConsumerNetworkClient.PollCondition() {
+        consumerClient.poll(time.timer(timeout), new ConsumerNetworkClient.PollCondition() {
             @Override
             public boolean shouldBlock() {
                 return true;
@@ -203,7 +203,7 @@ public class ConsumerNetworkClientTest {
 
         EasyMock.replay(mockNetworkClient);
 
-        consumerClient.poll(Long.MAX_VALUE, time.milliseconds(), new ConsumerNetworkClient.PollCondition() {
+        consumerClient.poll(time.timer(Long.MAX_VALUE), new ConsumerNetworkClient.PollCondition() {
             @Override
             public boolean shouldBlock() {
                 return true;
@@ -218,7 +218,7 @@ public class ConsumerNetworkClientTest {
         RequestFuture<ClientResponse> future = consumerClient.send(node, heartbeat());
         consumerClient.wakeup();
         try {
-            consumerClient.poll(0);
+            consumerClient.poll(time.timer(0));
             fail();
         } catch (WakeupException e) {
         }
@@ -290,7 +290,7 @@ public class ConsumerNetworkClientTest {
 
     @Test
     public void testAwaitForMetadataUpdateWithTimeout() {
-        assertFalse(consumerClient.awaitMetadataUpdate(10L));
+        assertFalse(consumerClient.awaitMetadataUpdate(time.timer(10L)));
     }
 
     @Test
@@ -325,7 +325,7 @@ public class ConsumerNetworkClientTest {
         assertFalse(future2.isDone());
 
         // First send should have expired and second send still pending
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(future1.isDone());
         assertFalse(future1.succeeded());
         assertEquals(1, consumerClient.pendingRequestCount());
@@ -346,7 +346,7 @@ public class ConsumerNetworkClientTest {
         assertEquals(1, consumerClient.pendingRequestCount());
         assertEquals(1, consumerClient.pendingRequestCount(node));
         disconnected.set(true);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(future3.isDone());
         assertFalse(future3.succeeded());
         assertEquals(0, consumerClient.pendingRequestCount());
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index f97c266..1a82faa 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -192,7 +192,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
 
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
@@ -246,7 +246,7 @@ public class FetcherTest {
         buffer.flip();
 
         client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
@@ -269,7 +269,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
 
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NOT_LEADER_FOR_PARTITION, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
@@ -312,7 +312,7 @@ public class FetcherTest {
         client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         // The fetcher should block on Deserialization error
         for (int i = 0; i < 2; i++) {
             try {
@@ -372,7 +372,7 @@ public class FetcherTest {
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         // the first fetchedRecords() should return the first valid message
         assertEquals(1, fetcher.fetchedRecords().get(tp0).size());
@@ -410,7 +410,7 @@ public class FetcherTest {
         fetcher.fetchedRecords();
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(responseBuffer), Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         List<ConsumerRecord<byte[], byte[]>> records = fetcher.fetchedRecords().get(tp0);
         assertEquals(1, records.size());
@@ -443,7 +443,7 @@ public class FetcherTest {
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         // the fetchedRecords() should always throw exception due to the bad batch.
         for (int i = 0; i < 2; i++) {
@@ -474,7 +474,7 @@ public class FetcherTest {
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         try {
             fetcher.fetchedRecords();
             fail("fetchedRecords should have raised");
@@ -509,7 +509,7 @@ public class FetcherTest {
         client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, memoryRecords, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         records = fetcher.fetchedRecords().get(tp0);
 
         assertEquals(3, records.size());
@@ -540,7 +540,7 @@ public class FetcherTest {
         client.prepareResponse(matchesOffset(tp0, 4), fullFetchResponse(tp0, this.nextRecords, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         records = fetcher.fetchedRecords().get(tp0);
         assertEquals(2, records.size());
         assertEquals(3L, subscriptions.position(tp0).longValue());
@@ -548,14 +548,14 @@ public class FetcherTest {
         assertEquals(2, records.get(1).offset());
 
         assertEquals(0, fetcher.sendFetches());
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         records = fetcher.fetchedRecords().get(tp0);
         assertEquals(1, records.size());
         assertEquals(4L, subscriptions.position(tp0).longValue());
         assertEquals(3, records.get(0).offset());
 
         assertTrue(fetcher.sendFetches() > 0);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         records = fetcher.fetchedRecords().get(tp0);
         assertEquals(2, records.size());
         assertEquals(6L, subscriptions.position(tp0).longValue());
@@ -580,7 +580,7 @@ public class FetcherTest {
         client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         records = fetcher.fetchedRecords().get(tp0);
         assertEquals(2, records.size());
         assertEquals(3L, subscriptions.position(tp0).longValue());
@@ -592,7 +592,7 @@ public class FetcherTest {
         subscriptions.seek(tp1, 4);
 
         assertEquals(1, fetcher.sendFetches());
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
         assertNull(fetchedRecords.get(tp0));
         records = fetchedRecords.get(tp1);
@@ -621,7 +621,7 @@ public class FetcherTest {
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         consumerRecords = fetcher.fetchedRecords().get(tp0);
         assertEquals(3, consumerRecords.size());
         assertEquals(31L, subscriptions.position(tp0).longValue()); // this is the next fetching position
@@ -681,7 +681,7 @@ public class FetcherTest {
         MemoryRecords partialRecord = MemoryRecords.readableRecords(
             ByteBuffer.wrap(new byte[]{0, 0, 0, 0, 0, 0, 0, 0}));
         client.prepareResponse(fullFetchResponse(tp0, partialRecord, Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
     }
 
@@ -693,7 +693,7 @@ public class FetcherTest {
         // resize the limit of the buffer to pretend it is only fetch-size large
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.TOPIC_AUTHORIZATION_FAILED, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         try {
             fetcher.fetchedRecords();
             fail("fetchedRecords should have thrown");
@@ -713,7 +713,7 @@ public class FetcherTest {
         // Now the rebalance happens and fetch positions are cleared
         subscriptions.assignFromSubscribed(singleton(tp0));
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         // The active fetch should be ignored since its position is no longer valid
         assertTrue(fetcher.fetchedRecords().isEmpty());
@@ -728,7 +728,7 @@ public class FetcherTest {
         subscriptions.pause(tp0);
 
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertNull(fetcher.fetchedRecords().get(tp0));
     }
 
@@ -749,7 +749,7 @@ public class FetcherTest {
 
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NOT_LEADER_FOR_PARTITION, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds()));
     }
@@ -761,7 +761,7 @@ public class FetcherTest {
 
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.UNKNOWN_TOPIC_OR_PARTITION, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds()));
     }
@@ -773,7 +773,7 @@ public class FetcherTest {
 
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertTrue(subscriptions.isOffsetResetNeeded(tp0));
         assertEquals(null, subscriptions.position(tp0));
@@ -789,7 +789,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         subscriptions.seek(tp0, 1);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertEquals(1, subscriptions.position(tp0).longValue());
@@ -802,7 +802,7 @@ public class FetcherTest {
 
         assertTrue(fetcherNoAutoReset.sendFetches() > 0);
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertFalse(subscriptionsNoAutoReset.isOffsetResetNeeded(tp0));
         subscriptionsNoAutoReset.seek(tp0, 2);
         assertEquals(0, fetcherNoAutoReset.fetchedRecords().size());
@@ -815,7 +815,7 @@ public class FetcherTest {
 
         fetcherNoAutoReset.sendFetches();
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         assertFalse(subscriptionsNoAutoReset.isOffsetResetNeeded(tp0));
         for (int i = 0; i < 2; i++) {
@@ -839,14 +839,14 @@ public class FetcherTest {
 
         assertEquals(1, fetcherNoAutoReset.sendFetches());
 
-        Map<TopicPartition, FetchResponse.PartitionData> partitions = new LinkedHashMap<>();
-        partitions.put(tp1, new FetchResponse.PartitionData(Errors.NONE, 100,
+        Map<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions = new LinkedHashMap<>();
+        partitions.put(tp1, new FetchResponse.PartitionData<>(Errors.NONE, 100,
             FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, records));
-        partitions.put(tp0, new FetchResponse.PartitionData(Errors.OFFSET_OUT_OF_RANGE, 100,
+        partitions.put(tp0, new FetchResponse.PartitionData<>(Errors.OFFSET_OUT_OF_RANGE, 100,
             FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY));
-        client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions),
+        client.prepareResponse(new FetchResponse<>(Errors.NONE, new LinkedHashMap<>(partitions),
             0, INVALID_SESSION_ID));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         List<ConsumerRecord<byte[], byte[]>> fetchedRecords = new ArrayList<>();
         List<OffsetOutOfRangeException> exceptions = new ArrayList<>();
@@ -884,18 +884,18 @@ public class FetcherTest {
 
         assertEquals(1, fetcherNoAutoReset.sendFetches());
 
-        Map<TopicPartition, FetchResponse.PartitionData> partitions = new LinkedHashMap<>();
-        partitions.put(tp1, new FetchResponse.PartitionData(Errors.NONE, 100, FetchResponse.INVALID_LAST_STABLE_OFFSET,
+        Map<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions = new LinkedHashMap<>();
+        partitions.put(tp1, new FetchResponse.PartitionData<>(Errors.NONE, 100, FetchResponse.INVALID_LAST_STABLE_OFFSET,
                 FetchResponse.INVALID_LOG_START_OFFSET, null, records));
-        partitions.put(tp0, new FetchResponse.PartitionData(Errors.OFFSET_OUT_OF_RANGE, 100,
+        partitions.put(tp0, new FetchResponse.PartitionData<>(Errors.OFFSET_OUT_OF_RANGE, 100,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY));
-        partitions.put(tp2, new FetchResponse.PartitionData(Errors.NONE, 100L, 4,
+        partitions.put(tp2, new FetchResponse.PartitionData<>(Errors.NONE, 100L, 4,
                 0L, null, nextRecords));
-        partitions.put(tp3, new FetchResponse.PartitionData(Errors.NONE, 100L, 4,
+        partitions.put(tp3, new FetchResponse.PartitionData<>(Errors.NONE, 100L, 4,
                 0L, null, partialRecords));
-        client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions),
+        client.prepareResponse(new FetchResponse<>(Errors.NONE, new LinkedHashMap<>(partitions),
                 0, INVALID_SESSION_ID));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         List<ConsumerRecord<byte[], byte[]>> fetchedRecords = new ArrayList<>();
         for (List<ConsumerRecord<byte[], byte[]>> records: fetcherNoAutoReset.fetchedRecords().values())
@@ -947,11 +947,11 @@ public class FetcherTest {
         subscriptionsNoAutoReset.assignFromUser(Utils.mkSet(tp0));
         subscriptionsNoAutoReset.seek(tp0, 1);
         assertEquals(1, fetcher.sendFetches());
-        Map<TopicPartition, FetchResponse.PartitionData> partitions = new HashMap<>();
-        partitions.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 100,
+        Map<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions = new HashMap<>();
+        partitions.put(tp0, new FetchResponse.PartitionData<>(Errors.NONE, 100,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, records));
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
 
         assertEquals(2, fetcher.fetchedRecords().get(tp0).size());
 
@@ -959,10 +959,10 @@ public class FetcherTest {
         subscriptionsNoAutoReset.seek(tp1, 1);
         assertEquals(1, fetcher.sendFetches());
         partitions = new HashMap<>();
-        partitions.put(tp1, new FetchResponse.PartitionData(Errors.OFFSET_OUT_OF_RANGE, 100,
+        partitions.put(tp1, new FetchResponse.PartitionData<>(Errors.OFFSET_OUT_OF_RANGE, 100,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY));
-        client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions), 0, INVALID_SESSION_ID));
-        consumerClient.poll(0);
+        client.prepareResponse(new FetchResponse<>(Errors.NONE, new LinkedHashMap<>(partitions), 0, INVALID_SESSION_ID));
+        consumerClient.poll(time.timer(0));
         assertEquals(1, fetcher.fetchedRecords().get(tp0).size());
 
         subscriptionsNoAutoReset.seek(tp1, 10);
@@ -977,7 +977,7 @@ public class FetcherTest {
 
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0), true);
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
 
         // disconnects should have no affect on subscription state
@@ -1332,7 +1332,7 @@ public class FetcherTest {
         // sending response before request, as getTopicMetadata is a blocking call
         client.prepareResponse(newMetadataResponse(topicName, Errors.NONE));
 
-        Map<String, List<PartitionInfo>> allTopics = fetcher.getAllTopicMetadata(5000L);
+        Map<String, List<PartitionInfo>> allTopics = fetcher.getAllTopicMetadata(time.timer(5000L));
 
         assertEquals(cluster.topics().size(), allTopics.size());
     }
@@ -1342,21 +1342,21 @@ public class FetcherTest {
         // first try gets a disconnect, next succeeds
         client.prepareResponse(null, true);
         client.prepareResponse(newMetadataResponse(topicName, Errors.NONE));
-        Map<String, List<PartitionInfo>> allTopics = fetcher.getAllTopicMetadata(5000L);
+        Map<String, List<PartitionInfo>> allTopics = fetcher.getAllTopicMetadata(time.timer(5000L));
         assertEquals(cluster.topics().size(), allTopics.size());
     }
 
     @Test(expected = TimeoutException.class)
     public void testGetAllTopicsTimeout() {
         // since no response is prepared, the request should timeout
-        fetcher.getAllTopicMetadata(50L);
+        fetcher.getAllTopicMetadata(time.timer(50L));
     }
 
     @Test
     public void testGetAllTopicsUnauthorized() {
         client.prepareResponse(newMetadataResponse(topicName, Errors.TOPIC_AUTHORIZATION_FAILED));
         try {
-            fetcher.getAllTopicMetadata(10L);
+            fetcher.getAllTopicMetadata(time.timer(10L));
             fail();
         } catch (TopicAuthorizationException e) {
             assertEquals(singleton(topicName), e.unauthorizedTopics());
@@ -1367,7 +1367,7 @@ public class FetcherTest {
     public void testGetTopicMetadataInvalidTopic() {
         client.prepareResponse(newMetadataResponse(topicName, Errors.INVALID_TOPIC_EXCEPTION));
         fetcher.getTopicMetadata(
-                new MetadataRequest.Builder(Collections.singletonList(topicName), true), 5000L);
+                new MetadataRequest.Builder(Collections.singletonList(topicName), true), time.timer(5000L));
     }
 
     @Test
@@ -1375,7 +1375,7 @@ public class FetcherTest {
         client.prepareResponse(newMetadataResponse(topicName, Errors.UNKNOWN_TOPIC_OR_PARTITION));
 
         Map<String, List<PartitionInfo>> topicMetadata = fetcher.getTopicMetadata(
-                new MetadataRequest.Builder(Collections.singletonList(topicName), true), 5000L);
+                new MetadataRequest.Builder(Collections.singletonList(topicName), true), time.timer(5000L));
         assertNull(topicMetadata.get(topicName));
     }
 
@@ -1385,7 +1385,7 @@ public class FetcherTest {
         client.prepareResponse(newMetadataResponse(topicName, Errors.NONE));
 
         Map<String, List<PartitionInfo>> topicMetadata = fetcher.getTopicMetadata(
-                new MetadataRequest.Builder(Collections.singletonList(topicName), true), 5000L);
+                new MetadataRequest.Builder(Collections.singletonList(topicName), true), time.timer(5000L));
         assertTrue(topicMetadata.containsKey(topicName));
     }
 
@@ -1426,7 +1426,8 @@ public class FetcherTest {
         client.prepareResponse(altered);
 
         Map<String, List<PartitionInfo>> topicMetadata =
-            fetcher.getTopicMetadata(new MetadataRequest.Builder(Collections.singletonList(topicName), false), 5000L);
+            fetcher.getTopicMetadata(new MetadataRequest.Builder(Collections.singletonList(topicName), false),
+                    time.timer(5000L));
 
         Assert.assertNotNull(topicMetadata);
         Assert.assertNotNull(topicMetadata.get(topicName));
@@ -1619,7 +1620,7 @@ public class FetcherTest {
         subscriptions.assignFromUser(Utils.mkSet(tp1, tp2));
 
         int expectedBytes = 0;
-        LinkedHashMap<TopicPartition, FetchResponse.PartitionData> fetchPartitionData = new LinkedHashMap<>();
+        LinkedHashMap<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> fetchPartitionData = new LinkedHashMap<>();
 
         for (TopicPartition tp : Utils.mkSet(tp1, tp2)) {
             subscriptions.seek(tp, 0);
@@ -1632,13 +1633,13 @@ public class FetcherTest {
             for (Record record : records.records())
                 expectedBytes += record.sizeInBytes();
 
-            fetchPartitionData.put(tp, new FetchResponse.PartitionData(Errors.NONE, 15L,
+            fetchPartitionData.put(tp, new FetchResponse.PartitionData<>(Errors.NONE, 15L,
                     FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, records));
         }
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(new FetchResponse(Errors.NONE, fetchPartitionData, 0, INVALID_SESSION_ID));
-        consumerClient.poll(0);
+        client.prepareResponse(new FetchResponse<>(Errors.NONE, fetchPartitionData, 0, INVALID_SESSION_ID));
+        consumerClient.poll(time.timer(0));
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
         assertEquals(3, fetchedRecords.get(tp1).size());
@@ -1693,16 +1694,16 @@ public class FetcherTest {
             builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes());
         MemoryRecords records = builder.build();
 
-        Map<TopicPartition, FetchResponse.PartitionData> partitions = new HashMap<>();
-        partitions.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 100,
+        Map<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions = new HashMap<>();
+        partitions.put(tp0, new FetchResponse.PartitionData<>(Errors.NONE, 100,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, records));
-        partitions.put(tp1, new FetchResponse.PartitionData(Errors.OFFSET_OUT_OF_RANGE, 100,
+        partitions.put(tp1, new FetchResponse.PartitionData<>(Errors.OFFSET_OUT_OF_RANGE, 100,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, MemoryRecords.EMPTY));
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions),
+        client.prepareResponse(new FetchResponse<>(Errors.NONE, new LinkedHashMap<>(partitions),
                 0, INVALID_SESSION_ID));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         fetcher.fetchedRecords();
 
         int expectedBytes = 0;
@@ -1733,16 +1734,16 @@ public class FetcherTest {
             builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes());
         MemoryRecords records = builder.build();
 
-        Map<TopicPartition, FetchResponse.PartitionData> partitions = new HashMap<>();
-        partitions.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 100,
+        Map<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions = new HashMap<>();
+        partitions.put(tp0, new FetchResponse.PartitionData<>(Errors.NONE, 100,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, records));
-        partitions.put(tp1, new FetchResponse.PartitionData(Errors.NONE, 100,
+        partitions.put(tp1, new FetchResponse.PartitionData<>(Errors.NONE, 100,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null,
                 MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("val".getBytes()))));
 
-        client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions),
+        client.prepareResponse(new FetchResponse<>(Errors.NONE, new LinkedHashMap<>(partitions),
                 0, INVALID_SESSION_ID));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         fetcher.fetchedRecords();
 
         // we should have ignored the record at the wrong offset
@@ -1768,7 +1769,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
         assertTrue(partitionRecords.containsKey(tp0));
@@ -1795,7 +1796,7 @@ public class FetcherTest {
             TopicPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, int throttleTime) {
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp, records, error, hw, lastStableOffset, throttleTime));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         return fetcher.fetchedRecords();
     }
 
@@ -1803,14 +1804,14 @@ public class FetcherTest {
             TopicPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, long logStartOffset, int throttleTime) {
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fetchResponse(tp, records, error, hw, lastStableOffset, logStartOffset, throttleTime));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         return fetcher.fetchedRecords();
     }
 
     @Test
     public void testGetOffsetsForTimesTimeout() {
         try {
-            fetcher.offsetsByTimes(Collections.singletonMap(new TopicPartition(topicName, 2), 1000L), 100L);
+            fetcher.offsetsByTimes(Collections.singletonMap(new TopicPartition(topicName, 2), 1000L), time.timer(100L));
             fail("Should throw timeout exception.");
         } catch (TimeoutException e) {
             // let it go.
@@ -1820,7 +1821,7 @@ public class FetcherTest {
     @Test
     public void testGetOffsetsForTimes() {
         // Empty map
-        assertTrue(fetcher.offsetsByTimes(new HashMap<TopicPartition, Long>(), 100L).isEmpty());
+        assertTrue(fetcher.offsetsByTimes(new HashMap<TopicPartition, Long>(), time.timer(100L)).isEmpty());
         // Unknown Offset
         testGetOffsetsForTimesWithUnknownOffset();
         // Error code none with unknown offset
@@ -1851,7 +1852,7 @@ public class FetcherTest {
         offsetsToSearch.put(tp0, ListOffsetRequest.EARLIEST_TIMESTAMP);
         offsetsToSearch.put(tp1, ListOffsetRequest.EARLIEST_TIMESTAMP);
 
-        fetcher.offsetsByTimes(offsetsToSearch, 0);
+        fetcher.offsetsByTimes(offsetsToSearch, time.timer(0));
     }
 
     @Test
@@ -1881,7 +1882,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
 
         client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
@@ -1920,7 +1921,7 @@ public class FetcherTest {
             }
         }, fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
 
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
@@ -1990,7 +1991,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
 
         client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
@@ -2037,7 +2038,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
 
         client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
@@ -2082,7 +2083,7 @@ public class FetcherTest {
 
         client.prepareResponse(fullFetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer),
                 abortedTransactions, Errors.NONE, 100L, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<String, String>>> allFetchedRecords = fetcher.fetchedRecords();
@@ -2119,7 +2120,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, compactedRecords, Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> allFetchedRecords = fetcher.fetchedRecords();
@@ -2154,7 +2155,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, recordsWithEmptyBatch, Errors.NONE, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> allFetchedRecords = fetcher.fetchedRecords();
@@ -2216,7 +2217,7 @@ public class FetcherTest {
 
         client.prepareResponse(fullFetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer),
                 abortedTransactions, Errors.NONE, 100L, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<String, String>>> allFetchedRecords = fetcher.fetchedRecords();
@@ -2253,7 +2254,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
 
         client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
@@ -2286,7 +2287,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
 
         client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
@@ -2306,16 +2307,16 @@ public class FetcherTest {
         subscriptions.seek(tp1, 1);
 
         // Fetch some records and establish an incremental fetch session.
-        LinkedHashMap<TopicPartition, FetchResponse.PartitionData> partitions1 = new LinkedHashMap<>();
-        partitions1.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 2L,
+        LinkedHashMap<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions1 = new LinkedHashMap<>();
+        partitions1.put(tp0, new FetchResponse.PartitionData<>(Errors.NONE, 2L,
                 2, 0L, null, this.records));
-        partitions1.put(tp1, new FetchResponse.PartitionData(Errors.NONE, 100L,
+        partitions1.put(tp1, new FetchResponse.PartitionData<>(Errors.NONE, 100L,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, emptyRecords));
-        FetchResponse resp1 = new FetchResponse(Errors.NONE, partitions1, 0, 123);
+        FetchResponse resp1 = new FetchResponse<>(Errors.NONE, partitions1, 0, 123);
         client.prepareResponse(resp1);
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
         assertFalse(fetchedRecords.containsKey(tp1));
@@ -2336,25 +2337,24 @@ public class FetcherTest {
         assertEquals(4L, subscriptions.position(tp0).longValue());
 
         // The second response contains no new records.
-        LinkedHashMap<TopicPartition, FetchResponse.PartitionData> partitions2 = new LinkedHashMap<>();
-        FetchResponse resp2 = new FetchResponse(Errors.NONE, partitions2, 0, 123);
+        LinkedHashMap<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions2 = new LinkedHashMap<>();
+        FetchResponse resp2 = new FetchResponse<>(Errors.NONE, partitions2, 0, 123);
         client.prepareResponse(resp2);
         assertEquals(1, fetcher.sendFetches());
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         fetchedRecords = fetcher.fetchedRecords();
         assertTrue(fetchedRecords.isEmpty());
         assertEquals(4L, subscriptions.position(tp0).longValue());
         assertEquals(1L, subscriptions.position(tp1).longValue());
 
         // The third response contains some new records for tp0.
-        LinkedHashMap<TopicPartition, FetchResponse.PartitionData> partitions3 = new LinkedHashMap<>();
-        partitions3.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 100L,
+        LinkedHashMap<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions3 = new LinkedHashMap<>();
+        partitions3.put(tp0, new FetchResponse.PartitionData<>(Errors.NONE, 100L,
                 4, 0L, null, this.nextRecords));
-        new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions1), 0, INVALID_SESSION_ID);
-        FetchResponse resp3 = new FetchResponse(Errors.NONE, partitions3, 0, 123);
+        FetchResponse resp3 = new FetchResponse<>(Errors.NONE, partitions3, 0, 123);
         client.prepareResponse(resp3);
         assertEquals(1, fetcher.sendFetches());
-        consumerClient.poll(0);
+        consumerClient.poll(time.timer(0));
         fetchedRecords = fetcher.fetchedRecords();
         assertFalse(fetchedRecords.containsKey(tp1));
         records = fetchedRecords.get(tp0);
@@ -2428,7 +2428,8 @@ public class FetcherTest {
         Map<TopicPartition, Long> timestampToSearch = new HashMap<>();
         timestampToSearch.put(t2p0, 0L);
         timestampToSearch.put(tp1, 0L);
-        Map<TopicPartition, OffsetAndTimestamp> offsetAndTimestampMap = fetcher.offsetsByTimes(timestampToSearch, Long.MAX_VALUE);
+        Map<TopicPartition, OffsetAndTimestamp> offsetAndTimestampMap =
+                fetcher.offsetsByTimes(timestampToSearch, time.timer(Long.MAX_VALUE));
 
         if (expectedOffsetForP0 == null)
             assertNull(offsetAndTimestampMap.get(t2p0));
@@ -2459,7 +2460,8 @@ public class FetcherTest {
 
         Map<TopicPartition, Long> timestampToSearch = new HashMap<>();
         timestampToSearch.put(tp0, 0L);
-        Map<TopicPartition, OffsetAndTimestamp> offsetAndTimestampMap = fetcher.offsetsByTimes(timestampToSearch, Long.MAX_VALUE);
+        Map<TopicPartition, OffsetAndTimestamp> offsetAndTimestampMap =
+                fetcher.offsetsByTimes(timestampToSearch, time.timer(Long.MAX_VALUE));
 
         assertTrue(offsetAndTimestampMap.containsKey(tp0));
         assertNull(offsetAndTimestampMap.get(tp0));
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java
index 7db7820..c382de6 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java
@@ -31,7 +31,8 @@ public class HeartbeatTest {
     private int maxPollIntervalMs = 900;
     private long retryBackoffMs = 10L;
     private MockTime time = new MockTime();
-    private Heartbeat heartbeat = new Heartbeat(sessionTimeoutMs, heartbeatIntervalMs, maxPollIntervalMs, retryBackoffMs);
+    private Heartbeat heartbeat = new Heartbeat(time, sessionTimeoutMs, heartbeatIntervalMs,
+            maxPollIntervalMs, retryBackoffMs);
 
     @Test
     public void testShouldHeartbeat() {
@@ -49,24 +50,58 @@ public class HeartbeatTest {
 
     @Test
     public void testTimeToNextHeartbeat() {
-        heartbeat.sentHeartbeat(0);
-        assertEquals(100, heartbeat.timeToNextHeartbeat(0));
-        assertEquals(0, heartbeat.timeToNextHeartbeat(100));
-        assertEquals(0, heartbeat.timeToNextHeartbeat(200));
+        heartbeat.sentHeartbeat(time.milliseconds());
+        assertEquals(heartbeatIntervalMs, heartbeat.timeToNextHeartbeat(time.milliseconds()));
+
+        time.sleep(heartbeatIntervalMs);
+        assertEquals(0, heartbeat.timeToNextHeartbeat(time.milliseconds()));
+
+        time.sleep(heartbeatIntervalMs);
+        assertEquals(0, heartbeat.timeToNextHeartbeat(time.milliseconds()));
     }
 
     @Test
     public void testSessionTimeoutExpired() {
         heartbeat.sentHeartbeat(time.milliseconds());
-        time.sleep(305);
+        time.sleep(sessionTimeoutMs + 5);
         assertTrue(heartbeat.sessionTimeoutExpired(time.milliseconds()));
     }
 
     @Test
     public void testResetSession() {
         heartbeat.sentHeartbeat(time.milliseconds());
-        time.sleep(305);
-        heartbeat.resetTimeouts(time.milliseconds());
+        time.sleep(sessionTimeoutMs + 5);
+        heartbeat.resetSessionTimeout();
+        assertFalse(heartbeat.sessionTimeoutExpired(time.milliseconds()));
+
+        // Resetting the session timeout should not reset the poll timeout
+        time.sleep(maxPollIntervalMs + 1);
+        heartbeat.resetSessionTimeout();
+        assertTrue(heartbeat.pollTimeoutExpired(time.milliseconds()));
+    }
+
+    @Test
+    public void testResetTimeouts() {
+        time.sleep(maxPollIntervalMs);
+        assertTrue(heartbeat.sessionTimeoutExpired(time.milliseconds()));
+        assertEquals(0, heartbeat.timeToNextHeartbeat(time.milliseconds()));
+        assertTrue(heartbeat.pollTimeoutExpired(time.milliseconds()));
+
+        heartbeat.resetTimeouts();
         assertFalse(heartbeat.sessionTimeoutExpired(time.milliseconds()));
+        assertEquals(heartbeatIntervalMs, heartbeat.timeToNextHeartbeat(time.milliseconds()));
+        assertFalse(heartbeat.pollTimeoutExpired(time.milliseconds()));
+    }
+
+    @Test
+    public void testPollTimeout() {
+        assertFalse(heartbeat.pollTimeoutExpired(time.milliseconds()));
+        time.sleep(maxPollIntervalMs / 2);
+
+        assertFalse(heartbeat.pollTimeoutExpired(time.milliseconds()));
+        time.sleep(maxPollIntervalMs / 2 + 1);
+
+        assertTrue(heartbeat.pollTimeoutExpired(time.milliseconds()));
     }
+
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java b/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java
index 011eba2..4dbd03c 100644
--- a/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java
+++ b/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java
@@ -71,11 +71,6 @@ public class MockTime implements Time {
         return highResTimeNs.get();
     }
 
-    @Override
-    public long hiResClockMs() {
-        return TimeUnit.NANOSECONDS.toMillis(nanoseconds());
-    }
-
     private void maybeSleep(long ms) {
         if (ms != 0)
             sleep(ms);
diff --git a/clients/src/test/java/org/apache/kafka/common/utils/TimerTest.java b/clients/src/test/java/org/apache/kafka/common/utils/TimerTest.java
new file mode 100644
index 0000000..ea48c5a
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/utils/TimerTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.utils;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class TimerTest {
+
+    private final MockTime time = new MockTime();
+
+    @Test
+    public void testTimerUpdate() {
+        Timer timer = time.timer(500);
+        assertEquals(500, timer.remainingMs());
+        assertEquals(0, timer.elapsedMs());
+
+        time.sleep(100);
+        timer.update();
+
+        assertEquals(400, timer.remainingMs());
+        assertEquals(100, timer.elapsedMs());
+
+        time.sleep(400);
+        timer.update(time.milliseconds());
+
+        assertEquals(0, timer.remainingMs());
+        assertEquals(500, timer.elapsedMs());
+        assertTrue(timer.isExpired());
+
+        // Going over the expiration is fine and the elapsed time can exceed
+        // the initial timeout. However, remaining time should be stuck at 0.
+        time.sleep(200);
+        timer.update(time.milliseconds());
+        assertTrue(timer.isExpired());
+        assertEquals(0, timer.remainingMs());
+        assertEquals(700, timer.elapsedMs());
+    }
+
+    @Test
+    public void testTimerUpdateAndReset() {
+        Timer timer = time.timer(500);
+        timer.sleep(200);
+        assertEquals(300, timer.remainingMs());
+        assertEquals(200, timer.elapsedMs());
+
+        timer.updateAndReset(400);
+        assertEquals(400, timer.remainingMs());
+        assertEquals(0, timer.elapsedMs());
+
+        timer.sleep(400);
+        assertTrue(timer.isExpired());
+
+        timer.updateAndReset(200);
+        assertEquals(200, timer.remainingMs());
+        assertEquals(0, timer.elapsedMs());
+        assertFalse(timer.isExpired());
+    }
+
+    @Test
+    public void testTimerResetUsesCurrentTime() {
+        Timer timer = time.timer(500);
+        timer.sleep(200);
+        assertEquals(300, timer.remainingMs());
+        assertEquals(200, timer.elapsedMs());
+
+        time.sleep(300);
+        timer.reset(500);
+        assertEquals(500, timer.remainingMs());
+
+        timer.update();
+        assertEquals(200, timer.remainingMs());
+    }
+
+    @Test
+    public void testTimeoutOverflow() {
+        Timer timer = time.timer(Long.MAX_VALUE);
+        assertEquals(Long.MAX_VALUE - timer.currentTimeMs(), timer.remainingMs());
+        assertEquals(0, timer.elapsedMs());
+    }
+
+    @Test
+    public void testNonMonotonicUpdate() {
+        Timer timer = time.timer(100);
+        long currentTimeMs = timer.currentTimeMs();
+
+        timer.update(currentTimeMs - 1);
+        assertEquals(currentTimeMs, timer.currentTimeMs());
+
+        assertEquals(100, timer.remainingMs());
+        assertEquals(0, timer.elapsedMs());
+    }
+
+    @Test
+    public void testTimerSleep() {
+        Timer timer = time.timer(500);
+        long currentTimeMs = timer.currentTimeMs();
+
+        timer.sleep(200);
+        assertEquals(time.milliseconds(), timer.currentTimeMs());
+        assertEquals(currentTimeMs + 200, timer.currentTimeMs());
+
+        timer.sleep(1000);
+        assertEquals(time.milliseconds(), timer.currentTimeMs());
+        assertEquals(currentTimeMs + 500, timer.currentTimeMs());
+        assertTrue(timer.isExpired());
+    }
+
+}
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
index 796ecdf..103a323 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
@@ -26,6 +26,7 @@ import org.apache.kafka.common.requests.JoinGroupRequest.ProtocolMetadata;
 import org.apache.kafka.common.utils.CircularIterator;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Timer;
 import org.apache.kafka.connect.storage.ConfigBackingStore;
 import org.apache.kafka.connect.util.ConnectorTaskId;
 import org.slf4j.Logger;
@@ -105,8 +106,8 @@ public final class WorkerCoordinator extends AbstractCoordinator implements Clos
 
     // expose for tests
     @Override
-    protected synchronized boolean ensureCoordinatorReady(final long timeoutMs) {
-        return super.ensureCoordinatorReady(timeoutMs);
+    protected synchronized boolean ensureCoordinatorReady(final Timer timer) {
+        return super.ensureCoordinatorReady(timer);
     }
 
     public void poll(long timeout) {
@@ -117,7 +118,7 @@ public final class WorkerCoordinator extends AbstractCoordinator implements Clos
 
         do {
             if (coordinatorUnknown()) {
-                ensureCoordinatorReady(Long.MAX_VALUE);
+                ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
                 now = time.milliseconds();
             }
 
@@ -133,7 +134,8 @@ public final class WorkerCoordinator extends AbstractCoordinator implements Clos
 
             // Note that because the network client is shared with the background heartbeat thread,
             // we do not want to block in poll longer than the time to the next heartbeat.
-            client.poll(Math.min(Math.max(0, remaining), timeToNextHeartbeat(now)));
+            long pollTimeout = Math.min(Math.max(0, remaining), timeToNextHeartbeat(now));
+            client.poll(time.timer(pollTimeout));
 
             now = time.milliseconds();
             elapsed = now - start;
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java
index e1017f2..ede6c71 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java
@@ -208,7 +208,7 @@ public class WorkerCoordinatorTest {
         final String consumerId = "leader";
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // normal join group
         Map<String, Long> memberConfigOffsets = new HashMap<>();
@@ -248,7 +248,7 @@ public class WorkerCoordinatorTest {
         final String memberId = "member";
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // normal join group
         client.prepareResponse(joinGroupFollowerResponse(1, memberId, "leader", Errors.NONE));
@@ -289,7 +289,7 @@ public class WorkerCoordinatorTest {
         final String memberId = "member";
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // config mismatch results in assignment error
         client.prepareResponse(joinGroupFollowerResponse(1, memberId, "leader", Errors.NONE));
@@ -320,7 +320,7 @@ public class WorkerCoordinatorTest {
         PowerMock.replayAll();
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
-        coordinator.ensureCoordinatorReady(Long.MAX_VALUE);
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // join the group once
         client.prepareResponse(joinGroupFollowerResponse(1, "consumer", "leader", Errors.NONE));
diff --git a/core/src/main/scala/kafka/admin/AdminClient.scala b/core/src/main/scala/kafka/admin/AdminClient.scala
index 1009bc5..239844d 100644
--- a/core/src/main/scala/kafka/admin/AdminClient.scala
+++ b/core/src/main/scala/kafka/admin/AdminClient.scala
@@ -59,7 +59,7 @@ class AdminClient(val time: Time,
     override def run() {
       try {
         while (running)
-          client.poll(Long.MaxValue)
+          client.poll(time.timer(Long.MaxValue))
       } catch {
         case t : Throwable =>
           error("admin-client-network-thread exited", t)


Mime
View raw message