kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From j...@apache.org
Subject [kafka] branch trunk updated: KAFKA-7831; Do not modify subscription state from background thread (#6221)
Date Fri, 08 Mar 2019 00:29:36 GMT
This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 460e46c  KAFKA-7831; Do not modify subscription state from background thread (#6221)
460e46c is described below

commit 460e46c3bb76a361d0706b263c03696005e12566
Author: Jason Gustafson <jason@confluent.io>
AuthorDate: Thu Mar 7 16:29:19 2019 -0800

    KAFKA-7831; Do not modify subscription state from background thread (#6221)
    
    Metadata may be updated from the background thread, so we need to protect access to SubscriptionState. This patch restructures the metadata handling so that we only check pattern subscriptions in the foreground. Additionally, it improves the following:
    
    1. SubscriptionState is now the source of truth for the topics that will be fetched. We had a lot of messy logic previously to try and keep the the topic set in Metadata consistent with the subscription, so this simplifies the logic.
    2. The metadata needs for the producer and consumer are quite different, so it made sense to separate the custom logic into separate extensions of Metadata. For example, only the producer requires topic expiration.
    3. We've always had an edge case in which a metadata change with an inflight request may cause us to effectively miss an expected update. This patch implements a separate version inside Metadata which is bumped when the needed topics changes.
    4. This patch removes the MetadataListener, which was the cause of https://issues.apache.org/jira/browse/KAFKA-7764.
    
    Reviewers: David Arthur <mumrah@gmail.com>, Rajini Sivaram <rajinisivaram@googlemail.com>
---
 .../java/org/apache/kafka/clients/Metadata.java    | 349 ++++-------
 .../org/apache/kafka/clients/NetworkClient.java    |  36 +-
 .../kafka/clients/consumer/ConsumerConfig.java     |   4 +-
 .../kafka/clients/consumer/KafkaConsumer.java      |  69 ++-
 .../consumer/internals/ConsumerCoordinator.java    | 194 +++---
 .../consumer/internals/ConsumerMetadata.java       |  77 +++
 .../consumer/internals/ConsumerNetworkClient.java  |   9 +-
 .../kafka/clients/consumer/internals/Fetcher.java  |  79 +--
 .../consumer/internals/SubscriptionState.java      | 101 ++-
 .../kafka/clients/producer/KafkaProducer.java      |  48 +-
 .../producer/internals/ProducerMetadata.java       | 129 ++++
 .../kafka/clients/producer/internals/Sender.java   |   4 +-
 .../kafka/common/requests/MetadataResponse.java    |  24 -
 .../org/apache/kafka/common/utils/SystemTime.java  |  20 +
 .../java/org/apache/kafka/common/utils/Time.java   |  15 +
 .../org/apache/kafka/clients/MetadataTest.java     | 370 +++--------
 .../java/org/apache/kafka/clients/MockClient.java  |  24 +-
 .../apache/kafka/clients/NetworkClientTest.java    |  13 +-
 .../kafka/clients/consumer/KafkaConsumerTest.java  | 187 +++---
 .../internals/AbstractCoordinatorTest.java         |  10 +-
 .../internals/ConsumerCoordinatorTest.java         | 100 ++-
 .../consumer/internals/ConsumerMetadataTest.java   | 164 +++++
 .../internals/ConsumerNetworkClientTest.java       |  40 +-
 .../clients/consumer/internals/FetcherTest.java    | 680 +++++++++++++--------
 .../kafka/clients/producer/KafkaProducerTest.java  |  69 ++-
 .../producer/internals/ProducerMetadataTest.java   | 205 +++++++
 .../clients/producer/internals/SenderTest.java     |  48 +-
 .../producer/internals/TransactionManagerTest.java |   5 +-
 .../org/apache/kafka/common/utils/MockTime.java    |  24 +
 .../apache/kafka/common/utils/MockTimeTest.java    |  31 +-
 .../apache/kafka/common/utils/SystemTimeTest.java} |  26 +-
 .../org/apache/kafka/common/utils/TimeTest.java    |  83 +++
 .../test/java/org/apache/kafka/test/TestUtils.java |   4 -
 .../runtime/distributed/WorkerGroupMember.java     |   4 +-
 .../runtime/distributed/WorkerCoordinatorTest.java |   9 +-
 core/src/main/scala/kafka/admin/AdminClient.scala  |  16 +-
 .../apache/kafka/streams/TopologyTestDriver.java   |   7 +
 37 files changed, 1958 insertions(+), 1319 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/Metadata.java b/clients/src/main/java/org/apache/kafka/clients/Metadata.java
index 9dc9b87..dfd461a 100644
--- a/clients/src/main/java/org/apache/kafka/clients/Metadata.java
+++ b/clients/src/main/java/org/apache/kafka/clients/Metadata.java
@@ -21,21 +21,22 @@ import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.AuthenticationException;
-import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.errors.InvalidMetadataException;
+import org.apache.kafka.common.errors.InvalidTopicException;
+import org.apache.kafka.common.errors.TopicAuthorizationException;
 import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.record.RecordBatch;
+import org.apache.kafka.common.requests.MetadataRequest;
 import org.apache.kafka.common.requests.MetadataResponse;
+import org.apache.kafka.common.utils.LogContext;
 import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 import java.io.Closeable;
 import java.net.InetSocketAddress;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -43,7 +44,6 @@ import java.util.Optional;
 import java.util.Set;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
-import java.util.stream.Collectors;
 
 /**
  * A class encapsulating some of the logic around metadata.
@@ -58,63 +58,43 @@ import java.util.stream.Collectors;
  * manage topics while producers rely on topic expiry to limit the refresh set.
  */
 public class Metadata implements Closeable {
-
-    private static final Logger log = LoggerFactory.getLogger(Metadata.class);
-
-    public static final long TOPIC_EXPIRY_MS = 5 * 60 * 1000;
-    private static final long TOPIC_EXPIRY_NEEDS_UPDATE = -1L;
-
+    private final Logger log;
     private final long refreshBackoffMs;
     private final long metadataExpireMs;
-    private int version;
+    private int updateVersion;  // bumped on every metadata response
+    private int requestVersion; // bumped on every new topic addition
     private long lastRefreshMs;
     private long lastSuccessfulRefreshMs;
     private AuthenticationException authenticationException;
+    private KafkaException metadataException;
     private MetadataCache cache = MetadataCache.empty();
     private boolean needUpdate;
-    /* Topics with expiry time */
-    private final Map<String, Long> topics;
-    private final List<Listener> listeners;
     private final ClusterResourceListeners clusterResourceListeners;
-    private boolean needMetadataForAllTopics;
-    private final boolean allowAutoTopicCreation;
-    private final boolean topicExpiryEnabled;
     private boolean isClosed;
     private final Map<TopicPartition, Integer> lastSeenLeaderEpochs;
 
-    public Metadata(long refreshBackoffMs,
-                    long metadataExpireMs,
-                    boolean allowAutoTopicCreation) {
-        this(refreshBackoffMs, metadataExpireMs, allowAutoTopicCreation, false, new ClusterResourceListeners());
-    }
-
     /**
      * Create a new Metadata instance
-     * @param refreshBackoffMs The minimum amount of time that must expire between metadata refreshes to avoid busy
-     *        polling
-     * @param metadataExpireMs The maximum amount of time that metadata can be retained without refresh
-     * @param allowAutoTopicCreation If this and the broker config 'auto.create.topics.enable' are true, topics that
-     *                               don't exist will be created by the broker when a metadata request is sent
-     * @param topicExpiryEnabled If true, enable expiry of unused topics
+     *
+     * @param refreshBackoffMs         The minimum amount of time that must expire between metadata refreshes to avoid busy
+     *                                 polling
+     * @param metadataExpireMs         The maximum amount of time that metadata can be retained without refresh
+     * @param logContext               Log context corresponding to the containing client
      * @param clusterResourceListeners List of ClusterResourceListeners which will receive metadata updates.
      */
     public Metadata(long refreshBackoffMs,
                     long metadataExpireMs,
-                    boolean allowAutoTopicCreation,
-                    boolean topicExpiryEnabled,
+                    LogContext logContext,
                     ClusterResourceListeners clusterResourceListeners) {
+        this.log = logContext.logger(Metadata.class);
         this.refreshBackoffMs = refreshBackoffMs;
         this.metadataExpireMs = metadataExpireMs;
-        this.allowAutoTopicCreation = allowAutoTopicCreation;
-        this.topicExpiryEnabled = topicExpiryEnabled;
         this.lastRefreshMs = 0L;
         this.lastSuccessfulRefreshMs = 0L;
-        this.version = 0;
+        this.requestVersion = 0;
+        this.updateVersion = 0;
         this.needUpdate = false;
-        this.topics = new HashMap<>();
-        this.listeners = new ArrayList<>();
         this.clusterResourceListeners = clusterResourceListeners;
-        this.needMetadataForAllTopics = false;
         this.isClosed = false;
         this.lastSeenLeaderEpochs = new HashMap<>();
     }
@@ -127,17 +107,6 @@ public class Metadata implements Closeable {
     }
 
     /**
-     * Add the topic to maintain in the metadata. If topic expiry is enabled, expiry time
-     * will be reset on the next update.
-     */
-    public synchronized void add(String topic) {
-        Objects.requireNonNull(topic, "topic cannot be null");
-        if (topics.put(topic, TOPIC_EXPIRY_NEEDS_UPDATE) == null) {
-            requestUpdateForNewTopics();
-        }
-    }
-
-    /**
      * Return the next time when the current cluster info can be updated (i.e., backoff time has elapsed).
      *
      * @param nowMs current time in ms
@@ -161,11 +130,11 @@ public class Metadata implements Closeable {
     }
 
     /**
-     * Request an update of the current cluster metadata info, return the current version before the update
+     * Request an update of the current cluster metadata info, return the current updateVersion before the update
      */
     public synchronized int requestUpdate() {
         this.needUpdate = true;
-        return this.version;
+        return this.updateVersion;
     }
 
     /**
@@ -184,9 +153,9 @@ public class Metadata implements Closeable {
     /**
      * Conditionally update the leader epoch for a partition
      *
-     * @param topicPartition topic+partition to update the epoch for
-     * @param epoch the new epoch
-     * @param epochTest a predicate to determine if the old epoch should be replaced
+     * @param topicPartition       topic+partition to update the epoch for
+     * @param epoch                the new epoch
+     * @param epochTest            a predicate to determine if the old epoch should be replaced
      * @param setRequestUpdateFlag sets the "needUpdate" flag to true if the epoch is updated
      * @return true if the epoch was updated, false otherwise
      */
@@ -211,6 +180,7 @@ public class Metadata implements Closeable {
 
     /**
      * Check whether an update has been explicitly requested.
+     *
      * @return true if an update was requested, false otherwise
      */
     public synchronized boolean updateRequested() {
@@ -243,140 +213,101 @@ public class Metadata implements Closeable {
             return null;
     }
 
-    /**
-     * Wait for metadata update until the current version is larger than the last version we know of
-     */
-    public synchronized void awaitUpdate(final int lastVersion, final long maxWaitMs) throws InterruptedException {
-        if (maxWaitMs < 0)
-            throw new IllegalArgumentException("Max time to wait for metadata updates should not be < 0 milliseconds");
-
-        long begin = System.currentTimeMillis();
-        long remainingWaitMs = maxWaitMs;
-        while ((this.version <= lastVersion) && !isClosed()) {
-            AuthenticationException ex = getAndClearAuthenticationException();
-            if (ex != null)
-                throw ex;
-            if (remainingWaitMs != 0)
-                wait(remainingWaitMs);
-            long elapsed = System.currentTimeMillis() - begin;
-            if (elapsed >= maxWaitMs)
-                throw new TimeoutException("Failed to update metadata after " + maxWaitMs + " ms.");
-            remainingWaitMs = maxWaitMs - elapsed;
-        }
-        if (isClosed())
-            throw new KafkaException("Requested metadata update after close");
-    }
-
-    /**
-     * Replace the current set of topics maintained to the one provided.
-     * If topic expiry is enabled, expiry time of the topics will be
-     * reset on the next update.
-     * @param topics
-     */
-    public synchronized void setTopics(Collection<String> topics) {
-        Set<TopicPartition> partitionsToRemove = lastSeenLeaderEpochs.keySet()
-                .stream()
-                .filter(tp -> !topics.contains(tp.topic()))
-                .collect(Collectors.toSet());
-        partitionsToRemove.forEach(lastSeenLeaderEpochs::remove);
-
-        cache.retainTopics(topics);
-
-        if (!this.topics.keySet().containsAll(topics)) {
-            requestUpdateForNewTopics();
-        }
-        this.topics.clear();
-        for (String topic : topics)
-            this.topics.put(topic, TOPIC_EXPIRY_NEEDS_UPDATE);
-    }
-
-    /**
-     * Get the list of topics we are currently maintaining metadata for
-     */
-    public synchronized Set<String> topics() {
-        return new HashSet<>(this.topics.keySet());
-    }
-
-    /**
-     * Check if a topic is already in the topic set.
-     * @param topic topic to check
-     * @return true if the topic exists, false otherwise
-     */
-    public synchronized boolean containsTopic(String topic) {
-        return this.topics.containsKey(topic);
+    synchronized KafkaException getAndClearMetadataException() {
+        if (this.metadataException != null) {
+            KafkaException metadataException = this.metadataException;
+            this.metadataException = null;
+            return metadataException;
+        } else
+            return null;
     }
 
     public synchronized void bootstrap(List<InetSocketAddress> addresses, long now) {
         this.needUpdate = true;
         this.lastRefreshMs = now;
         this.lastSuccessfulRefreshMs = now;
-        this.version += 1;
+        this.updateVersion += 1;
         this.cache = MetadataCache.bootstrap(addresses);
     }
 
     /**
+     * Update metadata assuming the current request version. This is mainly for convenience in testing.
+     */
+    public synchronized void update(MetadataResponse response, long now) {
+        this.update(this.requestVersion, response, now);
+    }
+
+    /**
      * Updates the cluster metadata. If topic expiry is enabled, expiry time
      * is set for topics if required and expired topics are removed from the metadata.
      *
-     * @param metadataResponse metadata response received from the broker
+     * @param requestVersion The request version corresponding to the update response, as provided by
+     *     {@link #newMetadataRequestAndVersion()}.
+     * @param response metadata response received from the broker
      * @param now current time in milliseconds
      */
-    public synchronized void update(MetadataResponse metadataResponse, long now) {
-        Objects.requireNonNull(metadataResponse, "Metadata response cannot be null");
+    public synchronized void update(int requestVersion, MetadataResponse response, long now) {
+        Objects.requireNonNull(response, "Metadata response cannot be null");
         if (isClosed())
             throw new IllegalStateException("Update requested after metadata close");
 
-        this.needUpdate = false;
+        if (requestVersion == this.requestVersion)
+            this.needUpdate = false;
+        else
+            requestUpdate();
+
         this.lastRefreshMs = now;
         this.lastSuccessfulRefreshMs = now;
-        this.version += 1;
-
-        if (topicExpiryEnabled) {
-            // Handle expiry of topics from the metadata refresh set.
-            for (Iterator<Map.Entry<String, Long>> it = topics.entrySet().iterator(); it.hasNext(); ) {
-                Map.Entry<String, Long> entry = it.next();
-                long expireMs = entry.getValue();
-                if (expireMs == TOPIC_EXPIRY_NEEDS_UPDATE)
-                    entry.setValue(now + TOPIC_EXPIRY_MS);
-                else if (expireMs <= now) {
-                    it.remove();
-                    log.debug("Removing unused topic {} from the metadata list, expiryMs {} now {}", entry.getKey(), expireMs, now);
-                }
-            }
-        }
+        this.updateVersion += 1;
 
         String previousClusterId = cache.cluster().clusterResource().clusterId();
 
-        this.cache = handleMetadataResponse(metadataResponse, topic -> true);
-        Set<String> unavailableTopics = metadataResponse.unavailableTopics();
-        Cluster clusterForListeners = this.cache.cluster();
-        fireListeners(clusterForListeners, unavailableTopics);
+        this.cache = handleMetadataResponse(response, topic -> retainTopic(topic.topic(), topic.isInternal(), now));
 
-        if (this.needMetadataForAllTopics) {
-            // the listener may change the interested topics, which could cause another metadata refresh.
-            // If we have already fetched all topics, however, another fetch should be unnecessary.
-            this.needUpdate = false;
-            this.cache = handleMetadataResponse(metadataResponse, topics.keySet()::contains);
-        }
+        Cluster cluster = cache.cluster();
+        maybeSetMetadataError(cluster);
+
+        this.lastSeenLeaderEpochs.keySet().removeIf(tp -> !retainTopic(tp.topic(), false, now));
 
         String newClusterId = cache.cluster().clusterResource().clusterId();
         if (!Objects.equals(previousClusterId, newClusterId)) {
             log.info("Cluster ID: {}", newClusterId);
         }
-        clusterResourceListeners.onUpdate(clusterForListeners.clusterResource());
+        clusterResourceListeners.onUpdate(cache.cluster().clusterResource());
 
-        notifyAll();
-        log.debug("Updated cluster metadata version {} to {}", this.version, this.cache);
+        log.debug("Updated cluster metadata updateVersion {} to {}", this.updateVersion, this.cache);
+    }
+
+    private void maybeSetMetadataError(Cluster cluster) {
+        // if we encounter any invalid topics, cache the exception to later throw to the user
+        metadataException = null;
+        checkInvalidTopics(cluster);
+        checkUnauthorizedTopics(cluster);
+    }
+
+    private void checkInvalidTopics(Cluster cluster) {
+        if (!cluster.invalidTopics().isEmpty()) {
+            log.error("Metadata response reported invalid topics {}", cluster.invalidTopics());
+            metadataException = new InvalidTopicException(cluster.invalidTopics());
+        }
+    }
+
+    private void checkUnauthorizedTopics(Cluster cluster) {
+        if (!cluster.unauthorizedTopics().isEmpty()) {
+            log.error("Topic authorization failed for topics {}", cluster.unauthorizedTopics());
+            metadataException = new TopicAuthorizationException(new HashSet<>(cluster.unauthorizedTopics()));
+        }
     }
 
     /**
      * Transform a MetadataResponse into a new MetadataCache instance.
      */
-    private MetadataCache handleMetadataResponse(MetadataResponse metadataResponse, Predicate<String> topicsToRetain) {
+    private MetadataCache handleMetadataResponse(MetadataResponse metadataResponse,
+                                                 Predicate<MetadataResponse.TopicMetadata> topicsToRetain) {
         Set<String> internalTopics = new HashSet<>();
         List<MetadataCache.PartitionInfoAndEpoch> partitions = new ArrayList<>();
         for (MetadataResponse.TopicMetadata metadata : metadataResponse.topicMetadata()) {
-            if (!topicsToRetain.test(metadata.topic()))
+            if (!topicsToRetain.test(metadata))
                 continue;
 
             if (metadata.error() == Errors.NONE) {
@@ -387,7 +318,16 @@ public class Metadata implements Closeable {
                         int epoch = partitionMetadata.leaderEpoch().orElse(RecordBatch.NO_PARTITION_LEADER_EPOCH);
                         partitions.add(new MetadataCache.PartitionInfoAndEpoch(partitionInfo, epoch));
                     });
+
+                    if (partitionMetadata.error().exception() instanceof InvalidMetadataException) {
+                        log.debug("Requesting metadata update for partition {} due to error {}",
+                                new TopicPartition(metadata.topic(), partitionMetadata.partition()), partitionMetadata.error());
+                        requestUpdate();
+                    }
                 }
+            } else if (metadata.error().exception() instanceof InvalidMetadataException) {
+                log.debug("Requesting metadata update for topic {} due to error {}", metadata.topic(), metadata.error());
+                requestUpdate();
             }
         }
 
@@ -415,14 +355,6 @@ public class Metadata implements Closeable {
                 PartitionInfo previousInfo = cache.cluster().partition(tp);
                 if (previousInfo != null) {
                     partitionInfoConsumer.accept(previousInfo);
-                } else {
-                    if (containsTopic(topic)) {
-                        log.debug("Got an older epoch in partition metadata response for {}, but we are not tracking this topic. " +
-                                "Ignoring metadata update for this partition", tp);
-                    } else {
-                        log.warn("Got an older epoch in partition metadata response for {}, but could not find previous partition " +
-                                "info to use. Refusing to update metadata for this partition", tp);
-                    }
                 }
             }
         } else {
@@ -432,9 +364,14 @@ public class Metadata implements Closeable {
         }
     }
 
-    private void fireListeners(Cluster newCluster, Set<String> unavailableTopics) {
-        for (Listener listener: listeners)
-            listener.onMetadataUpdate(newCluster, unavailableTopics);
+    public synchronized void maybeThrowException() {
+        AuthenticationException authenticationException = getAndClearAuthenticationException();
+        if (authenticationException != null)
+            throw authenticationException;
+
+        KafkaException metadataException = getAndClearMetadataException();
+        if (metadataException != null)
+            throw metadataException;
     }
 
     /**
@@ -444,15 +381,13 @@ public class Metadata implements Closeable {
     public synchronized void failedUpdate(long now, AuthenticationException authenticationException) {
         this.lastRefreshMs = now;
         this.authenticationException = authenticationException;
-        if (authenticationException != null)
-            this.notifyAll();
     }
 
     /**
-     * @return The current metadata version
+     * @return The current metadata updateVersion
      */
-    public synchronized int version() {
-        return this.version;
+    public synchronized int updateVersion() {
+        return this.updateVersion;
     }
 
     /**
@@ -462,79 +397,51 @@ public class Metadata implements Closeable {
         return this.lastSuccessfulRefreshMs;
     }
 
-    public boolean allowAutoTopicCreation() {
-        return allowAutoTopicCreation;
-    }
-
-    /**
-     * Set state to indicate if metadata for all topics in Kafka cluster is required or not.
-     * @param needMetadataForAllTopics boolean indicating need for metadata of all topics in cluster.
-     */
-    public synchronized void needMetadataForAllTopics(boolean needMetadataForAllTopics) {
-        if (needMetadataForAllTopics && !this.needMetadataForAllTopics) {
-            requestUpdateForNewTopics();
-        }
-        this.needMetadataForAllTopics = needMetadataForAllTopics;
-    }
-
-    /**
-     * Get whether metadata for all topics is needed or not
-     */
-    public synchronized boolean needMetadataForAllTopics() {
-        return this.needMetadataForAllTopics;
-    }
-
-    /**
-     * Add a Metadata listener that gets notified of metadata updates
-     */
-    public synchronized void addListener(Listener listener) {
-        this.listeners.add(listener);
-    }
-
-    /**
-     * Stop notifying the listener of metadata updates
-     */
-    public synchronized void removeListener(Listener listener) {
-        this.listeners.remove(listener);
-    }
-
     /**
-     * "Close" this metadata instance to indicate that metadata updates are no longer possible. This is typically used
-     * when the thread responsible for performing metadata updates is exiting and needs a way to relay this information
-     * to any other thread(s) that could potentially wait on metadata update to come through.
+     * Close this metadata instance to indicate that metadata updates are no longer possible.
      */
     @Override
     public synchronized void close() {
         this.isClosed = true;
-        this.notifyAll();
     }
 
     /**
      * Check if this metadata instance has been closed. See {@link #close()} for more information.
+     *
      * @return True if this instance has been closed; false otherwise
      */
     public synchronized boolean isClosed() {
         return this.isClosed;
     }
 
-    /**
-     * MetadataUpdate Listener
-     */
-    public interface Listener {
-        /**
-         * Callback invoked on metadata update.
-         *
-         * @param cluster the cluster containing metadata for topics with valid metadata
-         * @param unavailableTopics topics which are non-existent or have one or more partitions whose
-         *        leader is not known
-         */
-        void onMetadataUpdate(Cluster cluster, Set<String> unavailableTopics);
-    }
-
-    private synchronized void requestUpdateForNewTopics() {
+    public synchronized void requestUpdateForNewTopics() {
         // Override the timestamp of last refresh to let immediate update.
         this.lastRefreshMs = 0;
+        this.requestVersion++;
         requestUpdate();
     }
 
+    public synchronized MetadataRequestAndVersion newMetadataRequestAndVersion() {
+        return new MetadataRequestAndVersion(newMetadataRequestBuilder(), requestVersion);
+    }
+
+    protected MetadataRequest.Builder newMetadataRequestBuilder() {
+        return MetadataRequest.Builder.allTopics();
+    }
+
+    protected boolean retainTopic(String topic, boolean isInternal, long nowMs) {
+        return true;
+    }
+
+    public static class MetadataRequestAndVersion {
+        public final MetadataRequest.Builder requestBuilder;
+        public final int requestVersion;
+
+        private MetadataRequestAndVersion(MetadataRequest.Builder requestBuilder,
+                                          int requestVersion) {
+            this.requestBuilder = requestBuilder;
+            this.requestVersion = requestVersion;
+        }
+    }
+
 }
diff --git a/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java b/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java
index 44446b3..e334295 100644
--- a/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java
+++ b/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java
@@ -937,12 +937,12 @@ public class NetworkClient implements KafkaClient {
         /* the current cluster metadata */
         private final Metadata metadata;
 
-        /* true iff there is a metadata request that has been sent and for which we have not yet received a response */
-        private boolean metadataFetchInProgress;
+        // Defined if there is a request in progress, null otherwise
+        private Integer inProgressRequestVersion;
 
         DefaultMetadataUpdater(Metadata metadata) {
             this.metadata = metadata;
-            this.metadataFetchInProgress = false;
+            this.inProgressRequestVersion = null;
         }
 
         @Override
@@ -952,14 +952,18 @@ public class NetworkClient implements KafkaClient {
 
         @Override
         public boolean isUpdateDue(long now) {
-            return !this.metadataFetchInProgress && this.metadata.timeToNextUpdate(now) == 0;
+            return !hasFetchInProgress() && this.metadata.timeToNextUpdate(now) == 0;
+        }
+
+        private boolean hasFetchInProgress() {
+            return inProgressRequestVersion != null;
         }
 
         @Override
         public long maybeUpdate(long now) {
             // should we update our metadata?
             long timeToNextMetadataUpdate = metadata.timeToNextUpdate(now);
-            long waitForMetadataFetch = this.metadataFetchInProgress ? defaultRequestTimeoutMs : 0;
+            long waitForMetadataFetch = hasFetchInProgress() ? defaultRequestTimeoutMs : 0;
 
             long metadataTimeout = Math.max(timeToNextMetadataUpdate, waitForMetadataFetch);
 
@@ -992,20 +996,18 @@ public class NetworkClient implements KafkaClient {
                     log.warn("Bootstrap broker {} disconnected", node);
             }
 
-            metadataFetchInProgress = false;
+            inProgressRequestVersion = null;
         }
 
         @Override
         public void handleAuthenticationFailure(AuthenticationException exception) {
-            metadataFetchInProgress = false;
             if (metadata.updateRequested())
                 metadata.failedUpdate(time.milliseconds(), exception);
+            inProgressRequestVersion = null;
         }
 
         @Override
         public void handleCompletedMetadataResponse(RequestHeader requestHeader, long now, MetadataResponse response) {
-            this.metadataFetchInProgress = false;
-
             // If any partition has leader with missing listeners, log a few for diagnosing broker configuration
             // issues. This could be a transient issue if listeners were added dynamically to brokers.
             List<TopicPartition> missingListenerPartitions = response.topicMetadata().stream().flatMap(topicMetadata ->
@@ -1030,8 +1032,10 @@ public class NetworkClient implements KafkaClient {
                 log.trace("Ignoring empty metadata response with correlation id {}.", requestHeader.correlationId());
                 this.metadata.failedUpdate(now, null);
             } else {
-                this.metadata.update(response, now);
+                this.metadata.update(inProgressRequestVersion, response, now);
             }
+
+            inProgressRequestVersion = null;
         }
 
         @Override
@@ -1063,15 +1067,9 @@ public class NetworkClient implements KafkaClient {
             String nodeConnectionId = node.idString();
 
             if (canSendRequest(nodeConnectionId, now)) {
-                this.metadataFetchInProgress = true;
-                MetadataRequest.Builder metadataRequest;
-                if (metadata.needMetadataForAllTopics())
-                    metadataRequest = MetadataRequest.Builder.allTopics();
-                else
-                    metadataRequest = new MetadataRequest.Builder(new ArrayList<>(metadata.topics()),
-                            metadata.allowAutoTopicCreation());
-
-
+                Metadata.MetadataRequestAndVersion requestAndVersion = metadata.newMetadataRequestAndVersion();
+                this.inProgressRequestVersion = requestAndVersion.requestVersion;
+                MetadataRequest.Builder metadataRequest = requestAndVersion.requestBuilder;
                 log.debug("Sending metadata request {} to node {}", metadataRequest, node);
                 sendInternalMetadataRequest(metadataRequest, nodeConnectionId, now);
                 return defaultRequestTimeoutMs;
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java
index 9cd5766..b92cbf9 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java
@@ -235,8 +235,8 @@ public class ConsumerConfig extends AbstractConfig {
 
     /** <code>exclude.internal.topics</code> */
     public static final String EXCLUDE_INTERNAL_TOPICS_CONFIG = "exclude.internal.topics";
-    private static final String EXCLUDE_INTERNAL_TOPICS_DOC = "Whether records from internal topics (such as offsets) should be exposed to the consumer. "
-                                                            + "If set to <code>true</code> the only way to receive records from an internal topic is subscribing to it.";
+    private static final String EXCLUDE_INTERNAL_TOPICS_DOC = "Whether internal topics matching a subscribed pattern should " +
+            "be excluded from the subscription. It is always possible to explicitly subscribe to an internal topic.";
     public static final boolean DEFAULT_EXCLUDE_INTERNAL_TOPICS = true;
 
     /**
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
index 7f5b2c0..4cee56a 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
@@ -19,13 +19,13 @@ package org.apache.kafka.clients.consumer;
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.ClientDnsLookup;
 import org.apache.kafka.clients.ClientUtils;
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.clients.consumer.internals.ConsumerCoordinator;
 import org.apache.kafka.clients.consumer.internals.ConsumerInterceptors;
-import org.apache.kafka.clients.consumer.internals.ConsumerMetrics;
+import org.apache.kafka.clients.consumer.internals.ConsumerMetadata;
 import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient;
 import org.apache.kafka.clients.consumer.internals.Fetcher;
+import org.apache.kafka.clients.consumer.internals.FetcherMetricsRegistry;
 import org.apache.kafka.clients.consumer.internals.Heartbeat;
 import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.internals.PartitionAssignor;
@@ -553,6 +553,7 @@ import java.util.regex.Pattern;
  */
 public class KafkaConsumer<K, V> implements Consumer<K, V> {
 
+    private static final String CLIENT_ID_METRIC_TAG = "client-id";
     private static final long NO_CURRENT_THREAD = -1L;
     private static final AtomicInteger CONSUMER_CLIENT_ID_SEQUENCE = new AtomicInteger(1);
     private static final String JMX_PREFIX = "kafka.consumer";
@@ -573,7 +574,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
     private final Time time;
     private final ConsumerNetworkClient client;
     private final SubscriptionState subscriptions;
-    private final Metadata metadata;
+    private final ConsumerMetadata metadata;
     private final long retryBackoffMs;
     private final long requestTimeoutMs;
     private final int defaultApiTimeoutMs;
@@ -683,16 +684,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
             this.requestTimeoutMs = config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG);
             this.defaultApiTimeoutMs = config.getInt(ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG);
             this.time = Time.SYSTEM;
-
-            Map<String, String> metricsTags = Collections.singletonMap("client-id", clientId);
-            MetricConfig metricConfig = new MetricConfig().samples(config.getInt(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG))
-                    .timeWindow(config.getLong(ConsumerConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG), TimeUnit.MILLISECONDS)
-                    .recordLevel(Sensor.RecordingLevel.forName(config.getString(ConsumerConfig.METRICS_RECORDING_LEVEL_CONFIG)))
-                    .tags(metricsTags);
-            List<MetricsReporter> reporters = config.getConfiguredInstances(ConsumerConfig.METRIC_REPORTER_CLASSES_CONFIG,
-                    MetricsReporter.class, Collections.singletonMap(ConsumerConfig.CLIENT_ID_CONFIG, clientId));
-            reporters.add(new JmxReporter(JMX_PREFIX));
-            this.metrics = new Metrics(metricConfig, reporters, time);
+            this.metrics = buildMetrics(config, time, clientId);
             this.retryBackoffMs = config.getLong(ConsumerConfig.RETRY_BACKOFF_MS_CONFIG);
 
             // load interceptors and make sure they get clientId
@@ -715,18 +707,24 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                 config.ignore(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG);
                 this.valueDeserializer = valueDeserializer;
             }
-            ClusterResourceListeners clusterResourceListeners = configureClusterResourceListeners(keyDeserializer, valueDeserializer, reporters, interceptorList);
-            this.metadata = new Metadata(retryBackoffMs, config.getLong(ConsumerConfig.METADATA_MAX_AGE_CONFIG),
-                    true, false, clusterResourceListeners);
+            OffsetResetStrategy offsetResetStrategy = OffsetResetStrategy.valueOf(config.getString(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).toUpperCase(Locale.ROOT));
+            this.subscriptions = new SubscriptionState(logContext, offsetResetStrategy);
+            ClusterResourceListeners clusterResourceListeners = configureClusterResourceListeners(keyDeserializer,
+                    valueDeserializer, metrics.reporters(), interceptorList);
+            this.metadata = new ConsumerMetadata(retryBackoffMs,
+                    config.getLong(ConsumerConfig.METADATA_MAX_AGE_CONFIG),
+                    !config.getBoolean(ConsumerConfig.EXCLUDE_INTERNAL_TOPICS_CONFIG),
+                    subscriptions, logContext, clusterResourceListeners);
             List<InetSocketAddress> addresses = ClientUtils.parseAndValidateAddresses(
                     config.getList(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG), config.getString(ConsumerConfig.CLIENT_DNS_LOOKUP_CONFIG));
             this.metadata.bootstrap(addresses, time.milliseconds());
             String metricGrpPrefix = "consumer";
-            ConsumerMetrics metricsRegistry = new ConsumerMetrics(metricsTags.keySet(), "consumer");
+
+            FetcherMetricsRegistry metricsRegistry = new FetcherMetricsRegistry(Collections.singleton(CLIENT_ID_METRIC_TAG), metricGrpPrefix);
             ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(config, time);
             IsolationLevel isolationLevel = IsolationLevel.valueOf(
                     config.getString(ConsumerConfig.ISOLATION_LEVEL_CONFIG).toUpperCase(Locale.ROOT));
-            Sensor throttleTimeSensor = Fetcher.throttleTimeSensor(metrics, metricsRegistry.fetcherMetrics);
+            Sensor throttleTimeSensor = Fetcher.throttleTimeSensor(metrics, metricsRegistry);
             int heartbeatIntervalMs = config.getInt(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG);
 
             NetworkClient netClient = new NetworkClient(
@@ -753,8 +751,6 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                     retryBackoffMs,
                     config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG),
                     heartbeatIntervalMs); //Will avoid blocking an extended period of time to prevent heartbeat thread starvation
-            OffsetResetStrategy offsetResetStrategy = OffsetResetStrategy.valueOf(config.getString(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).toUpperCase(Locale.ROOT));
-            this.subscriptions = new SubscriptionState(logContext, offsetResetStrategy);
             this.assignors = config.getConfiguredInstances(
                     ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG,
                     PartitionAssignor.class);
@@ -779,7 +775,6 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                         enableAutoCommit,
                         config.getInt(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG),
                         this.interceptors,
-                        config.getBoolean(ConsumerConfig.EXCLUDE_INTERNAL_TOPICS_CONFIG),
                         config.getBoolean(ConsumerConfig.LEAVE_GROUP_ON_CLOSE_CONFIG));
             this.fetcher = new Fetcher<>(
                     logContext,
@@ -795,7 +790,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                     this.metadata,
                     this.subscriptions,
                     metrics,
-                    metricsRegistry.fetcherMetrics,
+                    metricsRegistry,
                     this.time,
                     this.retryBackoffMs,
                     this.requestTimeoutMs,
@@ -824,7 +819,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                   ConsumerNetworkClient client,
                   Metrics metrics,
                   SubscriptionState subscriptions,
-                  Metadata metadata,
+                  ConsumerMetadata metadata,
                   long retryBackoffMs,
                   long requestTimeoutMs,
                   int defaultApiTimeoutMs,
@@ -849,6 +844,18 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
         this.groupId = groupId;
     }
 
+    private static Metrics buildMetrics(ConsumerConfig config, Time time, String clientId) {
+        Map<String, String> metricsTags = Collections.singletonMap(CLIENT_ID_METRIC_TAG, clientId);
+        MetricConfig metricConfig = new MetricConfig().samples(config.getInt(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG))
+                .timeWindow(config.getLong(ConsumerConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG), TimeUnit.MILLISECONDS)
+                .recordLevel(Sensor.RecordingLevel.forName(config.getString(ConsumerConfig.METRICS_RECORDING_LEVEL_CONFIG)))
+                .tags(metricsTags);
+        List<MetricsReporter> reporters = config.getConfiguredInstances(ConsumerConfig.METRIC_REPORTER_CLASSES_CONFIG,
+                MetricsReporter.class, Collections.singletonMap(ConsumerConfig.CLIENT_ID_CONFIG, clientId));
+        reporters.add(new JmxReporter(JMX_PREFIX));
+        return new Metrics(metricConfig, reporters, time);
+    }
+
     /**
      * Get the set of partitions currently assigned to this consumer. If subscription happened by directly assigning
      * partitions using {@link #assign(Collection)} then this will simply return the same partitions that
@@ -934,8 +941,8 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                 throwIfNoAssignorsConfigured();
                 fetcher.clearBufferedDataForUnassignedTopics(topics);
                 log.info("Subscribed to topic(s): {}", Utils.join(topics, ", "));
-                this.subscriptions.subscribe(new HashSet<>(topics), listener);
-                metadata.setTopics(subscriptions.groupSubscription());
+                if (this.subscriptions.subscribe(new HashSet<>(topics), listener))
+                    metadata.requestUpdateForNewTopics();
             }
         } finally {
             release();
@@ -998,9 +1005,8 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
             throwIfNoAssignorsConfigured();
             log.info("Subscribed to pattern: '{}'", pattern);
             this.subscriptions.subscribe(pattern, listener);
-            this.metadata.needMetadataForAllTopics(true);
             this.coordinator.updatePatternSubscription(metadata.fetch());
-            this.metadata.requestUpdate();
+            this.metadata.requestUpdateForNewTopics();
         } finally {
             release();
         }
@@ -1038,7 +1044,6 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
             this.subscriptions.unsubscribe();
             if (this.coordinator != null)
                 this.coordinator.maybeLeaveGroup();
-            this.metadata.needMetadataForAllTopics(false);
             log.info("Unsubscribed all topics or patterns and assigned partitions");
         } finally {
             release();
@@ -1073,12 +1078,10 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
             } else if (partitions.isEmpty()) {
                 this.unsubscribe();
             } else {
-                Set<String> topics = new HashSet<>();
                 for (TopicPartition tp : partitions) {
                     String topic = (tp != null) ? tp.topic() : null;
                     if (topic == null || topic.trim().isEmpty())
                         throw new IllegalArgumentException("Topic partitions to assign to cannot have null or empty topic");
-                    topics.add(topic);
                 }
                 fetcher.clearBufferedDataForUnassignedPartitions(partitions);
 
@@ -1088,8 +1091,8 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                     this.coordinator.maybeAutoCommitOffsetsAsync(time.milliseconds());
 
                 log.info("Subscribed to partition(s): {}", Utils.join(partitions, ", "));
-                this.subscriptions.assignFromUser(new HashSet<>(partitions));
-                metadata.setTopics(topics);
+                if (this.subscriptions.assignFromUser(new HashSet<>(partitions)))
+                    metadata.requestUpdateForNewTopics();
             }
         } finally {
             release();
@@ -1962,7 +1965,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                     throw new IllegalArgumentException("The target time for partition " + entry.getKey() + " is " +
                             entry.getValue() + ". The target time cannot be negative.");
             }
-            return fetcher.offsetsByTimes(timestampsToSearch, time.timer(timeout));
+            return fetcher.offsetsForTimes(timestampsToSearch, time.timer(timeout));
         } finally {
             release();
         }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index 2fb6fb6..7990707 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.clients.consumer.internals;
 
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.consumer.CommitFailedException;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
@@ -30,7 +29,6 @@ import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.GroupAuthorizationException;
 import org.apache.kafka.common.errors.InterruptException;
-import org.apache.kafka.common.errors.InvalidTopicException;
 import org.apache.kafka.common.errors.RetriableException;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.errors.TopicAuthorizationException;
@@ -72,14 +70,13 @@ import java.util.stream.Collectors;
 public final class ConsumerCoordinator extends AbstractCoordinator {
     private final Logger log;
     private final List<PartitionAssignor> assignors;
-    private final Metadata metadata;
+    private final ConsumerMetadata metadata;
     private final ConsumerCoordinatorMetrics sensors;
     private final SubscriptionState subscriptions;
     private final OffsetCommitCallback defaultOffsetCommitCallback;
     private final boolean autoCommitEnabled;
     private final int autoCommitIntervalMs;
     private final ConsumerInterceptors<?, ?> interceptors;
-    private final boolean excludeInternalTopics;
     private final AtomicInteger pendingAsyncCommits;
 
     // this collection must be thread-safe because it is modified from the response handler
@@ -123,7 +120,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                                int sessionTimeoutMs,
                                Heartbeat heartbeat,
                                List<PartitionAssignor> assignors,
-                               Metadata metadata,
+                               ConsumerMetadata metadata,
                                SubscriptionState subscriptions,
                                Metrics metrics,
                                String metricGrpPrefix,
@@ -132,7 +129,6 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                                boolean autoCommitEnabled,
                                int autoCommitIntervalMs,
                                ConsumerInterceptors<?, ?> interceptors,
-                               boolean excludeInternalTopics,
                                final boolean leaveGroupOnClose) {
         super(logContext,
               client,
@@ -147,7 +143,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
               leaveGroupOnClose);
         this.log = logContext.logger(ConsumerCoordinator.class);
         this.metadata = metadata;
-        this.metadataSnapshot = new MetadataSnapshot(subscriptions, metadata.fetch());
+        this.metadataSnapshot = new MetadataSnapshot(subscriptions, metadata.fetch(), metadata.updateVersion());
         this.subscriptions = subscriptions;
         this.defaultOffsetCommitCallback = new DefaultOffsetCommitCallback();
         this.autoCommitEnabled = autoCommitEnabled;
@@ -156,14 +152,12 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         this.completedOffsetCommits = new ConcurrentLinkedQueue<>();
         this.sensors = new ConsumerCoordinatorMetrics(metrics, metricGrpPrefix);
         this.interceptors = interceptors;
-        this.excludeInternalTopics = excludeInternalTopics;
         this.pendingAsyncCommits = new AtomicInteger();
 
         if (autoCommitEnabled)
             this.nextAutoCommitTimer = time.timer(autoCommitIntervalMs);
 
         this.metadata.requestUpdate();
-        addMetadataListener();
     }
 
     @Override
@@ -173,6 +167,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
 
     @Override
     protected List<ProtocolMetadata> metadata() {
+        log.debug("Joining group with current subscription: {}", subscriptions.subscription());
         this.joinedSubscription = subscriptions.subscription();
         List<ProtocolMetadata> metadataList = new ArrayList<>();
         for (PartitionAssignor assignor : assignors) {
@@ -184,46 +179,11 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     }
 
     public void updatePatternSubscription(Cluster cluster) {
-        final Set<String> topicsToSubscribe = new HashSet<>();
-
-        for (String topic : cluster.topics())
-            if (subscriptions.subscribedPattern().matcher(topic).matches() &&
-                    !(excludeInternalTopics && cluster.internalTopics().contains(topic)))
-                topicsToSubscribe.add(topic);
-
-        subscriptions.subscribeFromPattern(topicsToSubscribe);
-
-        // note we still need to update the topics contained in the metadata. Although we have
-        // specified that all topics should be fetched, only those set explicitly will be retained
-        metadata.setTopics(subscriptions.groupSubscription());
-    }
-
-    private void addMetadataListener() {
-        this.metadata.addListener(new Metadata.Listener() {
-            @Override
-            public void onMetadataUpdate(Cluster cluster, Set<String> unavailableTopics) {
-                // if we encounter any unauthorized topics, raise an exception to the user
-                if (!cluster.unauthorizedTopics().isEmpty())
-                    throw new TopicAuthorizationException(new HashSet<>(cluster.unauthorizedTopics()));
-
-                // if we encounter any invalid topics, raise an exception to the user
-                if (!cluster.invalidTopics().isEmpty())
-                    throw new InvalidTopicException(cluster.invalidTopics());
-
-                if (subscriptions.hasPatternSubscription())
-                    updatePatternSubscription(cluster);
-
-                // check if there are any changes to the metadata which should trigger a rebalance
-                if (subscriptions.partitionsAutoAssigned()) {
-                    MetadataSnapshot snapshot = new MetadataSnapshot(subscriptions, cluster);
-                    if (!snapshot.equals(metadataSnapshot))
-                        metadataSnapshot = snapshot;
-                }
-
-                if (!Collections.disjoint(metadata.topics(), unavailableTopics))
-                    metadata.requestUpdate();
-            }
-        });
+        final Set<String> topicsToSubscribe = cluster.topics().stream()
+                .filter(subscriptions::matchesSubscribedPattern)
+                .collect(Collectors.toSet());
+        if (subscriptions.subscribeFromPattern(topicsToSubscribe))
+            metadata.requestUpdateForNewTopics();
     }
 
     private PartitionAssignor lookupAssignor(String name) {
@@ -234,40 +194,26 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         return null;
     }
 
-    @Override
-    protected void onJoinComplete(int generation,
-                                  String memberId,
-                                  String assignmentStrategy,
-                                  ByteBuffer assignmentBuffer) {
-        // only the leader is responsible for monitoring for metadata changes (i.e. partition changes)
-        if (!isLeader)
-            assignmentSnapshot = null;
-
-        PartitionAssignor assignor = lookupAssignor(assignmentStrategy);
-        if (assignor == null)
-            throw new IllegalStateException("Coordinator selected invalid assignment protocol: " + assignmentStrategy);
-
-        Assignment assignment = ConsumerProtocol.deserializeAssignment(assignmentBuffer);
-        if (!subscriptions.assignFromSubscribed(assignment.partitions())) {
-            // was sent assignments that didn't match the original subscription
-            Set<TopicPartition> invalidAssignments = assignment.partitions().stream().filter(topicPartition -> 
+    private void handleAssignmentMismatch(Assignment assignment) {
+        // We received an assignment that doesn't match our current subscription. If the subscription changed,
+        // we can ignore the assignment and rebalance. Otherwise we raise an error.
+        Set<TopicPartition> invalidAssignments = assignment.partitions().stream().filter(topicPartition ->
                 !joinedSubscription.contains(topicPartition.topic())).collect(Collectors.toSet());
-            if (invalidAssignments.size() > 0) {
-                throw new IllegalStateException("Coordinator leader sent assignment that don't correspond to subscription request: " + invalidAssignments);
-            }
-
-            requestRejoin();
-            return;
+        if (invalidAssignments.size() > 0) {
+            throw new IllegalStateException("Consumer was assigned partitions " + invalidAssignments +
+                    " which didn't correspond to subscription request " + joinedSubscription);
         }
 
-        // check if the assignment contains some topics that were not in the original
+        requestRejoin();
+    }
+
+    private void maybeUpdateJoinedSubscription(Set<TopicPartition> assignedPartitions) {
+        // Check if the assignment contains some topics that were not in the original
         // subscription, if yes we will obey what leader has decided and add these topics
         // into the subscriptions as long as they still match the subscribed pattern
-        //
-        // TODO this part of the logic should be removed once we allow regex on leader assign
+
         Set<String> addedTopics = new HashSet<>();
         //this is a copy because its handed to listener below
-        Set<TopicPartition> assignedPartitions = new HashSet<>(subscriptions.assignedPartitions());
         for (TopicPartition tp : assignedPartitions) {
             if (!joinedSubscription.contains(tp.topic()))
                 addedTopics.add(tp.topic());
@@ -279,14 +225,36 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             newSubscription.addAll(addedTopics);
             newJoinedSubscription.addAll(addedTopics);
 
-            this.subscriptions.subscribeFromPattern(newSubscription);
+            if (this.subscriptions.subscribeFromPattern(newSubscription))
+                metadata.requestUpdateForNewTopics();
             this.joinedSubscription = newJoinedSubscription;
         }
+    }
+
+    @Override
+    protected void onJoinComplete(int generation,
+                                  String memberId,
+                                  String assignmentStrategy,
+                                  ByteBuffer assignmentBuffer) {
+        // only the leader is responsible for monitoring for metadata changes (i.e. partition changes)
+        if (!isLeader)
+            assignmentSnapshot = null;
+
+        PartitionAssignor assignor = lookupAssignor(assignmentStrategy);
+        if (assignor == null)
+            throw new IllegalStateException("Coordinator selected invalid assignment protocol: " + assignmentStrategy);
+
+        Assignment assignment = ConsumerProtocol.deserializeAssignment(assignmentBuffer);
+        if (!subscriptions.assignFromSubscribed(assignment.partitions())) {
+            handleAssignmentMismatch(assignment);
+            return;
+        }
 
-        // Update the metadata to include the full group subscription. The leader will trigger a rebalance
-        // if there are any metadata changes affecting any of the consumed partitions (whether or not this
-        // instance is subscribed to the topics).
-        this.metadata.setTopics(subscriptions.groupSubscription());
+        Set<TopicPartition> assignedPartitions = new HashSet<>(subscriptions.assignedPartitions());
+
+        // The leader may have assigned partitions which match our subscription pattern, but which
+        // were not explicitly requested, so we update the joined subscription here.
+        maybeUpdateJoinedSubscription(assignedPartitions);
 
         // give the assignor a chance to update internal state based on the received assignment
         assignor.onAssignment(assignment);
@@ -307,6 +275,20 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         }
     }
 
+    void maybeUpdateSubscriptionMetadata() {
+        int version = metadata.updateVersion();
+        if (version > metadataSnapshot.version) {
+            Cluster cluster = metadata.fetch();
+
+            if (subscriptions.hasPatternSubscription())
+                updatePatternSubscription(cluster);
+
+            // Update the current snapshot, which will be used to check for subscription
+            // changes that would require a rebalance (e.g. new partitions).
+            metadataSnapshot = new MetadataSnapshot(subscriptions, cluster, version);
+        }
+    }
+
     /**
      * Poll for coordinator events. This ensures that the coordinator is known and that the consumer
      * has joined the group (if it is using group management). This also handles periodic offset commits
@@ -318,6 +300,8 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
      * @return true iff the operation succeeded
      */
     public boolean poll(Timer timer) {
+        maybeUpdateSubscriptionMetadata();
+
         invokeCompletedOffsetCommitCallbacks();
 
         if (subscriptions.partitionsAutoAssigned()) {
@@ -340,13 +324,15 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                     // reduce the number of rebalances caused by single topic creation by asking consumer to
                     // refresh metadata before re-joining the group as long as the refresh backoff time has
                     // passed.
-                    if (this.metadata.timeToAllowUpdate(time.milliseconds()) == 0) {
+                    if (this.metadata.timeToAllowUpdate(timer.currentTimeMs()) == 0) {
                         this.metadata.requestUpdate();
                     }
 
                     if (!client.ensureFreshMetadata(timer)) {
                         return false;
                     }
+
+                    maybeUpdateSubscriptionMetadata();
                 }
 
                 if (!ensureActiveGroup(timer)) {
@@ -382,6 +368,20 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         return Math.min(nextAutoCommitTimer.remainingMs(), timeToNextHeartbeat(now));
     }
 
+    private void updateGroupSubscription(Set<String> topics) {
+        // the leader will begin watching for changes to any of the topics the group is interested in,
+        // which ensures that all metadata changes will eventually be seen
+        if (this.subscriptions.groupSubscribe(topics))
+            metadata.requestUpdateForNewTopics();
+
+        // update metadata (if needed) and keep track of the metadata used for assignment so that
+        // we can check after rebalance completion whether anything has changed
+        if (!client.ensureFreshMetadata(time.timer(Long.MAX_VALUE)))
+            throw new TimeoutException();
+
+        maybeUpdateSubscriptionMetadata();
+    }
+
     @Override
     protected Map<String, ByteBuffer> performAssignment(String leaderId,
                                                         String assignmentStrategy,
@@ -400,12 +400,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
 
         // the leader will begin watching for changes to any of the topics the group is interested in,
         // which ensures that all metadata changes will eventually be seen
-        this.subscriptions.groupSubscribe(allSubscribedTopics);
-        metadata.setTopics(this.subscriptions.groupSubscription());
-
-        // update metadata (if needed) and keep track of the metadata used for assignment so that
-        // we can check after rebalance completion whether anything has changed
-        if (!client.ensureFreshMetadata(time.timer(Long.MAX_VALUE))) throw new TimeoutException();
+        updateGroupSubscription(allSubscribedTopics);
 
         isLeader = true;
 
@@ -439,9 +434,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                     "fetched from the brokers: {}", newlyAddedTopics);
 
             allSubscribedTopics.addAll(assignedTopics);
-            this.subscriptions.groupSubscribe(allSubscribedTopics);
-            metadata.setTopics(this.subscriptions.groupSubscription());
-            if (!client.ensureFreshMetadata(time.timer(Long.MAX_VALUE))) throw new TimeoutException();
+            updateGroupSubscription(allSubscribedTopics);
         }
 
         assignmentSnapshot = metadataSnapshot;
@@ -485,7 +478,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             return false;
 
         // we need to rejoin if we performed the assignment and metadata has changed
-        if (assignmentSnapshot != null && !assignmentSnapshot.equals(metadataSnapshot))
+        if (assignmentSnapshot != null && !assignmentSnapshot.matches(metadataSnapshot))
             return true;
 
         // we need to join if our subscription has changed since the last join
@@ -966,26 +959,19 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     }
 
     private static class MetadataSnapshot {
+        private final int version;
         private final Map<String, Integer> partitionsPerTopic;
 
-        private MetadataSnapshot(SubscriptionState subscription, Cluster cluster) {
+        private MetadataSnapshot(SubscriptionState subscription, Cluster cluster, int version) {
             Map<String, Integer> partitionsPerTopic = new HashMap<>();
             for (String topic : subscription.groupSubscription())
                 partitionsPerTopic.put(topic, cluster.partitionCountForTopic(topic));
             this.partitionsPerTopic = partitionsPerTopic;
+            this.version = version;
         }
 
-        @Override
-        public boolean equals(Object o) {
-            if (this == o) return true;
-            if (o == null || getClass() != o.getClass()) return false;
-            MetadataSnapshot that = (MetadataSnapshot) o;
-            return partitionsPerTopic != null ? partitionsPerTopic.equals(that.partitionsPerTopic) : that.partitionsPerTopic == null;
-        }
-
-        @Override
-        public int hashCode() {
-            return partitionsPerTopic != null ? partitionsPerTopic.hashCode() : 0;
+        boolean matches(MetadataSnapshot other) {
+            return version == other.version || partitionsPerTopic.equals(other.partitionsPerTopic);
         }
     }
 
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadata.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadata.java
new file mode 100644
index 0000000..c87849d
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadata.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.clients.Metadata;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
+import org.apache.kafka.common.requests.MetadataRequest;
+import org.apache.kafka.common.utils.LogContext;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+public class ConsumerMetadata extends Metadata {
+    private final boolean includeInternalTopics;
+    private final SubscriptionState subscription;
+    private final Set<String> transientTopics;
+
+    public ConsumerMetadata(long refreshBackoffMs,
+                            long metadataExpireMs,
+                            boolean includeInternalTopics,
+                            SubscriptionState subscription,
+                            LogContext logContext,
+                            ClusterResourceListeners clusterResourceListeners) {
+        super(refreshBackoffMs, metadataExpireMs, logContext, clusterResourceListeners);
+        this.includeInternalTopics = includeInternalTopics;
+        this.subscription = subscription;
+        this.transientTopics = new HashSet<>();
+    }
+
+    @Override
+    public synchronized MetadataRequest.Builder newMetadataRequestBuilder() {
+        if (subscription.hasPatternSubscription())
+            return MetadataRequest.Builder.allTopics();
+        List<String> topics = new ArrayList<>();
+        topics.addAll(subscription.groupSubscription());
+        topics.addAll(transientTopics);
+        return new MetadataRequest.Builder(topics, true);
+    }
+
+    synchronized void addTransientTopics(Set<String> topics) {
+        this.transientTopics.addAll(topics);
+        if (!fetch().topics().containsAll(topics))
+            requestUpdateForNewTopics();
+    }
+
+    synchronized void clearTransientTopics() {
+        this.transientTopics.clear();
+    }
+
+    @Override
+    protected synchronized boolean retainTopic(String topic, boolean isInternal, long nowMs) {
+        if (transientTopics.contains(topic) || subscription.isGroupSubscribed(topic))
+            return true;
+
+        if (isInternal && !includeInternalTopics)
+            return false;
+
+        return subscription.matchesSubscribedPattern(topic);
+    }
+
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java
index 9aa8eaa..753fdb0 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java
@@ -159,11 +159,8 @@ public class ConsumerNetworkClient implements Closeable {
         int version = this.metadata.requestUpdate();
         do {
             poll(timer);
-            AuthenticationException ex = this.metadata.getAndClearAuthenticationException();
-            if (ex != null)
-                throw ex;
-        } while (this.metadata.version() == version && timer.notExpired());
-        return this.metadata.version() > version;
+        } while (this.metadata.updateVersion() == version && timer.notExpired());
+        return this.metadata.updateVersion() > version;
     }
 
     /**
@@ -295,6 +292,8 @@ public class ConsumerNetworkClient implements Closeable {
 
         // called without the lock to avoid deadlock potential if handlers need to acquire locks
         firePendingCompletedRequests();
+
+        metadata.maybeThrowException();
     }
 
     /**
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index 8531960..9009ffe 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -18,7 +18,6 @@ package org.apache.kafka.clients.consumer.internals;
 
 import org.apache.kafka.clients.ClientResponse;
 import org.apache.kafka.clients.FetchSessionHandler;
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MetadataCache;
 import org.apache.kafka.clients.StaleMetadataException;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
@@ -88,6 +87,8 @@ import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 import static java.util.Collections.emptyList;
 
@@ -121,7 +122,7 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
     private final long requestTimeoutMs;
     private final int maxPollRecords;
     private final boolean checkCrcs;
-    private final Metadata metadata;
+    private final ConsumerMetadata metadata;
     private final FetchManagerMetrics sensors;
     private final SubscriptionState subscriptions;
     private final ConcurrentLinkedQueue<CompletedFetch> completedFetches;
@@ -144,7 +145,7 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
                    boolean checkCrcs,
                    Deserializer<K> keyDeserializer,
                    Deserializer<V> valueDeserializer,
-                   Metadata metadata,
+                   ConsumerMetadata metadata,
                    SubscriptionState subscriptions,
                    Metrics metrics,
                    FetcherMetricsRegistry metricsRegistry,
@@ -391,24 +392,30 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
         resetOffsetsAsync(offsetResetTimestamps);
     }
 
-    public Map<TopicPartition, OffsetAndTimestamp> offsetsByTimes(Map<TopicPartition, Long> timestampsToSearch,
-                                                                  Timer timer) {
-        Map<TopicPartition, OffsetData> fetchedOffsets = fetchOffsetsByTimes(timestampsToSearch,
-                timer, true).fetchedOffsets;
-
-        HashMap<TopicPartition, OffsetAndTimestamp> offsetsByTimes = new HashMap<>(timestampsToSearch.size());
-        for (Map.Entry<TopicPartition, Long> entry : timestampsToSearch.entrySet())
-            offsetsByTimes.put(entry.getKey(), null);
-
-        for (Map.Entry<TopicPartition, OffsetData> entry : fetchedOffsets.entrySet()) {
-            // 'entry.getValue().timestamp' will not be null since we are guaranteed
-            // to work with a v1 (or later) ListOffset request
-            OffsetData offsetData = entry.getValue();
-            offsetsByTimes.put(entry.getKey(), new OffsetAndTimestamp(offsetData.offset, offsetData.timestamp,
-                    offsetData.leaderEpoch));
-        }
+    public Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(Map<TopicPartition, Long> timestampsToSearch,
+                                                                   Timer timer) {
+        metadata.addTransientTopics(topicsForPartitions(timestampsToSearch.keySet()));
+
+        try {
+            Map<TopicPartition, OffsetData> fetchedOffsets = fetchOffsetsByTimes(timestampsToSearch,
+                    timer, true).fetchedOffsets;
+
+            HashMap<TopicPartition, OffsetAndTimestamp> offsetsByTimes = new HashMap<>(timestampsToSearch.size());
+            for (Map.Entry<TopicPartition, Long> entry : timestampsToSearch.entrySet())
+                offsetsByTimes.put(entry.getKey(), null);
+
+            for (Map.Entry<TopicPartition, OffsetData> entry : fetchedOffsets.entrySet()) {
+                // 'entry.getValue().timestamp' will not be null since we are guaranteed
+                // to work with a v1 (or later) ListOffset request
+                OffsetData offsetData = entry.getValue();
+                offsetsByTimes.put(entry.getKey(), new OffsetAndTimestamp(offsetData.offset, offsetData.timestamp,
+                        offsetData.leaderEpoch));
+            }
 
-        return offsetsByTimes;
+            return offsetsByTimes;
+        } finally {
+            metadata.clearTransientTopics();
+        }
     }
 
     private ListOffsetResult fetchOffsetsByTimes(Map<TopicPartition, Long> timestampsToSearch,
@@ -457,15 +464,18 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
     private Map<TopicPartition, Long> beginningOrEndOffset(Collection<TopicPartition> partitions,
                                                            long timestamp,
                                                            Timer timer) {
-        Map<TopicPartition, Long> timestampsToSearch = new HashMap<>();
-        for (TopicPartition tp : partitions)
-            timestampsToSearch.put(tp, timestamp);
-        Map<TopicPartition, Long> offsets = new HashMap<>();
-        ListOffsetResult result = fetchOffsetsByTimes(timestampsToSearch, timer, false);
-        for (Map.Entry<TopicPartition, OffsetData> entry : result.fetchedOffsets.entrySet()) {
-            offsets.put(entry.getKey(), entry.getValue().offset);
+        metadata.addTransientTopics(topicsForPartitions(partitions));
+        try {
+            Map<TopicPartition, Long> timestampsToSearch = partitions.stream()
+                    .collect(Collectors.toMap(Function.identity(), tp -> timestamp));
+
+            ListOffsetResult result = fetchOffsetsByTimes(timestampsToSearch, timer, false);
+
+            return result.fetchedOffsets.entrySet().stream()
+                    .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().offset));
+        } finally {
+            metadata.clearTransientTopics();
         }
-        return offsets;
     }
 
     /**
@@ -589,10 +599,6 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
     }
 
     private void resetOffsetsAsync(Map<TopicPartition, Long> partitionResetTimestamps) {
-        // Add the topics to the metadata to do a single metadata fetch.
-        for (TopicPartition tp : partitionResetTimestamps.keySet())
-            metadata.add(tp.topic());
-
         Map<Node, Map<TopicPartition, ListOffsetRequest.PartitionData>> timestampsToSearchByNode =
                 groupListOffsetRequests(partitionResetTimestamps, new HashSet<>());
         for (Map.Entry<Node, Map<TopicPartition, ListOffsetRequest.PartitionData>> entry : timestampsToSearchByNode.entrySet()) {
@@ -639,10 +645,6 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
      */
     private RequestFuture<ListOffsetResult> sendListOffsetsRequests(final Map<TopicPartition, Long> timestampsToSearch,
                                                                     final boolean requireTimestamps) {
-        // Add the topics to the metadata to do a single metadata fetch.
-        for (TopicPartition tp : timestampsToSearch.keySet())
-            metadata.add(tp.topic());
-
         final Set<TopicPartition> partitionsToRetry = new HashSet<>();
         Map<Node, Map<TopicPartition, ListOffsetRequest.PartitionData>> timestampsToSearchByNode =
                 groupListOffsetRequests(timestampsToSearch, partitionsToRetry);
@@ -697,7 +699,6 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
             TopicPartition tp  = entry.getKey();
             Optional<MetadataCache.PartitionInfoAndEpoch> currentInfo = metadata.partitionInfoIfCurrent(tp);
             if (!currentInfo.isPresent()) {
-                metadata.add(tp.topic());
                 log.debug("Leader for partition {} is unknown for fetching offset", tp);
                 metadata.requestUpdate();
                 partitionsToRetry.add(tp);
@@ -1539,4 +1540,8 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
         decompressionBufferSupplier.close();
     }
 
+    private Set<String> topicsForPartitions(Collection<TopicPartition> partitions) {
+        return partitions.stream().map(TopicPartition::topic).collect(Collectors.toSet());
+    }
+
 }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
index 3298980..fe944c5 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
@@ -34,6 +34,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Predicate;
 import java.util.regex.Pattern;
 import java.util.stream.Collector;
@@ -54,6 +55,11 @@ import java.util.stream.Collectors;
  *
  * Note that pause state as well as fetch/consumed positions are not preserved when partition
  * assignment is changed whether directly by the user or through a group rebalance.
+ *
+ * Thread Safety: this class is generally not thread-safe. It should only be accessed in the
+ * consumer's calling thread. The only exception is {@link ConsumerMetadata} which accesses
+ * the subscription state needed to build and handle Metadata requests. The thread-safe methods
+ * are documented below.
  */
 public class SubscriptionState {
     private static final String SUBSCRIPTION_EXCEPTION_MESSAGE =
@@ -66,15 +72,17 @@ public class SubscriptionState {
     }
 
     /* the type of subscription */
-    private SubscriptionType subscriptionType;
+    private volatile SubscriptionType subscriptionType;
 
     /* the pattern user has requested */
-    private Pattern subscribedPattern;
+    private volatile Pattern subscribedPattern;
 
     /* the list of topics the user has requested */
     private Set<String> subscription;
 
-    /* the list of topics the group has subscribed to (set only for the leader on join group completion) */
+    /* The list of topics the group has subscribed to. This may include some topics which are not part
+     * of `subscription` for the leader of a group since it is responsible for detecting metadata changes
+     * which require a group rebalance. */
     private final Set<String> groupSubscription;
 
     /* the partitions that are currently assigned, note that the order of partition matters (see FetchBuilder for more details) */
@@ -94,7 +102,7 @@ public class SubscriptionState {
         this.defaultResetStrategy = defaultResetStrategy;
         this.subscription = Collections.emptySet();
         this.assignment = new PartitionStates<>();
-        this.groupSubscription = new HashSet<>();
+        this.groupSubscription = ConcurrentHashMap.newKeySet();
         this.subscribedPattern = null;
         this.subscriptionType = SubscriptionType.NONE;
     }
@@ -112,7 +120,7 @@ public class SubscriptionState {
             throw new IllegalStateException(SUBSCRIPTION_EXCEPTION_MESSAGE);
     }
 
-    public void subscribe(Set<String> topics, ConsumerRebalanceListener listener) {
+    public boolean subscribe(Set<String> topics, ConsumerRebalanceListener listener) {
         if (listener == null)
             throw new IllegalArgumentException("RebalanceListener cannot be null");
 
@@ -120,22 +128,24 @@ public class SubscriptionState {
 
         this.rebalanceListener = listener;
 
-        changeSubscription(topics);
+        return changeSubscription(topics);
     }
 
-    public void subscribeFromPattern(Set<String> topics) {
+    public boolean subscribeFromPattern(Set<String> topics) {
         if (subscriptionType != SubscriptionType.AUTO_PATTERN)
             throw new IllegalArgumentException("Attempt to subscribe from pattern while subscription type set to " +
                     subscriptionType);
 
-        changeSubscription(topics);
+        return changeSubscription(topics);
     }
 
-    private void changeSubscription(Set<String> topicsToSubscribe) {
-        if (!this.subscription.equals(topicsToSubscribe)) {
-            this.subscription = topicsToSubscribe;
-            this.groupSubscription.addAll(topicsToSubscribe);
-        }
+    private boolean changeSubscription(Set<String> topicsToSubscribe) {
+        if (subscription.equals(topicsToSubscribe))
+            return false;
+
+        this.subscription = topicsToSubscribe;
+        this.groupSubscription.addAll(topicsToSubscribe);
+        return true;
     }
 
     /**
@@ -143,10 +153,10 @@ public class SubscriptionState {
      * that it receives metadata updates for all topics that the group is interested in.
      * @param topics The topics to add to the group subscription
      */
-    public void groupSubscribe(Collection<String> topics) {
-        if (this.subscriptionType == SubscriptionType.USER_ASSIGNED)
+    public boolean groupSubscribe(Collection<String> topics) {
+        if (!partitionsAutoAssigned())
             throw new IllegalStateException(SUBSCRIPTION_EXCEPTION_MESSAGE);
-        this.groupSubscription.addAll(topics);
+        return this.groupSubscription.addAll(topics);
     }
 
     /**
@@ -161,21 +171,25 @@ public class SubscriptionState {
      * note this is different from {@link #assignFromSubscribed(Collection)}
      * whose input partitions are provided from the subscribed topics.
      */
-    public void assignFromUser(Set<TopicPartition> partitions) {
+    public boolean assignFromUser(Set<TopicPartition> partitions) {
         setSubscriptionType(SubscriptionType.USER_ASSIGNED);
 
-        if (!this.assignment.partitionSet().equals(partitions)) {
-            fireOnAssignment(partitions);
+        if (this.assignment.partitionSet().equals(partitions))
+            return false;
 
-            Map<TopicPartition, TopicPartitionState> partitionToState = new HashMap<>();
-            for (TopicPartition partition : partitions) {
-                TopicPartitionState state = assignment.stateValue(partition);
-                if (state == null)
-                    state = new TopicPartitionState();
-                partitionToState.put(partition, state);
-            }
-            this.assignment.set(partitionToState);
+        fireOnAssignment(partitions);
+
+        Set<String> manualSubscribedTopics = new HashSet<>();
+        Map<TopicPartition, TopicPartitionState> partitionToState = new HashMap<>();
+        for (TopicPartition partition : partitions) {
+            TopicPartitionState state = assignment.stateValue(partition);
+            if (state == null)
+                state = new TopicPartitionState();
+            partitionToState.put(partition, state);
+            manualSubscribedTopics.add(partition.topic());
         }
+        this.assignment.set(partitionToState);
+        return changeSubscription(manualSubscribedTopics);
     }
 
     /**
@@ -229,6 +243,10 @@ public class SubscriptionState {
         this.subscribedPattern = pattern;
     }
 
+    /**
+     * Check whether pattern subscription is in use. This is thread-safe.
+     *
+     */
     public boolean hasPatternSubscription() {
         return this.subscriptionType == SubscriptionType.AUTO_PATTERN;
     }
@@ -239,18 +257,31 @@ public class SubscriptionState {
 
     public void unsubscribe() {
         this.subscription = Collections.emptySet();
+        this.groupSubscription.clear();
         this.assignment.clear();
         this.subscribedPattern = null;
         this.subscriptionType = SubscriptionType.NONE;
         fireOnAssignment(Collections.emptySet());
     }
 
-    public Pattern subscribedPattern() {
-        return this.subscribedPattern;
+    /**
+     * Check whether a topic matches a subscribed pattern.
+     *
+     * This is thread-safe, but it may not always reflect the most recent subscription pattern.
+     *
+     * @return true if pattern subscription is in use and the topic matches the subscribed pattern, false otherwise
+     */
+    public boolean matchesSubscribedPattern(String topic) {
+        Pattern pattern = this.subscribedPattern;
+        if (hasPatternSubscription() && pattern != null)
+            return pattern.matcher(topic).matches();
+        return false;
     }
 
     public Set<String> subscription() {
-        return this.subscription;
+        if (partitionsAutoAssigned())
+            return this.subscription;
+        return Collections.emptySet();
     }
 
     public Set<TopicPartition> pausedPartitions() {
@@ -264,6 +295,9 @@ public class SubscriptionState {
      * require rebalancing. The leader fetches metadata for all topics in the group so that it
      * can do the partition assignment (which requires at least partition counts for all topics
      * to be assigned).
+     *
+     * Note this is thread-safe since the Set is backed by a ConcurrentMap.
+     *
      * @return The union of all subscribed topics in the group if this member is the leader
      *   of the current generation; otherwise it returns the same set as {@link #subscription()}
      */
@@ -271,6 +305,13 @@ public class SubscriptionState {
         return this.groupSubscription;
     }
 
+    /**
+     * Note this is thread-safe since the Set is backed by a ConcurrentMap.
+     */
+    public boolean isGroupSubscribed(String topic) {
+        return groupSubscription.contains(topic);
+    }
+
     private TopicPartitionState assignedState(TopicPartition tp) {
         TopicPartitionState state = this.assignment.stateValue(tp);
         if (state == null)
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
index 5a383ea..3a0130f 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
@@ -16,30 +16,17 @@
  */
 package org.apache.kafka.clients.producer;
 
-import java.net.InetSocketAddress;
-import java.time.Duration;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Properties;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Future;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReference;
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.ClientDnsLookup;
 import org.apache.kafka.clients.ClientUtils;
 import org.apache.kafka.clients.KafkaClient;
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetCommitCallback;
 import org.apache.kafka.clients.producer.internals.BufferPool;
 import org.apache.kafka.clients.producer.internals.ProducerInterceptors;
+import org.apache.kafka.clients.producer.internals.ProducerMetadata;
 import org.apache.kafka.clients.producer.internals.ProducerMetrics;
 import org.apache.kafka.clients.producer.internals.RecordAccumulator;
 import org.apache.kafka.clients.producer.internals.Sender;
@@ -61,7 +48,6 @@ import org.apache.kafka.common.errors.ProducerFencedException;
 import org.apache.kafka.common.errors.RecordTooLargeException;
 import org.apache.kafka.common.errors.SerializationException;
 import org.apache.kafka.common.errors.TimeoutException;
-import org.apache.kafka.common.errors.TopicAuthorizationException;
 import org.apache.kafka.common.header.Header;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.header.internals.RecordHeaders;
@@ -83,6 +69,20 @@ import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.slf4j.Logger;
 
+import java.net.InetSocketAddress;
+import java.time.Duration;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Properties;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+
 
 /**
  * A Kafka client that publishes records to the Kafka cluster.
@@ -240,7 +240,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
     private final Partitioner partitioner;
     private final int maxRequestSize;
     private final long totalMemorySize;
-    private final Metadata metadata;
+    private final ProducerMetadata metadata;
     private final RecordAccumulator accumulator;
     private final Sender sender;
     private final Thread ioThread;
@@ -318,7 +318,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
     KafkaProducer(Map<String, Object> configs,
                   Serializer<K> keySerializer,
                   Serializer<V> valueSerializer,
-                  Metadata metadata,
+                  ProducerMetadata metadata,
                   KafkaClient kafkaClient,
                   ProducerInterceptors interceptors,
                   Time time) {
@@ -410,8 +410,11 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
             if (metadata != null) {
                 this.metadata = metadata;
             } else {
-                this.metadata = new Metadata(retryBackoffMs, config.getLong(ProducerConfig.METADATA_MAX_AGE_CONFIG),
-                    true, true, clusterResourceListeners);
+                this.metadata = new ProducerMetadata(retryBackoffMs,
+                        config.getLong(ProducerConfig.METADATA_MAX_AGE_CONFIG),
+                        logContext,
+                        clusterResourceListeners,
+                        Time.SYSTEM);
                 this.metadata.bootstrap(addresses, time.milliseconds());
             }
             this.errors = this.metrics.sensor("errors");
@@ -431,7 +434,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
     }
 
     // visible for testing
-    Sender newSender(LogContext logContext, KafkaClient kafkaClient, Metadata metadata) {
+    Sender newSender(LogContext logContext, KafkaClient kafkaClient, ProducerMetadata metadata) {
         int maxInflightRequests = configureInflightRequests(producerConfig, transactionManager != null);
         int requestTimeoutMs = producerConfig.getInt(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG);
         ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(producerConfig, time);
@@ -998,10 +1001,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
                         String.format("Partition %d of topic %s with partition count %d is not present in metadata after %d ms.",
                                 partition, topic, partitionsCount, maxWaitMs));
             }
-            if (cluster.unauthorizedTopics().contains(topic))
-                throw new TopicAuthorizationException(topic);
-            if (cluster.invalidTopics().contains(topic))
-                throw new InvalidTopicException(topic);
+            metadata.maybeThrowException();
             remainingWaitMs = maxWaitMs - elapsed;
             partitionsCount = cluster.partitionCountForTopic(topic);
         } while (partitionsCount == null || (partition != null && partition >= partitionsCount));
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerMetadata.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerMetadata.java
new file mode 100644
index 0000000..90e7970
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerMetadata.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.producer.internals;
+
+import org.apache.kafka.clients.Metadata;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.errors.AuthenticationException;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
+import org.apache.kafka.common.requests.MetadataRequest;
+import org.apache.kafka.common.requests.MetadataResponse;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+
+public class ProducerMetadata extends Metadata {
+    private static final long TOPIC_EXPIRY_NEEDS_UPDATE = -1L;
+    static final long TOPIC_EXPIRY_MS = 5 * 60 * 1000;
+
+    /* Topics with expiry time */
+    private final Map<String, Long> topics = new HashMap<>();
+    private final Logger log;
+    private final Time time;
+
+    public ProducerMetadata(long refreshBackoffMs,
+                            long metadataExpireMs,
+                            LogContext logContext,
+                            ClusterResourceListeners clusterResourceListeners,
+                            Time time) {
+        super(refreshBackoffMs, metadataExpireMs, logContext, clusterResourceListeners);
+        this.log = logContext.logger(ProducerMetadata.class);
+        this.time = time;
+    }
+
+    @Override
+    public synchronized MetadataRequest.Builder newMetadataRequestBuilder() {
+        return new MetadataRequest.Builder(new ArrayList<>(topics.keySet()), true);
+    }
+
+    public synchronized void add(String topic) {
+        Objects.requireNonNull(topic, "topic cannot be null");
+        if (topics.put(topic, TOPIC_EXPIRY_NEEDS_UPDATE) == null) {
+            requestUpdateForNewTopics();
+        }
+    }
+
+    // Visible for testing
+    synchronized Set<String> topics() {
+        return topics.keySet();
+    }
+
+    public synchronized boolean containsTopic(String topic) {
+        return topics.containsKey(topic);
+    }
+
+    @Override
+    public synchronized boolean retainTopic(String topic, boolean isInternal, long nowMs) {
+        Long expireMs = topics.get(topic);
+        if (expireMs == null) {
+            return false;
+        } else if (expireMs == TOPIC_EXPIRY_NEEDS_UPDATE) {
+            topics.put(topic, nowMs + TOPIC_EXPIRY_MS);
+            return true;
+        } else if (expireMs <= nowMs) {
+            log.debug("Removing unused topic {} from the metadata list, expiryMs {} now {}", topic, expireMs, nowMs);
+            topics.remove(topic);
+            return false;
+        } else {
+            return true;
+        }
+    }
+
+    /**
+     * Wait for metadata update until the current version is larger than the last version we know of
+     */
+    public synchronized void awaitUpdate(final int lastVersion, final long timeoutMs) throws InterruptedException {
+        long currentTimeMs = time.milliseconds();
+        long deadlineMs = currentTimeMs + timeoutMs < 0 ? Long.MAX_VALUE : currentTimeMs + timeoutMs;
+        time.waitObject(this, () -> {
+            maybeThrowException();
+            return updateVersion() > lastVersion || isClosed();
+        }, deadlineMs);
+
+        if (isClosed())
+            throw new KafkaException("Requested metadata update after close");
+    }
+
+    @Override
+    public synchronized void update(int requestVersion, MetadataResponse response, long now) {
+        super.update(requestVersion, response, now);
+        notifyAll();
+    }
+
+    @Override
+    public synchronized void failedUpdate(long now, AuthenticationException authenticationException) {
+        super.failedUpdate(now, authenticationException);
+        if (authenticationException != null)
+            notifyAll();
+    }
+
+    /**
+     * Close this instance and notify any awaiting threads.
+     */
+    @Override
+    public synchronized void close() {
+        super.close();
+        notifyAll();
+    }
+
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
index d003f4d..6189aae 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
@@ -82,7 +82,7 @@ public class Sender implements Runnable {
     private final RecordAccumulator accumulator;
 
     /* the metadata for the client */
-    private final Metadata metadata;
+    private final ProducerMetadata metadata;
 
     /* the flag indicating whether the producer should guarantee the message order on the broker or not. */
     private final boolean guaranteeMessageOrder;
@@ -125,7 +125,7 @@ public class Sender implements Runnable {
 
     public Sender(LogContext logContext,
                   KafkaClient client,
-                  Metadata metadata,
+                  ProducerMetadata metadata,
                   RecordAccumulator accumulator,
                   boolean guaranteeMessageOrder,
                   int maxRequestSize,
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java
index 461c2d7..f90876f 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java
@@ -19,7 +19,6 @@ package org.apache.kafka.common.requests;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.PartitionInfo;
-import org.apache.kafka.common.errors.InvalidMetadataException;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.protocol.types.Field;
@@ -341,29 +340,6 @@ public class MetadataResponse extends AbstractResponse {
     }
 
     /**
-     * Returns the set of topics with an error indicating invalid metadata
-     * and topics with any partition whose error indicates invalid metadata.
-     * This includes all non-existent topics specified in the metadata request
-     * and any topic returned with one or more partitions whose leader is not known.
-     */
-    public Set<String> unavailableTopics() {
-        Set<String> invalidMetadataTopics = new HashSet<>();
-        for (TopicMetadata topicMetadata : this.topicMetadata) {
-            if (topicMetadata.error.exception() instanceof InvalidMetadataException)
-                invalidMetadataTopics.add(topicMetadata.topic);
-            else {
-                for (PartitionMetadata partitionMetadata : topicMetadata.partitionMetadata) {
-                    if (partitionMetadata.error.exception() instanceof InvalidMetadataException) {
-                        invalidMetadataTopics.add(topicMetadata.topic);
-                        break;
-                    }
-                }
-            }
-        }
-        return invalidMetadataTopics;
-    }
-
-    /**
      * Get a snapshot of the cluster metadata from this response
      * @return the cluster snapshot
      */
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java b/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java
index c8b79ab..9ef096f 100644
--- a/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java
@@ -16,6 +16,10 @@
  */
 package org.apache.kafka.common.utils;
 
+import org.apache.kafka.common.errors.TimeoutException;
+
+import java.util.function.Supplier;
+
 /**
  * A time implementation that uses the system clock and sleep call. Use `Time.SYSTEM` instead of creating an instance
  * of this class.
@@ -42,4 +46,20 @@ public class SystemTime implements Time {
         }
     }
 
+    @Override
+    public void waitObject(Object obj, Supplier<Boolean> condition, long deadlineMs) throws InterruptedException {
+        synchronized (obj) {
+            while (true) {
+                if (condition.get())
+                    return;
+
+                long currentTimeMs = milliseconds();
+                if (currentTimeMs >= deadlineMs)
+                    throw new TimeoutException("Condition not satisfied before deadline");
+
+                obj.wait(deadlineMs - currentTimeMs);
+            }
+        }
+    }
+
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Time.java b/clients/src/main/java/org/apache/kafka/common/utils/Time.java
index 90190cb..04aefa4 100644
--- a/clients/src/main/java/org/apache/kafka/common/utils/Time.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Time.java
@@ -18,6 +18,7 @@ package org.apache.kafka.common.utils;
 
 import java.time.Duration;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
 
 /**
  * An interface abstracting the clock to use in unit testing classes that make use of clock time.
@@ -59,6 +60,19 @@ public interface Time {
     void sleep(long ms);
 
     /**
+     * Wait for a condition using the monitor of a given object. This avoids the implicit
+     * dependence on system time when calling {@link Object#wait()}.
+     *
+     * @param obj The object that will be waited with {@link Object#wait()}. Note that it is the responsibility
+     *      of the caller to call notify on this object when the condition is satisfied.
+     * @param condition The condition we are awaiting
+     * @param timeoutMs How long to wait in milliseconds
+     *
+     * @throws org.apache.kafka.common.errors.TimeoutException if the timeout expires before the condition is satisfied
+     */
+    void waitObject(Object obj, Supplier<Boolean> condition, long timeoutMs) throws InterruptedException;
+
+    /**
      * Get a timer which is bound to this time instance and expires after the given timeout
      */
     default Timer timer(long timeoutMs) {
@@ -71,4 +85,5 @@ public interface Time {
     default Timer timer(Duration timeout) {
         return timer(timeout.toMillis());
     }
+
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
index 7142196..3d28297 100644
--- a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
@@ -17,48 +17,39 @@
 package org.apache.kafka.clients;
 
 import org.apache.kafka.common.Cluster;
-import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.errors.InvalidTopicException;
+import org.apache.kafka.common.errors.TopicAuthorizationException;
 import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.internals.Topic;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.requests.MetadataResponse;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.test.MockClusterResourceListener;
 import org.apache.kafka.test.TestUtils;
-import org.junit.After;
 import org.junit.Test;
 
 import java.net.InetSocketAddress;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
 import java.util.Map;
 import java.util.Optional;
-import java.util.Set;
-import java.util.concurrent.atomic.AtomicReference;
 
-import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 
 public class MetadataTest {
 
     private long refreshBackoffMs = 100;
     private long metadataExpireMs = 1000;
-    private Metadata metadata = new Metadata(refreshBackoffMs, metadataExpireMs, true);
-    private AtomicReference<Exception> backgroundError = new AtomicReference<>();
-
-    @After
-    public void tearDown() {
-        assertNull("Exception in background thread : " + backgroundError.get(), backgroundError.get());
-    }
+    private Metadata metadata = new Metadata(refreshBackoffMs, metadataExpireMs, new LogContext(),
+            new ClusterResourceListeners());
 
     private static MetadataResponse emptyMetadataResponse() {
         return new MetadataResponse(
@@ -67,55 +58,6 @@ public class MetadataTest {
                 -1,
                 Collections.emptyList());
     }
-    
-    @Test
-    public void testMetadata() throws Exception {
-        long time = 0;
-        metadata.update(emptyMetadataResponse(), time);
-        assertFalse("No update needed.", metadata.timeToNextUpdate(time) == 0);
-        metadata.requestUpdate();
-        assertFalse("Still no updated needed due to backoff", metadata.timeToNextUpdate(time) == 0);
-        time += refreshBackoffMs;
-        assertTrue("Update needed now that backoff time expired", metadata.timeToNextUpdate(time) == 0);
-        String topic = "my-topic";
-        Thread t1 = asyncFetch(topic, 500);
-        Thread t2 = asyncFetch(topic, 500);
-        assertTrue("Awaiting update", t1.isAlive());
-        assertTrue("Awaiting update", t2.isAlive());
-        // Perform metadata update when an update is requested on the async fetch thread
-        // This simulates the metadata update sequence in KafkaProducer
-        while (t1.isAlive() || t2.isAlive()) {
-            if (metadata.timeToNextUpdate(time) == 0) {
-                MetadataResponse response = TestUtils.metadataUpdateWith(1, Collections.singletonMap(topic, 1));
-                metadata.update(response, time);
-                time += refreshBackoffMs;
-            }
-            Thread.sleep(1);
-        }
-        t1.join();
-        t2.join();
-        assertFalse("No update needed.", metadata.timeToNextUpdate(time) == 0);
-        time += metadataExpireMs;
-        assertTrue("Update needed due to stale metadata.", metadata.timeToNextUpdate(time) == 0);
-    }
-
-    @Test
-    public void testMetadataAwaitAfterClose() throws InterruptedException {
-        long time = 0;
-        metadata.update(emptyMetadataResponse(), time);
-        assertFalse("No update needed.", metadata.timeToNextUpdate(time) == 0);
-        metadata.requestUpdate();
-        assertFalse("Still no updated needed due to backoff", metadata.timeToNextUpdate(time) == 0);
-        time += refreshBackoffMs;
-        assertTrue("Update needed now that backoff time expired", metadata.timeToNextUpdate(time) == 0);
-        String topic = "my-topic";
-        metadata.close();
-        Thread t1 = asyncFetch(topic, 500);
-        t1.join();
-        assertTrue(backgroundError.get().getClass() == KafkaException.class);
-        assertTrue(backgroundError.get().toString().contains("Requested metadata update after close"));
-        clearBackgroundError();
-    }
 
     @Test(expected = IllegalStateException.class)
     public void testMetadataUpdateAfterClose() {
@@ -136,7 +78,8 @@ public class MetadataTest {
         }
 
         long largerOfBackoffAndExpire = Math.max(refreshBackoffMs, metadataExpireMs);
-        Metadata metadata = new Metadata(refreshBackoffMs, metadataExpireMs, true);
+        Metadata metadata = new Metadata(refreshBackoffMs, metadataExpireMs, new LogContext(),
+                new ClusterResourceListeners());
 
         assertEquals(0, metadata.timeToNextUpdate(now));
 
@@ -192,65 +135,6 @@ public class MetadataTest {
     }
 
     @Test
-    public void testTimeToNextUpdate_OverwriteBackoff() {
-        long now = 10000;
-
-        // New topic added to fetch set and update requested. It should allow immediate update.
-        metadata.update(emptyMetadataResponse(), now);
-        metadata.add("new-topic");
-        assertEquals(0, metadata.timeToNextUpdate(now));
-
-        // Even though setTopics called, immediate update isn't necessary if the new topic set isn't
-        // containing a new topic,
-        metadata.update(emptyMetadataResponse(), now);
-        metadata.setTopics(metadata.topics());
-        assertEquals(metadataExpireMs, metadata.timeToNextUpdate(now));
-
-        // If the new set of topics containing a new topic then it should allow immediate update.
-        metadata.setTopics(Collections.singletonList("another-new-topic"));
-        assertEquals(0, metadata.timeToNextUpdate(now));
-
-        // If metadata requested for all topics it should allow immediate update.
-        metadata.update(emptyMetadataResponse(), now);
-        metadata.needMetadataForAllTopics(true);
-        assertEquals(0, metadata.timeToNextUpdate(now));
-
-        // However if metadata is already capable to serve all topics it shouldn't override backoff.
-        metadata.update(emptyMetadataResponse(), now);
-        metadata.needMetadataForAllTopics(true);
-        assertEquals(metadataExpireMs, metadata.timeToNextUpdate(now));
-    }
-
-    /**
-     * Tests that {@link org.apache.kafka.clients.Metadata#awaitUpdate(int, long)} doesn't
-     * wait forever with a max timeout value of 0
-     *
-     * @throws Exception
-     * @see <a href=https://issues.apache.org/jira/browse/KAFKA-1836>KAFKA-1836</a>
-     */
-    @Test
-    public void testMetadataUpdateWaitTime() throws Exception {
-        long time = 0;
-        metadata.update(emptyMetadataResponse(), time);
-        assertFalse("No update needed.", metadata.timeToNextUpdate(time) == 0);
-        // first try with a max wait time of 0 and ensure that this returns back without waiting forever
-        try {
-            metadata.awaitUpdate(metadata.requestUpdate(), 0);
-            fail("Wait on metadata update was expected to timeout, but it didn't");
-        } catch (TimeoutException te) {
-            // expected
-        }
-        // now try with a higher timeout value once
-        final long twoSecondWait = 2000;
-        try {
-            metadata.awaitUpdate(metadata.requestUpdate(), twoSecondWait);
-            fail("Wait on metadata update was expected to timeout, but it didn't");
-        } catch (TimeoutException te) {
-            // expected
-        }
-    }
-
-    @Test
     public void testFailedUpdate() {
         long time = 100;
         metadata.update(emptyMetadataResponse(), time);
@@ -261,39 +145,17 @@ public class MetadataTest {
         assertEquals(100, metadata.timeToNextUpdate(1100));
         assertEquals(100, metadata.lastSuccessfulUpdate());
 
-        metadata.needMetadataForAllTopics(true);
         metadata.update(emptyMetadataResponse(), time);
         assertEquals(100, metadata.timeToNextUpdate(1000));
     }
 
     @Test
-    public void testUpdateWithNeedMetadataForAllTopics() {
-        long time = 0;
-        metadata.update(emptyMetadataResponse(), time);
-        metadata.needMetadataForAllTopics(true);
-
-        final List<String> expectedTopics = Collections.singletonList("topic");
-        metadata.setTopics(expectedTopics);
-
-        Map<String, Integer> partitionCounts = new HashMap<>();
-        partitionCounts.put("topic", 1);
-        partitionCounts.put("topic1", 1);
-        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith(1, partitionCounts);
-        metadata.update(metadataResponse, 100);
-
-        assertArrayEquals("Metadata got updated with wrong set of topics.",
-            expectedTopics.toArray(), metadata.topics().toArray());
-
-        metadata.needMetadataForAllTopics(false);
-    }
-
-    @Test
     public void testClusterListenerGetsNotifiedOfUpdate() {
         long time = 0;
         MockClusterResourceListener mockClusterListener = new MockClusterResourceListener();
         ClusterResourceListeners listeners = new ClusterResourceListeners();
         listeners.maybeAdd(mockClusterListener);
-        metadata = new Metadata(refreshBackoffMs, metadataExpireMs, true, false, listeners);
+        metadata = new Metadata(refreshBackoffMs, metadataExpireMs, new LogContext(), listeners);
 
         String hostName = "www.example.com";
         metadata.bootstrap(Collections.singletonList(new InetSocketAddress(hostName, 9002)), time);
@@ -312,127 +174,9 @@ public class MetadataTest {
                 MockClusterResourceListener.IS_ON_UPDATE_CALLED.get());
     }
 
-    @Test
-    public void testListenerGetsNotifiedOfUpdate() {
-        long time = 0;
-        final Set<String> topics = new HashSet<>();
-        metadata.update(emptyMetadataResponse(), time);
-        metadata.addListener(new Metadata.Listener() {
-            @Override
-            public void onMetadataUpdate(Cluster cluster, Set<String> unavailableTopics) {
-                topics.clear();
-                topics.addAll(cluster.topics());
-            }
-        });
-
-        Map<String, Integer> partitionCounts = new HashMap<>();
-        partitionCounts.put("topic", 1);
-        partitionCounts.put("topic1", 1);
-        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, partitionCounts);
-        metadata.update(metadataResponse, 100);
-
-        assertEquals("Listener did not update topics list correctly",
-            new HashSet<>(Arrays.asList("topic", "topic1")), topics);
-    }
-
-    @Test
-    public void testListenerCanUnregister() {
-        long time = 0;
-        final Set<String> topics = new HashSet<>();
-        metadata.update(emptyMetadataResponse(), time);
-        final Metadata.Listener listener = new Metadata.Listener() {
-            @Override
-            public void onMetadataUpdate(Cluster cluster, Set<String> unavailableTopics) {
-                topics.clear();
-                topics.addAll(cluster.topics());
-            }
-        };
-        metadata.addListener(listener);
-
-        Map<String, Integer> partitionCounts = new HashMap<>();
-        partitionCounts.put("topic", 1);
-        partitionCounts.put("topic1", 1);
-        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, partitionCounts);
-        metadata.update(metadataResponse, 100);
-
-        metadata.removeListener(listener);
-
-        partitionCounts.clear();
-        partitionCounts.put("topic2", 1);
-        partitionCounts.put("topic3", 1);
-        metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, partitionCounts);
-        metadata.update(metadataResponse, 100);
-
-        assertEquals("Listener did not update topics list correctly",
-            new HashSet<>(Arrays.asList("topic", "topic1")), topics);
-    }
-
-    @Test
-    public void testTopicExpiry() throws Exception {
-        metadata = new Metadata(refreshBackoffMs, metadataExpireMs, true, true, new ClusterResourceListeners());
-
-        // Test that topic is expired if not used within the expiry interval
-        long time = 0;
-        metadata.add("topic1");
-        metadata.update(emptyMetadataResponse(), time);
-        time += Metadata.TOPIC_EXPIRY_MS;
-        metadata.update(emptyMetadataResponse(), time);
-        assertFalse("Unused topic not expired", metadata.containsTopic("topic1"));
-
-        // Test that topic is not expired if used within the expiry interval
-        metadata.add("topic2");
-        metadata.update(emptyMetadataResponse(), time);
-        for (int i = 0; i < 3; i++) {
-            time += Metadata.TOPIC_EXPIRY_MS / 2;
-            metadata.update(emptyMetadataResponse(), time);
-            assertTrue("Topic expired even though in use", metadata.containsTopic("topic2"));
-            metadata.add("topic2");
-        }
-
-        // Test that topics added using setTopics expire
-        HashSet<String> topics = new HashSet<>();
-        topics.add("topic4");
-        metadata.setTopics(topics);
-        metadata.update(emptyMetadataResponse(), time);
-        time += Metadata.TOPIC_EXPIRY_MS;
-        metadata.update(emptyMetadataResponse(), time);
-        assertFalse("Unused topic not expired", metadata.containsTopic("topic4"));
-    }
-
-    @Test
-    public void testNonExpiringMetadata() throws Exception {
-        metadata = new Metadata(refreshBackoffMs, metadataExpireMs, true, false, new ClusterResourceListeners());
-
-        // Test that topic is not expired if not used within the expiry interval
-        long time = 0;
-        metadata.add("topic1");
-        metadata.update(emptyMetadataResponse(), time);
-        time += Metadata.TOPIC_EXPIRY_MS;
-        metadata.update(emptyMetadataResponse(), time);
-        assertTrue("Unused topic expired when expiry disabled", metadata.containsTopic("topic1"));
-
-        // Test that topic is not expired if used within the expiry interval
-        metadata.add("topic2");
-        metadata.update(emptyMetadataResponse(), time);
-        for (int i = 0; i < 3; i++) {
-            time += Metadata.TOPIC_EXPIRY_MS / 2;
-            metadata.update(emptyMetadataResponse(), time);
-            assertTrue("Topic expired even though in use", metadata.containsTopic("topic2"));
-            metadata.add("topic2");
-        }
-
-        // Test that topics added using setTopics don't expire
-        HashSet<String> topics = new HashSet<>();
-        topics.add("topic4");
-        metadata.setTopics(topics);
-        time += metadataExpireMs * 2;
-        metadata.update(emptyMetadataResponse(), time);
-        assertTrue("Unused topic expired when expiry disabled", metadata.containsTopic("topic4"));
-    }
 
     @Test
     public void testRequestUpdate() {
-        metadata = new Metadata(refreshBackoffMs, metadataExpireMs, true, false, new ClusterResourceListeners());
         assertFalse(metadata.updateRequested());
 
         int[] epochs =           {42,   42,    41,    41,    42,    43,   43,    42,    41,    44};
@@ -572,12 +316,6 @@ public class MetadataTest {
         assertEquals(metadata.fetch().partitionCountForTopic("topic-1").longValue(), 5);
         assertTrue(metadata.partitionInfoIfCurrent(tp).isPresent());
         assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 101);
-
-        // Change topic subscription, remove metadata for old topic
-        metadata.setTopics(Collections.singletonList("topic-2"));
-        assertNull(metadata.fetch().partition(tp));
-        assertNull(metadata.fetch().partitionCountForTopic("topic-1"));
-        assertFalse(metadata.partitionInfoIfCurrent(tp).isPresent());
     }
 
     @Test
@@ -637,22 +375,78 @@ public class MetadataTest {
         assertEquals(fromMetadataEmpty, fromClusterEmpty);
     }
 
-    private void clearBackgroundError() {
-        backgroundError.set(null);
+    @Test
+    public void testRequestVersion() {
+        Time time = new MockTime();
+
+        metadata.requestUpdate();
+        Metadata.MetadataRequestAndVersion versionAndBuilder = metadata.newMetadataRequestAndVersion();
+        metadata.update(versionAndBuilder.requestVersion,
+                TestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)), time.milliseconds());
+        assertFalse(metadata.updateRequested());
+
+        // bump the request version for new topics added to the metadata
+        metadata.requestUpdateForNewTopics();
+
+        // simulating a bump while a metadata request is in flight
+        versionAndBuilder = metadata.newMetadataRequestAndVersion();
+        metadata.requestUpdateForNewTopics();
+        metadata.update(versionAndBuilder.requestVersion,
+                TestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)), time.milliseconds());
+
+        // metadata update is still needed
+        assertTrue(metadata.updateRequested());
+
+        // the next update will resolve it
+        versionAndBuilder = metadata.newMetadataRequestAndVersion();
+        metadata.update(versionAndBuilder.requestVersion,
+                TestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)), time.milliseconds());
+        assertFalse(metadata.updateRequested());
+    }
+
+    @Test
+    public void testInvalidTopicError() {
+        Time time = new MockTime();
+
+        String invalidTopic = "topic dfsa";
+        MetadataResponse invalidTopicResponse = TestUtils.metadataUpdateWith("clusterId", 1,
+                Collections.singletonMap(invalidTopic, Errors.INVALID_TOPIC_EXCEPTION), Collections.emptyMap());
+        metadata.update(invalidTopicResponse, time.milliseconds());
+
+        InvalidTopicException e = assertThrows(InvalidTopicException.class, () -> metadata.maybeThrowException());
+
+        assertEquals(Collections.singleton(invalidTopic), e.invalidTopics());
+        // We clear the exception once it has been raised to the user
+        assertNull(metadata.getAndClearMetadataException());
+
+        // Reset the invalid topic error
+        metadata.update(invalidTopicResponse, time.milliseconds());
+
+        // If we get a good update, the error should clear even if we haven't had a chance to raise it to the user
+        metadata.update(emptyMetadataResponse(), time.milliseconds());
+        assertNull(metadata.getAndClearMetadataException());
     }
 
-    private Thread asyncFetch(final String topic, final long maxWaitMs) {
-        Thread thread = new Thread() {
-            public void run() {
-                try {
-                    while (metadata.fetch().partitionsForTopic(topic).isEmpty())
-                        metadata.awaitUpdate(metadata.requestUpdate(), maxWaitMs);
-                } catch (Exception e) {
-                    backgroundError.set(e);
-                }
-            }
-        };
-        thread.start();
-        return thread;
+    @Test
+    public void testTopicAuthorizationError() {
+        Time time = new MockTime();
+
+        String invalidTopic = "foo";
+        MetadataResponse unauthorizedTopicResponse = TestUtils.metadataUpdateWith("clusterId", 1,
+                Collections.singletonMap(invalidTopic, Errors.TOPIC_AUTHORIZATION_FAILED), Collections.emptyMap());
+        metadata.update(unauthorizedTopicResponse, time.milliseconds());
+
+        TopicAuthorizationException e = assertThrows(TopicAuthorizationException.class, () -> metadata.maybeThrowException());
+        assertEquals(Collections.singleton(invalidTopic), e.unauthorizedTopics());
+        // We clear the exception once it has been raised to the user
+        assertNull(metadata.getAndClearMetadataException());
+
+        // Reset the unauthorized topic error
+        metadata.update(unauthorizedTopicResponse, time.milliseconds());
+
+        // If we get a good update, the error should clear even if we haven't had a chance to raise it to the user
+        metadata.update(emptyMetadataResponse(), time.milliseconds());
+        assertNull(metadata.getAndClearMetadataException());
     }
+
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/MockClient.java b/clients/src/test/java/org/apache/kafka/clients/MockClient.java
index 0dd42f8..7a1febd 100644
--- a/clients/src/test/java/org/apache/kafka/clients/MockClient.java
+++ b/clients/src/test/java/org/apache/kafka/clients/MockClient.java
@@ -23,6 +23,7 @@ import org.apache.kafka.common.errors.UnsupportedVersionException;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.requests.AbstractRequest;
 import org.apache.kafka.common.requests.AbstractResponse;
+import org.apache.kafka.common.requests.MetadataRequest;
 import org.apache.kafka.common.requests.MetadataResponse;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.test.TestCondition;
@@ -661,13 +662,26 @@ public class MockClient implements KafkaClient {
             update(time, lastUpdate);
         }
 
+        private void maybeCheckExpectedTopics(MetadataUpdate update, MetadataRequest.Builder builder) {
+            if (update.expectMatchRefreshTopics) {
+                if (builder.topics() == null)
+                    throw new IllegalStateException("The metadata topics does not match expectation. "
+                            + "Expected topics: " + update.topics()
+                            + ", asked topics: ALL");
+
+                Set<String> requestedTopics = new HashSet<>(builder.topics());
+                if (!requestedTopics.equals(update.topics())) {
+                    throw new IllegalStateException("The metadata topics does not match expectation. "
+                            + "Expected topics: " + update.topics()
+                            + ", asked topics: " + requestedTopics);
+                }
+            }
+        }
+
         @Override
         public void update(Time time, MetadataUpdate update) {
-            if (update.expectMatchRefreshTopics && !metadata.topics().equals(update.topics())) {
-                throw new IllegalStateException("The metadata topics does not match expectation. "
-                        + "Expected topics: " + update.topics()
-                        + ", asked topics: " + metadata.topics());
-            }
+            MetadataRequest.Builder builder = metadata.newMetadataRequestBuilder();
+            maybeCheckExpectedTopics(update, builder);
             metadata.update(update.updateResponse, time.milliseconds());
             this.lastUpdate = update;
         }
diff --git a/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java b/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java
index e098236..b40f690 100644
--- a/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java
@@ -26,7 +26,6 @@ import org.apache.kafka.common.protocol.types.Struct;
 import org.apache.kafka.common.record.MemoryRecords;
 import org.apache.kafka.common.requests.ApiVersionsResponse;
 import org.apache.kafka.common.requests.MetadataRequest;
-import org.apache.kafka.common.requests.MetadataResponse;
 import org.apache.kafka.common.requests.ProduceRequest;
 import org.apache.kafka.common.requests.ResponseHeader;
 import org.apache.kafka.common.utils.LogContext;
@@ -55,10 +54,7 @@ public class NetworkClientTest {
     protected final int defaultRequestTimeoutMs = 1000;
     protected final MockTime time = new MockTime();
     protected final MockSelector selector = new MockSelector(time);
-    protected final Metadata metadata = new Metadata(0, Long.MAX_VALUE, true);
-    protected final MetadataResponse initialMetadataResponse = TestUtils.metadataUpdateWith(1,
-            Collections.singletonMap("test", 1));
-    protected final Node node = initialMetadataResponse.brokers().iterator().next();
+    protected final Node node = TestUtils.singletonCluster().nodes().iterator().next();
     protected final long reconnectBackoffMsTest = 10 * 1000;
     protected final long reconnectBackoffMaxMsTest = 10 * 10000;
 
@@ -68,19 +64,19 @@ public class NetworkClientTest {
     private final NetworkClient clientWithNoVersionDiscovery = createNetworkClientWithNoVersionDiscovery();
 
     private NetworkClient createNetworkClient(long reconnectBackoffMaxMs) {
-        return new NetworkClient(selector, metadata, "mock", Integer.MAX_VALUE,
+        return new NetworkClient(selector, new ManualMetadataUpdater(Collections.singletonList(node)), "mock", Integer.MAX_VALUE,
                 reconnectBackoffMsTest, reconnectBackoffMaxMs, 64 * 1024, 64 * 1024,
                 defaultRequestTimeoutMs, ClientDnsLookup.DEFAULT, time, true, new ApiVersions(), new LogContext());
     }
 
     private NetworkClient createNetworkClientWithStaticNodes() {
-        return new NetworkClient(selector, new ManualMetadataUpdater(Arrays.asList(node)),
+        return new NetworkClient(selector, new ManualMetadataUpdater(Collections.singletonList(node)),
                 "mock-static", Integer.MAX_VALUE, 0, 0, 64 * 1024, 64 * 1024, defaultRequestTimeoutMs,
                 ClientDnsLookup.DEFAULT, time, true, new ApiVersions(), new LogContext());
     }
 
     private NetworkClient createNetworkClientWithNoVersionDiscovery() {
-        return new NetworkClient(selector, metadata, "mock", Integer.MAX_VALUE,
+        return new NetworkClient(selector, new ManualMetadataUpdater(Collections.singletonList(node)), "mock", Integer.MAX_VALUE,
                 reconnectBackoffMsTest, reconnectBackoffMaxMsTest,
                 64 * 1024, 64 * 1024, defaultRequestTimeoutMs,
                 ClientDnsLookup.DEFAULT, time, false, new ApiVersions(), new LogContext());
@@ -89,7 +85,6 @@ public class NetworkClientTest {
     @Before
     public void setup() {
         selector.reset();
-        metadata.update(initialMetadataResponse, time.milliseconds());
     }
 
     @Test(expected = IllegalStateException.class)
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
index 61702a1..a5161b4 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
@@ -18,10 +18,10 @@ package org.apache.kafka.clients.consumer;
 
 import org.apache.kafka.clients.ClientRequest;
 import org.apache.kafka.clients.KafkaClient;
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.consumer.internals.ConsumerCoordinator;
 import org.apache.kafka.clients.consumer.internals.ConsumerInterceptors;
+import org.apache.kafka.clients.consumer.internals.ConsumerMetadata;
 import org.apache.kafka.clients.consumer.internals.ConsumerMetrics;
 import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient;
 import org.apache.kafka.clients.consumer.internals.ConsumerProtocol;
@@ -40,6 +40,7 @@ import org.apache.kafka.common.errors.InvalidGroupIdException;
 import org.apache.kafka.common.errors.InvalidTopicException;
 import org.apache.kafka.common.errors.WakeupException;
 import org.apache.kafka.common.message.LeaveGroupResponseData;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.network.Selectable;
@@ -235,7 +236,7 @@ public class KafkaConsumerTest {
     @Test(expected = IllegalArgumentException.class)
     public void testSubscriptionOnNullTopic() {
         try (KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId)) {
-            consumer.subscribe(singletonList((String) null));
+            consumer.subscribe(singletonList(null));
         }
     }
 
@@ -371,7 +372,8 @@ public class KafkaConsumerTest {
     @Test
     public void verifyHeartbeatSent() throws Exception {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -379,7 +381,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -404,14 +406,15 @@ public class KafkaConsumerTest {
     @Test
     public void verifyHeartbeatSentWhenFetchedDataReady() throws Exception {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -437,7 +440,8 @@ public class KafkaConsumerTest {
     @Test
     public void verifyPollTimesOutDuringMetadataUpdate() {
         final Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        final ConsumerMetadata metadata = createMetadata(subscription);
         final MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -445,7 +449,7 @@ public class KafkaConsumerTest {
 
         final PartitionAssignor assignor = new RoundRobinAssignor();
 
-        final KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        final KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -460,7 +464,8 @@ public class KafkaConsumerTest {
     @Test
     public void verifyDeprecatedPollDoesNotTimeOutDuringMetadataUpdate() {
         final Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        final ConsumerMetadata metadata = createMetadata(subscription);
         final MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -468,7 +473,7 @@ public class KafkaConsumerTest {
 
         final PartitionAssignor assignor = new RoundRobinAssignor();
 
-        final KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        final KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -484,13 +489,14 @@ public class KafkaConsumerTest {
     @Test
     public void verifyNoCoordinatorLookupForManualAssignmentWithSeek() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         consumer.assign(singleton(tp0));
         consumer.seekToBeginning(singleton(tp0));
 
@@ -511,11 +517,12 @@ public class KafkaConsumerTest {
         // a reset on another partition.
 
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
         initMetadata(client, Collections.singletonMap(topic, 2));
 
-        KafkaConsumer<String, String> consumer = newConsumerNoAutoCommit(time, client, metadata);
+        KafkaConsumer<String, String> consumer = newConsumerNoAutoCommit(time, client, subscription, metadata);
         consumer.assign(Arrays.asList(tp0, tp1));
         consumer.seekToEnd(singleton(tp0));
         consumer.seekToBeginning(singleton(tp1));
@@ -555,7 +562,8 @@ public class KafkaConsumerTest {
     @Test(expected = NoOffsetForPartitionException.class)
     public void testMissingOffsetNoResetPolicy() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -563,8 +571,8 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor,
-                OffsetResetStrategy.NONE, true, groupId);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor,
+                true, groupId);
         consumer.assign(singletonList(tp0));
 
         client.prepareResponseFrom(new FindCoordinatorResponse(Errors.NONE, node), node);
@@ -578,7 +586,8 @@ public class KafkaConsumerTest {
     @Test
     public void testResetToCommittedOffset() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -586,8 +595,8 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor,
-                OffsetResetStrategy.NONE, true, groupId);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor,
+                true, groupId);
         consumer.assign(singletonList(tp0));
 
         client.prepareResponseFrom(new FindCoordinatorResponse(Errors.NONE, node), node);
@@ -602,7 +611,8 @@ public class KafkaConsumerTest {
     @Test
     public void testResetUsingAutoResetPolicy() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.LATEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -610,8 +620,8 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor,
-                OffsetResetStrategy.LATEST, true, groupId);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor,
+                true, groupId);
         consumer.assign(singletonList(tp0));
 
         client.prepareResponseFrom(new FindCoordinatorResponse(Errors.NONE, node), node);
@@ -631,7 +641,8 @@ public class KafkaConsumerTest {
         long offset2 = 20000;
 
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 2));
@@ -639,7 +650,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         consumer.assign(singletonList(tp0));
 
         // lookup coordinator
@@ -668,7 +679,8 @@ public class KafkaConsumerTest {
     @Test
     public void testAutoCommitSentBeforePositionUpdate() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -676,7 +688,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -704,7 +716,8 @@ public class KafkaConsumerTest {
     public void testRegexSubscription() {
         String unmatchedTopic = "unmatched";
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         Map<String, Integer> partitionCounts = new HashMap<>();
@@ -715,7 +728,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         prepareRebalance(client, node, singleton(topic), assignor, singletonList(tp0), null);
 
         consumer.subscribe(Pattern.compile(topic), getConsumerRebalanceListener(consumer));
@@ -737,7 +750,8 @@ public class KafkaConsumerTest {
         TopicPartition otherTopicPartition = new TopicPartition(otherTopic, 0);
 
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         Map<String, Integer> partitionCounts = new HashMap<>();
@@ -746,7 +760,7 @@ public class KafkaConsumerTest {
         initMetadata(client, partitionCounts);
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, false);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false);
 
         Node coordinator = prepareRebalance(client, node, singleton(topic), assignor, singletonList(tp0), null);
         consumer.subscribe(Pattern.compile(topic), getConsumerRebalanceListener(consumer));
@@ -769,7 +783,8 @@ public class KafkaConsumerTest {
     @Test
     public void testWakeupWithFetchDataAvailable() throws Exception {
         final Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -777,7 +792,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -818,7 +833,8 @@ public class KafkaConsumerTest {
     @Test
     public void testPollThrowsInterruptExceptionIfInterrupted() {
         final Time time = new MockTime();
-        final Metadata metadata = createMetadata();
+        final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        final ConsumerMetadata metadata = createMetadata(subscription);
         final MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -826,7 +842,7 @@ public class KafkaConsumerTest {
 
         final PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, false);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -847,7 +863,8 @@ public class KafkaConsumerTest {
     @Test
     public void fetchResponseWithUnexpectedPartitionIsIgnored() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -855,7 +872,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         consumer.subscribe(singletonList(topic), getConsumerRebalanceListener(consumer));
 
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
@@ -882,7 +899,8 @@ public class KafkaConsumerTest {
     @Test
     public void testSubscriptionChangesWithAutoCommitEnabled() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         Map<String, Integer> tpCounts = new HashMap<>();
@@ -894,7 +912,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
 
         // initial subscription
         consumer.subscribe(Arrays.asList(topic, topic2), getConsumerRebalanceListener(consumer));
@@ -996,7 +1014,8 @@ public class KafkaConsumerTest {
     @Test
     public void testSubscriptionChangesWithAutoCommitDisabled() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         Map<String, Integer> tpCounts = new HashMap<>();
@@ -1007,7 +1026,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, false);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false);
 
         // initial subscription
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
@@ -1057,7 +1076,8 @@ public class KafkaConsumerTest {
     @Test
     public void testManualAssignmentChangeWithAutoCommitEnabled() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         Map<String, Integer> tpCounts = new HashMap<>();
@@ -1068,7 +1088,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
 
         // lookup coordinator
         client.prepareResponseFrom(new FindCoordinatorResponse(Errors.NONE, node), node);
@@ -1112,7 +1132,8 @@ public class KafkaConsumerTest {
     @Test
     public void testManualAssignmentChangeWithAutoCommitDisabled() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         Map<String, Integer> tpCounts = new HashMap<>();
@@ -1123,7 +1144,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, false);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false);
 
         // lookup coordinator
         client.prepareResponseFrom(new FindCoordinatorResponse(Errors.NONE, node), node);
@@ -1168,7 +1189,8 @@ public class KafkaConsumerTest {
     @Test
     public void testOffsetOfPausedPartitions() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 2));
@@ -1176,7 +1198,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
 
         // lookup coordinator
         client.prepareResponseFrom(new FindCoordinatorResponse(Errors.NONE, node), node);
@@ -1362,7 +1384,8 @@ public class KafkaConsumerTest {
     @Test
     public void shouldAttemptToRejoinGroupAfterSyncGroupFailed() throws Exception {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -1370,7 +1393,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, false);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         client.prepareResponseFrom(new FindCoordinatorResponse(Errors.NONE, node), node);
         Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port());
@@ -1430,7 +1453,8 @@ public class KafkaConsumerTest {
                                    long waitMs,
                                    boolean interrupt) throws Exception {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -1438,7 +1462,7 @@ public class KafkaConsumerTest {
 
         PartitionAssignor assignor = new RoundRobinAssignor();
 
-        final KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, false);
+        final KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -1521,38 +1545,38 @@ public class KafkaConsumerTest {
 
     @Test(expected = AuthenticationException.class)
     public void testPartitionsForAuthenticationFailure() {
-        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthentication();
+        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthenticationError();
         consumer.partitionsFor("some other topic");
     }
 
     @Test(expected = AuthenticationException.class)
     public void testBeginningOffsetsAuthenticationFailure() {
-        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthentication();
+        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthenticationError();
         consumer.beginningOffsets(Collections.singleton(tp0));
     }
 
     @Test(expected = AuthenticationException.class)
     public void testEndOffsetsAuthenticationFailure() {
-        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthentication();
+        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthenticationError();
         consumer.endOffsets(Collections.singleton(tp0));
     }
 
     @Test(expected = AuthenticationException.class)
     public void testPollAuthenticationFailure() {
-        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthentication();
+        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthenticationError();
         consumer.subscribe(singleton(topic));
         consumer.poll(Duration.ZERO);
     }
 
     @Test(expected = AuthenticationException.class)
     public void testOffsetsForTimesAuthenticationFailure() {
-        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthentication();
+        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthenticationError();
         consumer.offsetsForTimes(singletonMap(tp0, 0L));
     }
 
     @Test(expected = AuthenticationException.class)
     public void testCommitSyncAuthenticationFailure() {
-        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthentication();
+        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthenticationError();
         Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
         offsets.put(tp0, new OffsetAndMetadata(10L));
         consumer.commitSync(offsets);
@@ -1560,13 +1584,14 @@ public class KafkaConsumerTest {
 
     @Test(expected = AuthenticationException.class)
     public void testCommittedAuthenticationFaiure() {
-        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthentication();
+        final KafkaConsumer<String, String> consumer = consumerWithPendingAuthenticationError();
         consumer.committed(tp0);
     }
 
-    private KafkaConsumer<String, String> consumerWithPendingAuthentication() {
+    private KafkaConsumer<String, String> consumerWithPendingAuthenticationError() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, singletonMap(topic, 1));
@@ -1575,7 +1600,7 @@ public class KafkaConsumerTest {
         PartitionAssignor assignor = new RangeAssignor();
 
         client.createPendingAuthenticationError(node, 0);
-        return newConsumer(time, client, metadata, assignor, false);
+        return newConsumer(time, client, subscription, metadata, assignor, false);
     }
 
     private ConsumerRebalanceListener getConsumerRebalanceListener(final KafkaConsumer<String, String> consumer) {
@@ -1593,8 +1618,9 @@ public class KafkaConsumerTest {
         };
     }
 
-    private Metadata createMetadata() {
-        return new Metadata(0, Long.MAX_VALUE, true);
+    private ConsumerMetadata createMetadata(SubscriptionState subscription) {
+        return new ConsumerMetadata(0, Long.MAX_VALUE, false, subscription,
+                new LogContext(), new ClusterResourceListeners());
     }
 
     private Node prepareRebalance(MockClient client, Node node, final Set<String> subscribedTopics, PartitionAssignor assignor, List<TopicPartition> partitions, Node coordinator) {
@@ -1752,30 +1778,25 @@ public class KafkaConsumerTest {
 
     private KafkaConsumer<String, String> newConsumer(Time time,
                                                       KafkaClient client,
-                                                      Metadata metadata,
+                                                      SubscriptionState subscription,
+                                                      ConsumerMetadata metadata,
                                                       PartitionAssignor assignor,
                                                       boolean autoCommitEnabled) {
-        return newConsumer(time, client, metadata, assignor, OffsetResetStrategy.EARLIEST, autoCommitEnabled, groupId);
+        return newConsumer(time, client, subscription, metadata, assignor, autoCommitEnabled, groupId);
     }
 
     private KafkaConsumer<String, String> newConsumerNoAutoCommit(Time time,
                                                                   KafkaClient client,
-                                                                  Metadata metadata) {
-        return newConsumer(time, client, metadata, new RangeAssignor(), OffsetResetStrategy.EARLIEST, false, groupId);
+                                                                  SubscriptionState subscription,
+                                                                  ConsumerMetadata metadata) {
+        return newConsumer(time, client, subscription, metadata, new RangeAssignor(), false, groupId);
     }
 
     private KafkaConsumer<String, String> newConsumer(Time time,
                                                       KafkaClient client,
-                                                      Metadata metadata,
-                                                      String groupId) {
-        return newConsumer(time, client, metadata, new RangeAssignor(), OffsetResetStrategy.LATEST, true, groupId);
-    }
-
-    private KafkaConsumer<String, String> newConsumer(Time time,
-                                                      KafkaClient client,
-                                                      Metadata metadata,
+                                                      SubscriptionState subscription,
+                                                      ConsumerMetadata metadata,
                                                       PartitionAssignor assignor,
-                                                      OffsetResetStrategy resetStrategy,
                                                       boolean autoCommitEnabled,
                                                       String groupId) {
         String clientId = "mock-consumer";
@@ -1783,7 +1804,6 @@ public class KafkaConsumerTest {
         long retryBackoffMs = 100;
         int requestTimeoutMs = 30000;
         int defaultApiTimeoutMs = 30000;
-        boolean excludeInternalTopics = true;
         int minBytes = 1;
         int maxBytes = Integer.MAX_VALUE;
         int maxWaitMs = 500;
@@ -1802,7 +1822,6 @@ public class KafkaConsumerTest {
         ConsumerMetrics metricsRegistry = new ConsumerMetrics(metricGroupPrefix);
 
         LogContext loggerFactory = new LogContext();
-        SubscriptionState subscriptions = new SubscriptionState(loggerFactory, resetStrategy);
         ConsumerNetworkClient consumerClient = new ConsumerNetworkClient(loggerFactory, client, metadata, time,
                 retryBackoffMs, requestTimeoutMs, heartbeatIntervalMs);
 
@@ -1816,7 +1835,7 @@ public class KafkaConsumerTest {
                 heartbeat,
                 assignors,
                 metadata,
-                subscriptions,
+                subscription,
                 metrics,
                 metricGroupPrefix,
                 time,
@@ -1824,7 +1843,6 @@ public class KafkaConsumerTest {
                 autoCommitEnabled,
                 autoCommitIntervalMs,
                 interceptors,
-                excludeInternalTopics,
                 true);
 
         Fetcher<String, String> fetcher = new Fetcher<>(
@@ -1839,7 +1857,7 @@ public class KafkaConsumerTest {
                 keyDeserializer,
                 valueDeserializer,
                 metadata,
-                subscriptions,
+                subscription,
                 metrics,
                 metricsRegistry.fetcherMetrics,
                 time,
@@ -1858,7 +1876,7 @@ public class KafkaConsumerTest {
                 time,
                 consumerClient,
                 metrics,
-                subscriptions,
+                subscription,
                 metadata,
                 retryBackoffMs,
                 requestTimeoutMs,
@@ -1889,7 +1907,8 @@ public class KafkaConsumerTest {
     @Test(expected = InvalidTopicException.class)
     public void testSubscriptionOnInvalidTopic() {
         Time time = new MockTime();
-        Metadata metadata = createMetadata();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
@@ -1908,7 +1927,7 @@ public class KafkaConsumerTest {
                 topicMetadata);
         client.prepareMetadataUpdate(updateResponse);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor, true);
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true);
         consumer.subscribe(singleton(invalidTopicName), getConsumerRebalanceListener(consumer));
 
         consumer.poll(Duration.ZERO);
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
index 9b4bdb9..9a68db7 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
@@ -16,11 +16,12 @@
  */
 package org.apache.kafka.clients.consumer.internals;
 
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
+import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.errors.WakeupException;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.requests.AbstractRequest;
@@ -86,11 +87,14 @@ public class AbstractCoordinatorTest {
     }
 
     private void setupCoordinator(int retryBackoffMs, int rebalanceTimeoutMs) {
+        LogContext logContext = new LogContext();
         this.mockTime = new MockTime();
-        Metadata metadata = new Metadata(retryBackoffMs, 60 * 60 * 1000L, true);
+        ConsumerMetadata metadata = new ConsumerMetadata(retryBackoffMs, 60 * 60 * 1000L,
+                false, new SubscriptionState(logContext, OffsetResetStrategy.EARLIEST),
+                logContext, new ClusterResourceListeners());
 
         this.mockClient = new MockClient(mockTime, metadata);
-        this.consumerClient = new ConsumerNetworkClient(new LogContext(), mockClient, metadata, mockTime,
+        this.consumerClient = new ConsumerNetworkClient(logContext, mockClient, metadata, mockTime,
                 retryBackoffMs, REQUEST_TIMEOUT_MS, HEARTBEAT_INTERVAL_MS);
         Metrics metrics = new Metrics();
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
index 290b428..885b357 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
@@ -17,10 +17,8 @@
 package org.apache.kafka.clients.consumer.internals;
 
 import org.apache.kafka.clients.ClientResponse;
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.consumer.CommitFailedException;
-import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetCommitCallback;
@@ -39,6 +37,7 @@ import org.apache.kafka.common.errors.DisconnectException;
 import org.apache.kafka.common.errors.GroupAuthorizationException;
 import org.apache.kafka.common.errors.OffsetMetadataTooLarge;
 import org.apache.kafka.common.errors.WakeupException;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.internals.Topic;
 import org.apache.kafka.common.message.LeaveGroupResponseData;
 import org.apache.kafka.common.metrics.Metrics;
@@ -124,7 +123,7 @@ public class ConsumerCoordinatorTest {
     });
     private Node node = metadataResponse.brokers().iterator().next();
     private SubscriptionState subscriptions;
-    private Metadata metadata;
+    private ConsumerMetadata metadata;
     private Metrics metrics;
     private ConsumerNetworkClient consumerClient;
     private MockRebalanceListener rebalanceListener;
@@ -135,7 +134,8 @@ public class ConsumerCoordinatorTest {
     public void setup() {
         LogContext logContext = new LogContext();
         this.subscriptions = new SubscriptionState(logContext, OffsetResetStrategy.EARLIEST);
-        this.metadata = new Metadata(0, Long.MAX_VALUE, true);
+        this.metadata = new ConsumerMetadata(0, Long.MAX_VALUE, false,
+                subscriptions, logContext, new ClusterResourceListeners());
         this.client = new MockClient(time, metadata);
         this.client.updateMetadata(metadataResponse);
         this.consumerClient = new ConsumerNetworkClient(logContext, client, metadata, time, 100,
@@ -145,7 +145,7 @@ public class ConsumerCoordinatorTest {
         this.mockOffsetCommitCallback = new MockCommitCallback();
         this.partitionAssignor.clear();
 
-        this.coordinator = buildCoordinator(metrics, assignors, ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, false, true);
+        this.coordinator = buildCoordinator(metrics, assignors, false, true);
     }
 
     @After
@@ -367,7 +367,6 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         // ensure metadata is up-to-date for leader
-        metadata.setTopics(singletonList(topic1));
         client.updateMetadata(metadataResponse);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -385,7 +384,6 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         // ensure metadata is up-to-date for leader
-        metadata.setTopics(singletonList(topic1));
         client.updateMetadata(metadataResponse);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -422,10 +420,6 @@ public class ConsumerCoordinatorTest {
 
         subscriptions.subscribe(singleton(topic2), rebalanceListener);
 
-        // ensure metadata is up-to-date for leader
-        metadata.setTopics(Arrays.asList(topic1, topic2));
-        client.updateMetadata(metadataResponse);
-
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
@@ -484,10 +478,6 @@ public class ConsumerCoordinatorTest {
 
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
-        // ensure metadata is up-to-date for leader
-        metadata.setTopics(singletonList(topic1));
-        client.updateMetadata(metadataResponse);
-
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
@@ -516,7 +506,6 @@ public class ConsumerCoordinatorTest {
 
         // partially update the metadata with one topic first,
         // let the leader to refresh metadata during assignment
-        metadata.setTopics(singletonList(topic1));
         client.updateMetadata(TestUtils.metadataUpdateWith(1, singletonMap(topic1, 1)));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -556,8 +545,8 @@ public class ConsumerCoordinatorTest {
         final String consumerId = "leader";
 
         subscriptions.subscribe(Pattern.compile(".*"), rebalanceListener);
-        metadata.needMetadataForAllTopics(true);
         client.updateMetadata(TestUtils.metadataUpdateWith(1, singletonMap(topic1, 1)));
+        coordinator.maybeUpdateSubscriptionMetadata();
 
         assertEquals(singleton(topic1), subscriptions.subscription());
 
@@ -582,6 +571,7 @@ public class ConsumerCoordinatorTest {
                 return true;
             }
         }, syncGroupResponse(singletonList(t1p), Errors.NONE));
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         List<TopicPartition> newAssignment = Arrays.asList(t1p, t2p);
         Set<TopicPartition> newAssignmentSet = new HashSet<>(newAssignment);
@@ -621,6 +611,7 @@ public class ConsumerCoordinatorTest {
 
         subscriptions.subscribe(Pattern.compile(".*"), rebalanceListener);
         client.updateMetadata(TestUtils.metadataUpdateWith(1, singletonMap(topic1, 1)));
+        coordinator.maybeUpdateSubscriptionMetadata();
         assertEquals(singleton(topic1), subscriptions.subscription());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -663,7 +654,6 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
         // ensure metadata is up-to-date for leader
-        metadata.setTopics(singletonList(topic1));
         client.updateMetadata(metadataResponse);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -763,7 +753,6 @@ public class ConsumerCoordinatorTest {
 
         // partially update the metadata with one topic first,
         // let the leader to refresh metadata during assignment
-        metadata.setTopics(singletonList(topic1));
         client.updateMetadata(TestUtils.metadataUpdateWith(1, singletonMap(topic1, 1)));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -974,10 +963,8 @@ public class ConsumerCoordinatorTest {
         final String consumerId = "consumer";
 
         // ensure metadata is up-to-date for leader
-        metadata.setTopics(singletonList(topic1));
-        client.updateMetadata(metadataResponse);
-
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        client.updateMetadata(metadataResponse);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -995,6 +982,7 @@ public class ConsumerCoordinatorTest {
 
         // a new partition is added to the topic
         metadata.update(TestUtils.metadataUpdateWith(1, singletonMap(topic1, 2)), time.milliseconds());
+        coordinator.maybeUpdateSubscriptionMetadata();
 
         // we should detect the change and ask for reassignment
         assertTrue(coordinator.rejoinNeededOrPending());
@@ -1011,7 +999,6 @@ public class ConsumerCoordinatorTest {
         List<String> topics = Arrays.asList(topic1, topic2);
 
         subscriptions.subscribe(new HashSet<>(topics), rebalanceListener);
-        metadata.setTopics(topics);
 
         // we only have metadata for one topic initially
         client.updateMetadata(TestUtils.metadataUpdateWith(1, singletonMap(topic1, 1)));
@@ -1041,6 +1028,7 @@ public class ConsumerCoordinatorTest {
                 return false;
             }
         }, syncGroupResponse(Collections.singletonList(tp1), Errors.NONE));
+        coordinator.poll(time.timer(Long.MAX_VALUE));
 
         // the metadata update should trigger a second rebalance
         client.prepareResponse(joinGroupLeaderResponse(2, consumerId, memberSubscriptions, Errors.NONE));
@@ -1055,7 +1043,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testWakeupFromAssignmentCallback() {
         ConsumerCoordinator coordinator = buildCoordinator(new Metrics(), assignors,
-                ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, false, true);
+                false, true);
 
         final String topic = "topic1";
         TopicPartition partition = new TopicPartition(topic, 0);
@@ -1073,7 +1061,6 @@ public class ConsumerCoordinatorTest {
         };
 
         subscriptions.subscribe(topics, rebalanceListener);
-        metadata.setTopics(topics);
 
         // we only have metadata for one topic initially
         client.updateMetadata(TestUtils.metadataUpdateWith(1, singletonMap(topic1, 1)));
@@ -1118,23 +1105,17 @@ public class ConsumerCoordinatorTest {
         unavailableTopicTest(true, Collections.singleton("notmatching"));
     }
 
-    @Test
-    public void testAssignWithTopicUnavailable() {
-        unavailableTopicTest(true, Collections.emptySet());
-    }
-
     private void unavailableTopicTest(boolean patternSubscribe, Set<String> unavailableTopicsInLastMetadata) {
         final String consumerId = "consumer";
 
-        metadata.setTopics(singletonList(topic1));
-        client.prepareMetadataUpdate(TestUtils.metadataUpdateWith("kafka-cluster", 1,
-                Collections.singletonMap(topic1, Errors.UNKNOWN_TOPIC_OR_PARTITION), Collections.emptyMap()));
-
         if (patternSubscribe)
             subscriptions.subscribe(Pattern.compile("test.*"), rebalanceListener);
         else
             subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
+        client.prepareMetadataUpdate(TestUtils.metadataUpdateWith("kafka-cluster", 1,
+                Collections.singletonMap(topic1, Errors.UNKNOWN_TOPIC_OR_PARTITION), Collections.emptyMap()));
+
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
@@ -1167,21 +1148,34 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testExcludeInternalTopicsConfigOption() {
-        subscriptions.subscribe(Pattern.compile(".*"), rebalanceListener);
-
-        client.updateMetadata(TestUtils.metadataUpdateWith(1, singletonMap(Topic.GROUP_METADATA_TOPIC_NAME, 1)));
-
-        assertFalse(subscriptions.subscription().contains(Topic.GROUP_METADATA_TOPIC_NAME));
+        testInternalTopicInclusion(false);
     }
 
     @Test
     public void testIncludeInternalTopicsConfigOption() {
-        coordinator = buildCoordinator(new Metrics(), assignors, false, false, true);
+        testInternalTopicInclusion(true);
+    }
+
+    private void testInternalTopicInclusion(boolean includeInternalTopics) {
+        metadata = new ConsumerMetadata(0, Long.MAX_VALUE, includeInternalTopics,
+                subscriptions, new LogContext(), new ClusterResourceListeners());
+        client = new MockClient(time, metadata);
+        coordinator = buildCoordinator(new Metrics(), assignors, false, true);
+
         subscriptions.subscribe(Pattern.compile(".*"), rebalanceListener);
 
-        client.updateMetadata(TestUtils.metadataUpdateWith(1, singletonMap(Topic.GROUP_METADATA_TOPIC_NAME, 2)));
+        Node node = new Node(0, "localhost", 9999);
+        MetadataResponse.PartitionMetadata partitionMetadata =
+                new MetadataResponse.PartitionMetadata(Errors.NONE, 0, node, Optional.empty(),
+                        singletonList(node), singletonList(node), singletonList(node));
+        MetadataResponse.TopicMetadata topicMetadata = new MetadataResponse.TopicMetadata(Errors.NONE,
+                Topic.GROUP_METADATA_TOPIC_NAME, true, singletonList(partitionMetadata));
+
+        client.updateMetadata(new MetadataResponse(singletonList(node), "clusterId", node.id(),
+                singletonList(topicMetadata)));
+        coordinator.maybeUpdateSubscriptionMetadata();
 
-        assertTrue(subscriptions.subscription().contains(Topic.GROUP_METADATA_TOPIC_NAME));
+        assertEquals(includeInternalTopics, subscriptions.subscription().contains(Topic.GROUP_METADATA_TOPIC_NAME));
     }
 
     @Test
@@ -1295,7 +1289,7 @@ public class ConsumerCoordinatorTest {
         final String consumerId = "consumer";
 
         ConsumerCoordinator coordinator = buildCoordinator(new Metrics(), assignors,
-                ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, true, true);
+                true, true);
 
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
         joinAsFollowerAndReceiveAssignment(consumerId, coordinator, singletonList(t1p));
@@ -1311,7 +1305,7 @@ public class ConsumerCoordinatorTest {
     public void testAutoCommitRetryBackoff() {
         final String consumerId = "consumer";
         ConsumerCoordinator coordinator = buildCoordinator(new Metrics(), assignors,
-                ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, true, true);
+                true, true);
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
         joinAsFollowerAndReceiveAssignment(consumerId, coordinator, singletonList(t1p));
 
@@ -1345,7 +1339,7 @@ public class ConsumerCoordinatorTest {
     public void testAutoCommitAwaitsInterval() {
         final String consumerId = "consumer";
         ConsumerCoordinator coordinator = buildCoordinator(new Metrics(), assignors,
-                ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, true, true);
+                true, true);
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
         joinAsFollowerAndReceiveAssignment(consumerId, coordinator, singletonList(t1p));
 
@@ -1384,7 +1378,7 @@ public class ConsumerCoordinatorTest {
         final String consumerId = "consumer";
 
         ConsumerCoordinator coordinator = buildCoordinator(new Metrics(), assignors,
-                ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, true, true);
+                true, true);
 
         subscriptions.subscribe(singleton(topic1), rebalanceListener);
 
@@ -1410,7 +1404,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testAutoCommitManualAssignment() {
         ConsumerCoordinator coordinator = buildCoordinator(new Metrics(), assignors,
-                ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, true, true);
+                true, true);
 
         subscriptions.assignFromUser(singleton(t1p));
         subscriptions.seek(t1p, 100);
@@ -1427,7 +1421,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testAutoCommitManualAssignmentCoordinatorUnknown() {
         ConsumerCoordinator coordinator = buildCoordinator(new Metrics(), assignors,
-                ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, true, true);
+                true, true);
 
         subscriptions.assignFromUser(singleton(t1p));
         subscriptions.seek(t1p, 100);
@@ -1863,7 +1857,7 @@ public class ConsumerCoordinatorTest {
 
         try (Metrics metrics = new Metrics(time)) {
             ConsumerCoordinator coordinator = buildCoordinator(metrics, Arrays.<PartitionAssignor>asList(roundRobin, range),
-                                                               ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, false, true);
+                    false, true);
             List<ProtocolMetadata> metadata = coordinator.metadata();
             assertEquals(2, metadata.size());
             assertEquals(roundRobin.name(), metadata.get(0).name());
@@ -1872,7 +1866,7 @@ public class ConsumerCoordinatorTest {
 
         try (Metrics metrics = new Metrics(time)) {
             ConsumerCoordinator coordinator = buildCoordinator(metrics, Arrays.<PartitionAssignor>asList(range, roundRobin),
-                                                               ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, false, true);
+                    false, true);
             List<ProtocolMetadata> metadata = coordinator.metadata();
             assertEquals(2, metadata.size());
             assertEquals(range.name(), metadata.get(0).name());
@@ -2034,7 +2028,7 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testAutoCommitAfterCoordinatorBackToService() {
         ConsumerCoordinator coordinator = buildCoordinator(new Metrics(), assignors,
-                ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, true, true);
+                true, true);
         subscriptions.assignFromUser(Collections.singleton(t1p));
         subscriptions.seek(t1p, 100L);
 
@@ -2055,7 +2049,7 @@ public class ConsumerCoordinatorTest {
                                                                final boolean leaveGroup) {
         final String consumerId = "consumer";
         ConsumerCoordinator coordinator = buildCoordinator(new Metrics(), assignors,
-                ConsumerConfig.DEFAULT_EXCLUDE_INTERNAL_TOPICS, autoCommit, leaveGroup);
+                autoCommit, leaveGroup);
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
         if (useGroupManagement) {
@@ -2146,7 +2140,6 @@ public class ConsumerCoordinatorTest {
 
     private ConsumerCoordinator buildCoordinator(final Metrics metrics,
                                                  final List<PartitionAssignor> assignors,
-                                                 final boolean excludeInternalTopics,
                                                  final boolean autoCommitEnabled,
                                                  final boolean leaveGroup) {
         return new ConsumerCoordinator(
@@ -2166,7 +2159,6 @@ public class ConsumerCoordinatorTest {
                 autoCommitEnabled,
                 autoCommitIntervalMs,
                 null,
-                excludeInternalTopics,
                 leaveGroup);
     }
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java
new file mode 100644
index 0000000..871ef30
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.MetadataRequest;
+import org.apache.kafka.common.requests.MetadataResponse;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.test.TestUtils;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.regex.Pattern;
+
+import static java.util.Collections.singleton;
+import static java.util.Collections.singletonList;
+import static java.util.Collections.singletonMap;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class ConsumerMetadataTest {
+
+    private final Node node = new Node(1, "localhost", 9092);
+    private final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+    private final Time time = new MockTime();
+
+    @Test
+    public void testPatternSubscriptionNoInternalTopics() {
+        testPatternSubscription(false);
+    }
+
+    @Test
+    public void testPatternSubscriptionIncludeInternalTopics() {
+        testPatternSubscription(true);
+    }
+
+    private void testPatternSubscription(boolean includeInternalTopics) {
+        subscription.subscribe(Pattern.compile("__.*"), new NoOpConsumerRebalanceListener());
+        ConsumerMetadata metadata = newConsumerMetadata(includeInternalTopics);
+
+        MetadataRequest.Builder builder = metadata.newMetadataRequestBuilder();
+        assertTrue(builder.isAllTopics());
+
+        List<MetadataResponse.TopicMetadata> topics = new ArrayList<>();
+        topics.add(topicMetadata("__consumer_offsets", true));
+        topics.add(topicMetadata("__matching_topic", false));
+        topics.add(topicMetadata("non_matching_topic", false));
+
+        MetadataResponse response = new MetadataResponse(singletonList(node), "clusterId", node.id(), topics);
+        metadata.update(response, time.milliseconds());
+
+        if (includeInternalTopics)
+            assertEquals(Utils.mkSet("__matching_topic", "__consumer_offsets"), metadata.fetch().topics());
+        else
+            assertEquals(Collections.singleton("__matching_topic"), metadata.fetch().topics());
+    }
+
+    @Test
+    public void testUserAssignment() {
+        subscription.assignFromUser(Utils.mkSet(
+                new TopicPartition("foo", 0),
+                new TopicPartition("bar", 0),
+                new TopicPartition("__consumer_offsets", 0)));
+        testBasicSubscription(Utils.mkSet("foo", "bar"), Utils.mkSet("__consumer_offsets"));
+    }
+
+    @Test
+    public void testNormalSubscription() {
+        subscription.subscribe(Utils.mkSet("foo", "bar", "__consumer_offsets"), new NoOpConsumerRebalanceListener());
+        subscription.groupSubscribe(Utils.mkSet("baz"));
+        testBasicSubscription(Utils.mkSet("foo", "bar", "baz"), Utils.mkSet("__consumer_offsets"));
+    }
+
+    @Test
+    public void testTransientTopics() {
+        subscription.subscribe(singleton("foo"), new NoOpConsumerRebalanceListener());
+        ConsumerMetadata metadata = newConsumerMetadata(false);
+        metadata.update(TestUtils.metadataUpdateWith(1, singletonMap("foo", 1)), time.milliseconds());
+        assertFalse(metadata.updateRequested());
+
+        metadata.addTransientTopics(singleton("foo"));
+        assertFalse(metadata.updateRequested());
+
+        metadata.addTransientTopics(singleton("bar"));
+        assertTrue(metadata.updateRequested());
+
+        Map<String, Integer> topicPartitionCounts = new HashMap<>();
+        topicPartitionCounts.put("foo", 1);
+        topicPartitionCounts.put("bar", 1);
+        metadata.update(TestUtils.metadataUpdateWith(1, topicPartitionCounts), time.milliseconds());
+        assertFalse(metadata.updateRequested());
+
+        assertEquals(Utils.mkSet("foo", "bar"), new HashSet<>(metadata.fetch().topics()));
+
+        metadata.clearTransientTopics();
+        metadata.update(TestUtils.metadataUpdateWith(1, topicPartitionCounts), time.milliseconds());
+        assertEquals(singleton("foo"), new HashSet<>(metadata.fetch().topics()));
+    }
+
+    private void testBasicSubscription(Set<String> expectedTopics, Set<String> expectedInternalTopics) {
+        Set<String> allTopics = new HashSet<>();
+        allTopics.addAll(expectedTopics);
+        allTopics.addAll(expectedInternalTopics);
+
+        ConsumerMetadata metadata = newConsumerMetadata(false);
+
+        MetadataRequest.Builder builder = metadata.newMetadataRequestBuilder();
+        assertEquals(allTopics, new HashSet<>(builder.topics()));
+
+        List<MetadataResponse.TopicMetadata> topics = new ArrayList<>();
+        for (String expectedTopic : expectedTopics)
+            topics.add(topicMetadata(expectedTopic, false));
+        for (String expectedInternalTopic : expectedInternalTopics)
+            topics.add(topicMetadata(expectedInternalTopic, true));
+
+        MetadataResponse response = new MetadataResponse(singletonList(node), "clusterId", node.id(), topics);
+        metadata.update(response, time.milliseconds());
+
+        assertEquals(allTopics, metadata.fetch().topics());
+    }
+
+    private MetadataResponse.TopicMetadata topicMetadata(String topic, boolean isInternal) {
+        MetadataResponse.PartitionMetadata partitionMetadata = new MetadataResponse.PartitionMetadata(Errors.NONE,
+                0, node, Optional.of(5), singletonList(node), singletonList(node), singletonList(node));
+        return new MetadataResponse.TopicMetadata(Errors.NONE, topic, isInternal, singletonList(partitionMetadata));
+    }
+
+    private ConsumerMetadata newConsumerMetadata(boolean includeInternalTopics) {
+        long refreshBackoffMs = 50;
+        long expireMs = 50000;
+        return new ConsumerMetadata(refreshBackoffMs, expireMs, includeInternalTopics, subscription, new LogContext(),
+                new ClusterResourceListeners());
+    }
+
+}
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java
index 45b420e..14c2cba 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java
@@ -24,20 +24,27 @@ import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.errors.DisconnectException;
+import org.apache.kafka.common.errors.InvalidTopicException;
 import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.errors.TopicAuthorizationException;
 import org.apache.kafka.common.errors.WakeupException;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.requests.HeartbeatRequest;
 import org.apache.kafka.common.requests.HeartbeatResponse;
+import org.apache.kafka.common.requests.MetadataResponse;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.test.TestUtils;
 import org.junit.Test;
 
+import java.time.Duration;
+import java.util.Collections;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.anyLong;
@@ -52,7 +59,8 @@ public class ConsumerNetworkClientTest {
     private MockTime time = new MockTime(1);
     private Cluster cluster = TestUtils.singletonCluster(topicName, 1);
     private Node node = cluster.nodes().get(0);
-    private Metadata metadata = new Metadata(0, Long.MAX_VALUE, true);
+    private Metadata metadata = new Metadata(100, 50000, new LogContext(),
+            new ClusterResourceListeners());
     private MockClient client = new MockClient(time, metadata);
     private ConsumerNetworkClient consumerClient = new ConsumerNetworkClient(new LogContext(),
             client, metadata, time, 100, 1000, Integer.MAX_VALUE);
@@ -75,7 +83,7 @@ public class ConsumerNetworkClientTest {
     }
 
     @Test
-    public void sendWithinBlackoutPeriodAfterAuthenticationFailure() throws InterruptedException {
+    public void sendWithinBlackoutPeriodAfterAuthenticationFailure() {
         client.authenticationFailed(node, 300);
         client.prepareResponse(heartbeatResponse(Errors.NONE));
         final RequestFuture<ClientResponse> future = consumerClient.send(node, heartbeat());
@@ -223,6 +231,34 @@ public class ConsumerNetworkClientTest {
     }
 
     @Test
+    public void testAuthenticationExceptionPropagatedFromMetadata() {
+        metadata.failedUpdate(time.milliseconds(), new AuthenticationException("Authentication failed"));
+        try {
+            consumerClient.poll(time.timer(Duration.ZERO));
+            fail("Expected authentication error thrown");
+        } catch (AuthenticationException e) {
+            // After the exception is raised, it should have been cleared
+            assertNull(metadata.getAndClearAuthenticationException());
+        }
+    }
+
+    @Test(expected = InvalidTopicException.class)
+    public void testInvalidTopicExceptionPropagatedFromMetadata() {
+        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("clusterId", 1,
+                Collections.singletonMap("topic", Errors.INVALID_TOPIC_EXCEPTION), Collections.emptyMap());
+        metadata.update(metadataResponse, time.milliseconds());
+        consumerClient.poll(time.timer(Duration.ZERO));
+    }
+
+    @Test(expected = TopicAuthorizationException.class)
+    public void testTopicAuthorizationExceptionPropagatedFromMetadata() {
+        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("clusterId", 1,
+                Collections.singletonMap("topic", Errors.TOPIC_AUTHORIZATION_FAILED), Collections.emptyMap());
+        metadata.update(metadataResponse, time.milliseconds());
+        consumerClient.poll(time.timer(Duration.ZERO));
+    }
+
+    @Test
     public void testFutureCompletionOutsidePoll() throws Exception {
         // Tests the scenario in which the request that is being awaited in one thread
         // is received and completed in another thread.
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index b08df1d..3fe7ca0 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -21,7 +21,6 @@ import org.apache.kafka.clients.ClientDnsLookup;
 import org.apache.kafka.clients.ClientRequest;
 import org.apache.kafka.clients.ClientUtils;
 import org.apache.kafka.clients.FetchSessionHandler;
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.clients.NodeApiVersions;
@@ -44,6 +43,7 @@ import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.errors.TopicAuthorizationException;
 import org.apache.kafka.common.header.Header;
 import org.apache.kafka.common.header.internals.RecordHeader;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.metrics.KafkaMetric;
 import org.apache.kafka.common.metrics.MetricConfig;
 import org.apache.kafka.common.metrics.Metrics;
@@ -119,13 +119,16 @@ import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 @SuppressWarnings("deprecation")
 public class FetcherTest {
+    private static final double EPSILON = 0.0001;
+
     private ConsumerRebalanceListener listener = new NoOpConsumerRebalanceListener();
     private String topicName = "test";
     private String groupId = "test-group";
@@ -134,6 +137,8 @@ public class FetcherTest {
     private TopicPartition tp1 = new TopicPartition(topicName, 1);
     private TopicPartition tp2 = new TopicPartition(topicName, 2);
     private TopicPartition tp3 = new TopicPartition(topicName, 3);
+    private MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, singletonMap(topicName, 4));
+
     private int minBytes = 1;
     private int maxBytes = Integer.MAX_VALUE;
     private int maxWaitMs = 0;
@@ -141,34 +146,22 @@ public class FetcherTest {
     private long retryBackoffMs = 100;
     private long requestTimeoutMs = 30000;
     private MockTime time = new MockTime(1);
-    private MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, singletonMap(topicName, 4));
-    private Metadata metadata = new Metadata(0, Long.MAX_VALUE, true);
-    private MockClient client = new MockClient(time, metadata);
-    private Node node;
-    private Metrics metrics = new Metrics(time);
-    private FetcherMetricsRegistry metricsRegistry = new FetcherMetricsRegistry("consumer" + groupId);
-
-    private SubscriptionState subscriptions = new SubscriptionState(
-            new LogContext(), OffsetResetStrategy.EARLIEST);
-    private SubscriptionState subscriptionsNoAutoReset = new SubscriptionState(
-            new LogContext(), OffsetResetStrategy.NONE);
-    private static final double EPSILON = 0.0001;
-    private ConsumerNetworkClient consumerClient = new ConsumerNetworkClient(new LogContext(),
-            client, metadata, time, 100, 1000, Integer.MAX_VALUE);
+    private SubscriptionState subscriptions;
+    private ConsumerMetadata metadata;
+    private FetcherMetricsRegistry metricsRegistry;
+    private MockClient client;
+    private Metrics metrics;
+    private ConsumerNetworkClient consumerClient;
+    private Fetcher<?, ?> fetcher;
 
     private MemoryRecords records;
     private MemoryRecords nextRecords;
     private MemoryRecords emptyRecords;
     private MemoryRecords partialRecords;
-    private Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, metrics);
-    private Metrics fetcherMetrics = new Metrics(time);
-    private Fetcher<byte[], byte[]> fetcherNoAutoReset = createFetcher(subscriptionsNoAutoReset, fetcherMetrics);
     private ExecutorService executorService;
 
     @Before
     public void setup() {
-        client.updateMetadata(initialUpdateResponse);
-        node = metadata.fetch().nodes().get(0);
         records = buildRecords(1L, 3, 1);
         nextRecords = buildRecords(4L, 2, 4);
         emptyRecords = buildRecords(0L, 0, 0);
@@ -176,12 +169,17 @@ public class FetcherTest {
         partialRecords.buffer().putInt(Records.SIZE_OFFSET, 10000);
     }
 
+    private void assignFromUser(Set<TopicPartition> partitions) {
+        subscriptions.assignFromUser(partitions);
+        client.updateMetadata(initialUpdateResponse);
+    }
+
     @After
     public void teardown() throws Exception {
-        this.metrics.close();
-        this.fetcherMetrics.close();
-        this.fetcher.close();
-        this.fetcherMetrics.close();
+        if (metrics != null)
+            this.metrics.close();
+        if (fetcher != null)
+            this.fetcher.close();
         if (executorService != null) {
             executorService.shutdownNow();
             assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS));
@@ -190,7 +188,9 @@ public class FetcherTest {
 
     @Test
     public void testFetchNormal() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         // normal fetch
@@ -201,7 +201,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
         assertTrue(partitionRecords.containsKey(tp0));
 
         List<ConsumerRecord<byte[], byte[]>> records = partitionRecords.get(tp0);
@@ -216,7 +216,9 @@ public class FetcherTest {
 
     @Test
     public void testMissingLeaderEpochInRecords() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         ByteBuffer buffer = ByteBuffer.allocate(1024);
@@ -234,7 +236,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
         assertTrue(partitionRecords.containsKey(tp0));
         assertEquals(2, partitionRecords.get(tp0).size());
 
@@ -245,7 +247,9 @@ public class FetcherTest {
 
     @Test
     public void testLeaderEpochInConsumerRecord() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         Integer partitionLeaderEpoch = 1;
@@ -283,7 +287,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
         assertTrue(partitionRecords.containsKey(tp0));
         assertEquals(6, partitionRecords.get(tp0).size());
 
@@ -295,7 +299,9 @@ public class FetcherTest {
 
     @Test
     public void testClearBufferedDataForTopicPartitions() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         // normal fetch
@@ -314,8 +320,12 @@ public class FetcherTest {
 
     @Test
     public void testFetchSkipsBlackedOutNodes() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
+        client.updateMetadata(initialUpdateResponse);
+        Node node = initialUpdateResponse.brokers().iterator().next();
 
         client.blackout(node, 500);
         assertEquals(0, fetcher.sendFetches());
@@ -326,7 +336,9 @@ public class FetcherTest {
 
     @Test
     public void testFetcherIgnoresControlRecords() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         // normal fetch
@@ -353,7 +365,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
         assertTrue(partitionRecords.containsKey(tp0));
 
         List<ConsumerRecord<byte[], byte[]>> records = partitionRecords.get(tp0);
@@ -366,7 +378,9 @@ public class FetcherTest {
 
     @Test
     public void testFetchError() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
@@ -376,7 +390,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
         assertFalse(partitionRecords.containsKey(tp0));
     }
 
@@ -408,9 +422,9 @@ public class FetcherTest {
             }
         };
 
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(time), deserializer, deserializer);
+        buildFetcher(deserializer, deserializer);
 
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
         client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
@@ -431,6 +445,9 @@ public class FetcherTest {
 
     @Test
     public void testParseCorruptedRecord() throws Exception {
+        buildFetcher();
+        assignFromUser(singleton(tp0));
+
         ByteBuffer buffer = ByteBuffer.allocate(1024);
         DataOutputStream out = new DataOutputStream(new ByteBufferOutputStream(buffer));
 
@@ -470,7 +487,6 @@ public class FetcherTest {
 
         buffer.flip();
 
-        subscriptions.assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         // normal fetch
@@ -516,7 +532,8 @@ public class FetcherTest {
         client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(responseBuffer), Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
-        List<ConsumerRecord<byte[], byte[]>> records = fetcher.fetchedRecords().get(tp0);
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
+        List<ConsumerRecord<byte[], byte[]>> records = recordsByPartition.get(tp0);
         assertEquals(1, records.size());
         assertEquals(toOffset, records.get(0).offset());
         assertEquals(toOffset + 1, subscriptions.position(tp0).longValue());
@@ -524,6 +541,8 @@ public class FetcherTest {
 
     @Test
     public void testInvalidDefaultRecordBatch() {
+        buildFetcher();
+
         ByteBuffer buffer = ByteBuffer.allocate(1024);
         ByteBufferOutputStream out = new ByteBufferOutputStream(buffer);
 
@@ -541,7 +560,7 @@ public class FetcherTest {
         buffer.put("beef".getBytes());
         buffer.position(0);
 
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         // normal fetch
@@ -561,7 +580,8 @@ public class FetcherTest {
     }
 
     @Test
-    public void testParseInvalidRecordBatch() throws Exception {
+    public void testParseInvalidRecordBatch() {
+        buildFetcher();
         MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L,
                 CompressionType.NONE, TimestampType.CREATE_TIME,
                 new SimpleRecord(1L, "a".getBytes(), "1".getBytes()),
@@ -572,7 +592,7 @@ public class FetcherTest {
         // flip some bits to fail the crc
         buffer.putInt(32, buffer.get(32) ^ 87238423);
 
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         // normal fetch
@@ -590,7 +610,7 @@ public class FetcherTest {
 
     @Test
     public void testHeaders() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(time));
+        buildFetcher();
 
         MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, TimestampType.CREATE_TIME, 1L);
         builder.append(0L, "key".getBytes(), "value-1".getBytes());
@@ -607,14 +627,15 @@ public class FetcherTest {
         MemoryRecords memoryRecords = builder.build();
 
         List<ConsumerRecord<byte[], byte[]>> records;
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
         client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, memoryRecords, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
-        records = fetcher.fetchedRecords().get(tp0);
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
+        records = recordsByPartition.get(tp0);
 
         assertEquals(3, records.size());
 
@@ -634,10 +655,10 @@ public class FetcherTest {
 
     @Test
     public void testFetchMaxPollRecords() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(time), 2);
+        buildFetcher(2);
 
         List<ConsumerRecord<byte[], byte[]>> records;
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
         client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
@@ -645,7 +666,8 @@ public class FetcherTest {
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
-        records = fetcher.fetchedRecords().get(tp0);
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
+        records = recordsByPartition.get(tp0);
         assertEquals(2, records.size());
         assertEquals(3L, subscriptions.position(tp0).longValue());
         assertEquals(1, records.get(0).offset());
@@ -653,14 +675,16 @@ public class FetcherTest {
 
         assertEquals(0, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
-        records = fetcher.fetchedRecords().get(tp0);
+        recordsByPartition = fetchedRecords();
+        records = recordsByPartition.get(tp0);
         assertEquals(1, records.size());
         assertEquals(4L, subscriptions.position(tp0).longValue());
         assertEquals(3, records.get(0).offset());
 
         assertTrue(fetcher.sendFetches() > 0);
         consumerClient.poll(time.timer(0));
-        records = fetcher.fetchedRecords().get(tp0);
+        recordsByPartition = fetchedRecords();
+        records = recordsByPartition.get(tp0);
         assertEquals(2, records.size());
         assertEquals(6L, subscriptions.position(tp0).longValue());
         assertEquals(4, records.get(0).offset());
@@ -674,10 +698,10 @@ public class FetcherTest {
      */
     @Test
     public void testFetchAfterPartitionWithFetchedRecordsIsUnassigned() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(time), 2);
+        buildFetcher(2);
 
         List<ConsumerRecord<byte[], byte[]>> records;
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
         // Returns 3 records while `max.poll.records` is configured to 2
@@ -685,19 +709,20 @@ public class FetcherTest {
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
-        records = fetcher.fetchedRecords().get(tp0);
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
+        records = recordsByPartition.get(tp0);
         assertEquals(2, records.size());
         assertEquals(3L, subscriptions.position(tp0).longValue());
         assertEquals(1, records.get(0).offset());
         assertEquals(2, records.get(1).offset());
 
-        subscriptions.assignFromUser(singleton(tp1));
+        assignFromUser(singleton(tp1));
         client.prepareResponse(matchesOffset(tp1, 4), fullFetchResponse(tp1, this.nextRecords, Errors.NONE, 100L, 0));
         subscriptions.seek(tp1, 4);
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
         assertNull(fetchedRecords.get(tp0));
         records = fetchedRecords.get(tp1);
         assertEquals(2, records.size());
@@ -710,6 +735,7 @@ public class FetcherTest {
     public void testFetchNonContinuousRecords() {
         // if we are fetching from a compacted topic, there may be gaps in the returned records
         // this test verifies the fetcher updates the current fetched/consumed positions correctly for this case
+        buildFetcher();
 
         MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE,
                 TimestampType.CREATE_TIME, 0L);
@@ -719,14 +745,15 @@ public class FetcherTest {
         MemoryRecords records = builder.build();
 
         List<ConsumerRecord<byte[], byte[]>> consumerRecords;
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
-        consumerRecords = fetcher.fetchedRecords().get(tp0);
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
+        consumerRecords = recordsByPartition.get(tp0);
         assertEquals(3, consumerRecords.size());
         assertEquals(31L, subscriptions.position(tp0).longValue()); // this is the next fetching position
 
@@ -742,6 +769,8 @@ public class FetcherTest {
     @Test
     public void testFetchRequestWhenRecordTooLarge() {
         try {
+            buildFetcher();
+
             client.setNodeApiVersions(NodeApiVersions.create(Collections.singletonList(
                 new ApiVersionsResponse.ApiVersion(ApiKeys.FETCH.id, (short) 2, (short) 2))));
             makeFetchRequestWithIncompleteRecord();
@@ -766,6 +795,7 @@ public class FetcherTest {
      */
     @Test
     public void testFetchRequestInternalError() {
+        buildFetcher();
         makeFetchRequestWithIncompleteRecord();
         try {
             fetcher.fetchedRecords();
@@ -778,7 +808,7 @@ public class FetcherTest {
     }
 
     private void makeFetchRequestWithIncompleteRecord() {
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
@@ -791,7 +821,9 @@ public class FetcherTest {
 
     @Test
     public void testUnauthorizedTopic() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         // resize the limit of the buffer to pretend it is only fetch-size large
@@ -808,10 +840,14 @@ public class FetcherTest {
 
     @Test
     public void testFetchDuringRebalance() {
+        buildFetcher();
+
         subscriptions.subscribe(singleton(topicName), listener);
         subscriptions.assignFromSubscribed(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
+        client.updateMetadata(initialUpdateResponse);
+
         assertEquals(1, fetcher.sendFetches());
 
         // Now the rebalance happens and fetch positions are cleared
@@ -825,7 +861,9 @@ public class FetcherTest {
 
     @Test
     public void testInFlightFetchOnPausedPartition() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
@@ -838,7 +876,9 @@ public class FetcherTest {
 
     @Test
     public void testFetchOnPausedPartition() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         subscriptions.pause(tp0);
@@ -848,7 +888,8 @@ public class FetcherTest {
 
     @Test
     public void testFetchNotLeaderForPartition() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
@@ -860,7 +901,8 @@ public class FetcherTest {
 
     @Test
     public void testFetchUnknownTopicOrPartition() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
@@ -872,7 +914,8 @@ public class FetcherTest {
 
     @Test
     public void testFetchFencedLeaderEpoch() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
@@ -885,7 +928,8 @@ public class FetcherTest {
 
     @Test
     public void testFetchUnknownLeaderEpoch() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
@@ -898,6 +942,8 @@ public class FetcherTest {
 
     @Test
     public void testEpochSetInFetchRequest() {
+        buildFetcher();
+        subscriptions.assignFromUser(singleton(tp0));
         client.updateMetadata(initialUpdateResponse);
 
         // Metadata update with leader epochs
@@ -906,7 +952,6 @@ public class FetcherTest {
                     new MetadataResponse.PartitionMetadata(error, partition, leader, Optional.of(99), replicas, Collections.emptyList(), offlineReplicas));
         client.updateMetadata(metadataResponse);
 
-        subscriptions.assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
 
@@ -930,7 +975,8 @@ public class FetcherTest {
 
     @Test
     public void testFetchOffsetOutOfRange() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
@@ -945,7 +991,8 @@ public class FetcherTest {
     public void testStaleOutOfRangeError() {
         // verify that an out of range error which arrives after a seek
         // does not cause us to reset our position or throw an exception
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
@@ -959,35 +1006,38 @@ public class FetcherTest {
 
     @Test
     public void testFetchedRecordsAfterSeek() {
-        subscriptionsNoAutoReset.assignFromUser(singleton(tp0));
-        subscriptionsNoAutoReset.seek(tp0, 0);
+        buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(),
+                new ByteArrayDeserializer(), 2, IsolationLevel.READ_UNCOMMITTED);
 
-        assertTrue(fetcherNoAutoReset.sendFetches() > 0);
+        assignFromUser(singleton(tp0));
+        subscriptions.seek(tp0, 0);
+
+        assertTrue(fetcher.sendFetches() > 0);
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         consumerClient.poll(time.timer(0));
-        assertFalse(subscriptionsNoAutoReset.isOffsetResetNeeded(tp0));
-        subscriptionsNoAutoReset.seek(tp0, 2);
-        assertEquals(0, fetcherNoAutoReset.fetchedRecords().size());
+        assertFalse(subscriptions.isOffsetResetNeeded(tp0));
+        subscriptions.seek(tp0, 2);
+        assertEquals(0, fetcher.fetchedRecords().size());
     }
 
     @Test
     public void testFetchOffsetOutOfRangeException() {
-        subscriptionsNoAutoReset.assignFromUser(singleton(tp0));
-        subscriptionsNoAutoReset.seek(tp0, 0);
+        buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(),
+                new ByteArrayDeserializer(), 2, IsolationLevel.READ_UNCOMMITTED);
 
-        fetcherNoAutoReset.sendFetches();
+        assignFromUser(singleton(tp0));
+        subscriptions.seek(tp0, 0);
+
+        fetcher.sendFetches();
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
-        assertFalse(subscriptionsNoAutoReset.isOffsetResetNeeded(tp0));
+        assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         for (int i = 0; i < 2; i++) {
-            try {
-                fetcherNoAutoReset.fetchedRecords();
-                fail("Should have thrown OffsetOutOfRangeException");
-            } catch (OffsetOutOfRangeException e) {
-                assertTrue(e.offsetOutOfRangePartitions().containsKey(tp0));
-                assertEquals(e.offsetOutOfRangePartitions().size(), 1);
-            }
+            OffsetOutOfRangeException e = assertThrows(OffsetOutOfRangeException.class, () ->
+                    fetcher.fetchedRecords());
+            assertEquals(singleton(tp0), e.offsetOutOfRangePartitions().keySet());
+            assertEquals(0L, e.offsetOutOfRangePartitions().get(tp0).longValue());
         }
     }
 
@@ -995,11 +1045,13 @@ public class FetcherTest {
     public void testFetchPositionAfterException() {
         // verify the advancement in the next fetch offset equals to the number of fetched records when
         // some fetched partitions cause Exception. This ensures that consumer won't lose record upon exception
-        subscriptionsNoAutoReset.assignFromUser(Utils.mkSet(tp0, tp1));
-        subscriptionsNoAutoReset.seek(tp0, 1);
-        subscriptionsNoAutoReset.seek(tp1, 1);
+        buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(),
+                new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED);
+        assignFromUser(Utils.mkSet(tp0, tp1));
+        subscriptions.seek(tp0, 1);
+        subscriptions.seek(tp1, 1);
 
-        assertEquals(1, fetcherNoAutoReset.sendFetches());
+        assertEquals(1, fetcher.sendFetches());
 
         Map<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions = new LinkedHashMap<>();
         partitions.put(tp1, new FetchResponse.PartitionData<>(Errors.NONE, 100,
@@ -1010,41 +1062,43 @@ public class FetcherTest {
             0, INVALID_SESSION_ID));
         consumerClient.poll(time.timer(0));
 
-        List<ConsumerRecord<byte[], byte[]>> fetchedRecords = new ArrayList<>();
-        List<OffsetOutOfRangeException> exceptions = new ArrayList<>();
+        List<ConsumerRecord<byte[], byte[]>> allFetchedRecords = new ArrayList<>();
+        fetchRecordsInto(allFetchedRecords);
 
-        for (List<ConsumerRecord<byte[], byte[]>> records: fetcherNoAutoReset.fetchedRecords().values())
-            fetchedRecords.addAll(records);
+        assertEquals(1, subscriptions.position(tp0).longValue());
+        assertEquals(4, subscriptions.position(tp1).longValue());
+        assertEquals(3, allFetchedRecords.size());
 
-        assertEquals(fetchedRecords.size(), subscriptionsNoAutoReset.position(tp1) - 1);
+        OffsetOutOfRangeException e = assertThrows(OffsetOutOfRangeException.class, () ->
+                fetchRecordsInto(allFetchedRecords));
 
-        try {
-            for (List<ConsumerRecord<byte[], byte[]>> records: fetcherNoAutoReset.fetchedRecords().values())
-                fetchedRecords.addAll(records);
-        } catch (OffsetOutOfRangeException e) {
-            exceptions.add(e);
-        }
+        assertEquals(singleton(tp0), e.offsetOutOfRangePartitions().keySet());
+        assertEquals(1L, e.offsetOutOfRangePartitions().get(tp0).longValue());
 
-        assertEquals(4, subscriptionsNoAutoReset.position(tp1).longValue());
-        assertEquals(3, fetchedRecords.size());
+        assertEquals(1, subscriptions.position(tp0).longValue());
+        assertEquals(4, subscriptions.position(tp1).longValue());
+        assertEquals(3, allFetchedRecords.size());
+    }
 
-        // Should have received one OffsetOutOfRangeException for partition tp1
-        assertEquals(1, exceptions.size());
-        OffsetOutOfRangeException e = exceptions.get(0);
-        assertTrue(e.offsetOutOfRangePartitions().containsKey(tp0));
-        assertEquals(e.offsetOutOfRangePartitions().size(), 1);
+    private void fetchRecordsInto(List<ConsumerRecord<byte[], byte[]>> allFetchedRecords) {
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
+        fetchedRecords.values().forEach(allFetchedRecords::addAll);
     }
 
     @Test
     public void testCompletedFetchRemoval() {
         // Ensure the removal of completed fetches that cause an Exception if and only if they contain empty records.
-        subscriptionsNoAutoReset.assignFromUser(Utils.mkSet(tp0, tp1, tp2, tp3));
-        subscriptionsNoAutoReset.seek(tp0, 1);
-        subscriptionsNoAutoReset.seek(tp1, 1);
-        subscriptionsNoAutoReset.seek(tp2, 1);
-        subscriptionsNoAutoReset.seek(tp3, 1);
+        buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(),
+                new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED);
+        assignFromUser(Utils.mkSet(tp0, tp1, tp2, tp3));
+        client.updateMetadata(initialUpdateResponse);
 
-        assertEquals(1, fetcherNoAutoReset.sendFetches());
+        subscriptions.seek(tp0, 1);
+        subscriptions.seek(tp1, 1);
+        subscriptions.seek(tp2, 1);
+        subscriptions.seek(tp3, 1);
+
+        assertEquals(1, fetcher.sendFetches());
 
         Map<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions = new LinkedHashMap<>();
         partitions.put(tp1, new FetchResponse.PartitionData<>(Errors.NONE, 100, FetchResponse.INVALID_LAST_STABLE_OFFSET,
@@ -1060,16 +1114,18 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
 
         List<ConsumerRecord<byte[], byte[]>> fetchedRecords = new ArrayList<>();
-        for (List<ConsumerRecord<byte[], byte[]>> records: fetcherNoAutoReset.fetchedRecords().values())
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
+        for (List<ConsumerRecord<byte[], byte[]>> records : recordsByPartition.values())
             fetchedRecords.addAll(records);
 
-        assertEquals(fetchedRecords.size(), subscriptionsNoAutoReset.position(tp1) - 1);
-        assertEquals(4, subscriptionsNoAutoReset.position(tp1).longValue());
+        assertEquals(fetchedRecords.size(), subscriptions.position(tp1) - 1);
+        assertEquals(4, subscriptions.position(tp1).longValue());
         assertEquals(3, fetchedRecords.size());
 
         List<OffsetOutOfRangeException> oorExceptions = new ArrayList<>();
         try {
-            for (List<ConsumerRecord<byte[], byte[]>> records: fetcherNoAutoReset.fetchedRecords().values())
+            recordsByPartition = fetchedRecords();
+            for (List<ConsumerRecord<byte[], byte[]>> records : recordsByPartition.values())
                 fetchedRecords.addAll(records);
         } catch (OffsetOutOfRangeException oor) {
             oorExceptions.add(oor);
@@ -1081,18 +1137,20 @@ public class FetcherTest {
         assertTrue(oor.offsetOutOfRangePartitions().containsKey(tp0));
         assertEquals(oor.offsetOutOfRangePartitions().size(), 1);
 
-        for (List<ConsumerRecord<byte[], byte[]>> records: fetcherNoAutoReset.fetchedRecords().values())
+        recordsByPartition = fetchedRecords();
+        for (List<ConsumerRecord<byte[], byte[]>> records : recordsByPartition.values())
             fetchedRecords.addAll(records);
 
         // Should not have received an Exception for tp2.
-        assertEquals(6, subscriptionsNoAutoReset.position(tp2).longValue());
+        assertEquals(6, subscriptions.position(tp2).longValue());
         assertEquals(5, fetchedRecords.size());
 
         int numExceptionsExpected = 3;
         List<KafkaException> kafkaExceptions = new ArrayList<>();
         for (int i = 1; i <= numExceptionsExpected; i++) {
             try {
-                for (List<ConsumerRecord<byte[], byte[]>> records: fetcherNoAutoReset.fetchedRecords().values())
+                recordsByPartition = fetchedRecords();
+                for (List<ConsumerRecord<byte[], byte[]>> records : recordsByPartition.values())
                     fetchedRecords.addAll(records);
             } catch (KafkaException e) {
                 kafkaExceptions.add(e);
@@ -1104,10 +1162,11 @@ public class FetcherTest {
 
     @Test
     public void testSeekBeforeException() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptionsNoAutoReset, new Metrics(time), 2);
+        buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(),
+                new ByteArrayDeserializer(), 2, IsolationLevel.READ_UNCOMMITTED);
 
-        subscriptionsNoAutoReset.assignFromUser(Utils.mkSet(tp0));
-        subscriptionsNoAutoReset.seek(tp0, 1);
+        assignFromUser(Utils.mkSet(tp0));
+        subscriptions.seek(tp0, 1);
         assertEquals(1, fetcher.sendFetches());
         Map<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions = new HashMap<>();
         partitions.put(tp0, new FetchResponse.PartitionData<>(Errors.NONE, 100,
@@ -1117,8 +1176,8 @@ public class FetcherTest {
 
         assertEquals(2, fetcher.fetchedRecords().get(tp0).size());
 
-        subscriptionsNoAutoReset.assignFromUser(Utils.mkSet(tp0, tp1));
-        subscriptionsNoAutoReset.seek(tp1, 1);
+        subscriptions.assignFromUser(Utils.mkSet(tp0, tp1));
+        subscriptions.seek(tp1, 1);
         assertEquals(1, fetcher.sendFetches());
         partitions = new HashMap<>();
         partitions.put(tp1, new FetchResponse.PartitionData<>(Errors.OFFSET_OUT_OF_RANGE, 100,
@@ -1127,14 +1186,16 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertEquals(1, fetcher.fetchedRecords().get(tp0).size());
 
-        subscriptionsNoAutoReset.seek(tp1, 10);
+        subscriptions.seek(tp1, 10);
         // Should not throw OffsetOutOfRangeException after the seek
         assertEquals(0, fetcher.fetchedRecords().size());
     }
 
     @Test
     public void testFetchDisconnected() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
@@ -1150,7 +1211,8 @@ public class FetcherTest {
 
     @Test
     public void testUpdateFetchPositionNoOpWithPositionSet() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 5L);
 
         fetcher.resetOffsetsIfNeeded();
@@ -1161,7 +1223,8 @@ public class FetcherTest {
 
     @Test
     public void testUpdateFetchPositionResetToDefaultOffset() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0);
 
         client.prepareResponse(listOffsetRequestMatcher(ListOffsetRequest.EARLIEST_TIMESTAMP),
@@ -1175,9 +1238,12 @@ public class FetcherTest {
 
     @Test
     public void testUpdateFetchPositionResetToLatestOffset() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
+        client.updateMetadata(initialUpdateResponse);
+
         client.prepareResponse(listOffsetRequestMatcher(ListOffsetRequest.LATEST_TIMESTAMP),
                 listOffsetResponse(Errors.NONE, 1L, 5L));
         fetcher.resetOffsetsIfNeeded();
@@ -1192,7 +1258,8 @@ public class FetcherTest {
      */
     @Test
     public void testFetchOffsetErrors() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
         // Fail with OFFSET_NOT_AVAILABLE
@@ -1227,36 +1294,43 @@ public class FetcherTest {
     }
 
     @Test
-    public void testListOffsetsSendsIsolationLevel() {
-        for (final IsolationLevel isolationLevel : IsolationLevel.values()) {
-            Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(), new ByteArrayDeserializer(),
-                    new ByteArrayDeserializer(), Integer.MAX_VALUE, isolationLevel);
+    public void testListOffsetSendsReadUncommitted() {
+        testListOffsetsSendsIsolationLevel(IsolationLevel.READ_UNCOMMITTED);
+    }
+
+    @Test
+    public void testListOffsetSendsReadCommitted() {
+        testListOffsetsSendsIsolationLevel(IsolationLevel.READ_COMMITTED);
+    }
 
-            subscriptions.assignFromUser(singleton(tp0));
-            subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
+    private void testListOffsetsSendsIsolationLevel(IsolationLevel isolationLevel) {
+        buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), new ByteArrayDeserializer(),
+                Integer.MAX_VALUE, isolationLevel);
 
-            client.prepareResponse(new MockClient.RequestMatcher() {
-                @Override
-                public boolean matches(AbstractRequest body) {
-                    ListOffsetRequest request = (ListOffsetRequest) body;
-                    return request.isolationLevel() == isolationLevel;
-                }
-            }, listOffsetResponse(Errors.NONE, 1L, 5L));
-            fetcher.resetOffsetsIfNeeded();
-            consumerClient.pollNoWakeup();
+        assignFromUser(singleton(tp0));
+        subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
-            assertFalse(subscriptions.isOffsetResetNeeded(tp0));
-            assertTrue(subscriptions.isFetchable(tp0));
-            assertEquals(5, subscriptions.position(tp0).longValue());
-        }
+        client.prepareResponse(body -> {
+            ListOffsetRequest request = (ListOffsetRequest) body;
+            return request.isolationLevel() == isolationLevel;
+        }, listOffsetResponse(Errors.NONE, 1L, 5L));
+        fetcher.resetOffsetsIfNeeded();
+        consumerClient.pollNoWakeup();
+
+        assertFalse(subscriptions.isOffsetResetNeeded(tp0));
+        assertTrue(subscriptions.isFetchable(tp0));
+        assertEquals(5, subscriptions.position(tp0).longValue());
     }
 
     @Test
     public void testResetOffsetsSkipsBlackedOutConnections() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST);
 
         // Check that we skip sending the ListOffset request when the node is blacked out
+        client.updateMetadata(initialUpdateResponse);
+        Node node = initialUpdateResponse.brokers().iterator().next();
         client.blackout(node, 500);
         fetcher.resetOffsetsIfNeeded();
         assertEquals(0, consumerClient.pendingRequestCount());
@@ -1277,7 +1351,8 @@ public class FetcherTest {
 
     @Test
     public void testUpdateFetchPositionResetToEarliestOffset() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST);
 
         client.prepareResponse(listOffsetRequestMatcher(ListOffsetRequest.EARLIEST_TIMESTAMP),
@@ -1292,7 +1367,8 @@ public class FetcherTest {
 
     @Test
     public void testResetOffsetsMetadataRefresh() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
         // First fetch fails with stale metadata
@@ -1321,7 +1397,8 @@ public class FetcherTest {
 
     @Test
     public void testUpdateFetchPositionDisconnect() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
         // First request gets a disconnect
@@ -1356,7 +1433,8 @@ public class FetcherTest {
 
     @Test
     public void testAssignmentChangeWithInFlightReset() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
         // Send the ListOffsets request to reset the position
@@ -1366,7 +1444,7 @@ public class FetcherTest {
         assertTrue(client.hasInFlightRequests());
 
         // Now we have an assignment change
-        subscriptions.assignFromUser(singleton(tp1));
+        assignFromUser(singleton(tp1));
 
         // The response returns and is discarded
         client.respond(listOffsetResponse(Errors.NONE, 1L, 5L));
@@ -1379,7 +1457,8 @@ public class FetcherTest {
 
     @Test
     public void testSeekWithInFlightReset() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
         // Send the ListOffsets request to reset the position
@@ -1402,7 +1481,8 @@ public class FetcherTest {
 
     @Test
     public void testChangeResetWithInFlightReset() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
         // Send the ListOffsets request to reset the position
@@ -1426,7 +1506,8 @@ public class FetcherTest {
 
     @Test
     public void testIdempotentResetWithInFlightReset() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
         // Send the ListOffsets request to reset the position
@@ -1448,7 +1529,8 @@ public class FetcherTest {
 
     @Test
     public void testRestOffsetsAuthorizationFailure() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
         // First request gets a disconnect
@@ -1485,7 +1567,8 @@ public class FetcherTest {
 
     @Test
     public void testUpdateFetchPositionOfPausedPartitionsRequiringOffsetReset() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.pause(tp0); // paused partition does not have a valid position
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
@@ -1502,7 +1585,8 @@ public class FetcherTest {
 
     @Test
     public void testUpdateFetchPositionOfPausedPartitionsWithoutAValidPosition() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0);
         subscriptions.pause(tp0); // paused partition does not have a valid position
 
@@ -1516,7 +1600,8 @@ public class FetcherTest {
 
     @Test
     public void testUpdateFetchPositionOfPausedPartitionsWithAValidPosition() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 10);
         subscriptions.pause(tp0); // paused partition already has a valid position
 
@@ -1531,6 +1616,8 @@ public class FetcherTest {
     @Test
     public void testGetAllTopics() {
         // sending response before request, as getTopicMetadata is a blocking call
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         client.prepareResponse(newMetadataResponse(topicName, Errors.NONE));
 
         Map<String, List<PartitionInfo>> allTopics = fetcher.getAllTopicMetadata(time.timer(5000L));
@@ -1541,6 +1628,8 @@ public class FetcherTest {
     @Test
     public void testGetAllTopicsDisconnect() {
         // first try gets a disconnect, next succeeds
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         client.prepareResponse(null, true);
         client.prepareResponse(newMetadataResponse(topicName, Errors.NONE));
         Map<String, List<PartitionInfo>> allTopics = fetcher.getAllTopicMetadata(time.timer(5000L));
@@ -1550,11 +1639,15 @@ public class FetcherTest {
     @Test(expected = TimeoutException.class)
     public void testGetAllTopicsTimeout() {
         // since no response is prepared, the request should timeout
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         fetcher.getAllTopicMetadata(time.timer(50L));
     }
 
     @Test
     public void testGetAllTopicsUnauthorized() {
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         client.prepareResponse(newMetadataResponse(topicName, Errors.TOPIC_AUTHORIZATION_FAILED));
         try {
             fetcher.getAllTopicMetadata(time.timer(10L));
@@ -1566,6 +1659,8 @@ public class FetcherTest {
 
     @Test(expected = InvalidTopicException.class)
     public void testGetTopicMetadataInvalidTopic() {
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         client.prepareResponse(newMetadataResponse(topicName, Errors.INVALID_TOPIC_EXCEPTION));
         fetcher.getTopicMetadata(
                 new MetadataRequest.Builder(Collections.singletonList(topicName), true), time.timer(5000L));
@@ -1573,6 +1668,8 @@ public class FetcherTest {
 
     @Test
     public void testGetTopicMetadataUnknownTopic() {
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         client.prepareResponse(newMetadataResponse(topicName, Errors.UNKNOWN_TOPIC_OR_PARTITION));
 
         Map<String, List<PartitionInfo>> topicMetadata = fetcher.getTopicMetadata(
@@ -1582,6 +1679,8 @@ public class FetcherTest {
 
     @Test
     public void testGetTopicMetadataLeaderNotAvailable() {
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         client.prepareResponse(newMetadataResponse(topicName, Errors.LEADER_NOT_AVAILABLE));
         client.prepareResponse(newMetadataResponse(topicName, Errors.NONE));
 
@@ -1592,6 +1691,8 @@ public class FetcherTest {
 
     @Test
     public void testGetTopicMetadataOfflinePartitions() {
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         MetadataResponse originalResponse = newMetadataResponse(topicName, Errors.NONE); //baseline ok response
 
         //create a response based on the above one with all partitions being leaderless
@@ -1642,6 +1743,8 @@ public class FetcherTest {
      */
     @Test
     public void testQuotaMetrics() {
+        buildFetcher();
+
         MockSelector selector = new MockSelector(time);
         Sensor throttleTimeSensor = Fetcher.throttleTimeSensor(metrics, metricsRegistry);
         Cluster cluster = TestUtils.singletonCluster("test", 1);
@@ -1688,7 +1791,8 @@ public class FetcherTest {
      */
     @Test
     public void testFetcherMetrics() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         MetricName maxLagMetric = metrics.metricInstance(metricsRegistry.recordsLagMax);
@@ -1726,7 +1830,9 @@ public class FetcherTest {
 
     @Test
     public void testFetcherLeadMetric() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         MetricName minLeadMetric = metrics.metricInstance(metricsRegistry.recordsLeadMin);
@@ -1765,11 +1871,10 @@ public class FetcherTest {
 
     @Test
     public void testReadCommittedLagMetric() {
-        Metrics metrics = new Metrics();
-        fetcher = createFetcher(subscriptions, metrics, new ByteArrayDeserializer(),
+        buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
 
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
         MetricName maxLagMetric = metrics.metricInstance(metricsRegistry.recordsLagMax);
@@ -1808,18 +1913,20 @@ public class FetcherTest {
 
     @Test
     public void testFetchResponseMetrics() {
+        buildFetcher();
+
         String topic1 = "foo";
         String topic2 = "bar";
         TopicPartition tp1 = new TopicPartition(topic1, 0);
         TopicPartition tp2 = new TopicPartition(topic2, 0);
 
+        subscriptions.assignFromUser(Utils.mkSet(tp1, tp2));
+
         Map<String, Integer> partitionCounts = new HashMap<>();
         partitionCounts.put(topic1, 1);
         partitionCounts.put(topic2, 1);
         client.updateMetadata(TestUtils.metadataUpdateWith(1, partitionCounts));
 
-        subscriptions.assignFromUser(Utils.mkSet(tp1, tp2));
-
         int expectedBytes = 0;
         LinkedHashMap<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> fetchPartitionData = new LinkedHashMap<>();
 
@@ -1842,7 +1949,7 @@ public class FetcherTest {
         client.prepareResponse(new FetchResponse<>(Errors.NONE, fetchPartitionData, 0, INVALID_SESSION_ID));
         consumerClient.poll(time.timer(0));
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
         assertEquals(3, fetchedRecords.get(tp1).size());
         assertEquals(3, fetchedRecords.get(tp2).size());
 
@@ -1855,7 +1962,9 @@ public class FetcherTest {
 
     @Test
     public void testFetchResponseMetricsPartialResponse() {
-        subscriptions.assignFromUser(singleton(tp0));
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
         Map<MetricName, KafkaMetric> allMetrics = metrics.metrics();
@@ -1881,7 +1990,8 @@ public class FetcherTest {
 
     @Test
     public void testFetchResponseMetricsWithOnePartitionError() {
-        subscriptions.assignFromUser(Utils.mkSet(tp0, tp1));
+        buildFetcher();
+        assignFromUser(Utils.mkSet(tp0, tp1));
         subscriptions.seek(tp0, 0);
         subscriptions.seek(tp1, 0);
 
@@ -1917,7 +2027,9 @@ public class FetcherTest {
 
     @Test
     public void testFetchResponseMetricsWithOnePartitionAtTheWrongOffset() {
-        subscriptions.assignFromUser(Utils.mkSet(tp0, tp1));
+        buildFetcher();
+
+        assignFromUser(Utils.mkSet(tp0, tp1));
         subscriptions.seek(tp0, 0);
         subscriptions.seek(tp1, 0);
 
@@ -1957,22 +2069,19 @@ public class FetcherTest {
     }
 
     @Test
-    public void testFetcherMetricsTemplates() throws Exception {
-        metrics.close();
+    public void testFetcherMetricsTemplates() {
         Map<String, String> clientTags = Collections.singletonMap("client-id", "clientA");
-        metrics = new Metrics(new MetricConfig().tags(clientTags));
-        metricsRegistry = new FetcherMetricsRegistry(clientTags.keySet(), "consumer" + groupId);
-        fetcher.close();
-        fetcher = createFetcher(subscriptions, metrics);
+        buildFetcher(new MetricConfig().tags(clientTags), OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(),
+                new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED);
 
         // Fetch from topic to generate topic metrics
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
         assertTrue(partitionRecords.containsKey(tp0));
 
         // Create throttle metrics
@@ -1998,7 +2107,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp, records, error, hw, lastStableOffset, throttleTime));
         consumerClient.poll(time.timer(0));
-        return fetcher.fetchedRecords();
+        return fetchedRecords();
     }
 
     private Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchRecords(
@@ -2006,13 +2115,14 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fetchResponse(tp, records, error, hw, lastStableOffset, logStartOffset, throttleTime));
         consumerClient.poll(time.timer(0));
-        return fetcher.fetchedRecords();
+        return fetchedRecords();
     }
 
     @Test
     public void testGetOffsetsForTimesTimeout() {
         try {
-            fetcher.offsetsByTimes(Collections.singletonMap(new TopicPartition(topicName, 2), 1000L), time.timer(100L));
+            buildFetcher();
+            fetcher.offsetsForTimes(Collections.singletonMap(new TopicPartition(topicName, 2), 1000L), time.timer(100L));
             fail("Should throw timeout exception.");
         } catch (TimeoutException e) {
             // let it go.
@@ -2021,8 +2131,10 @@ public class FetcherTest {
 
     @Test
     public void testGetOffsetsForTimes() {
+        buildFetcher();
+
         // Empty map
-        assertTrue(fetcher.offsetsByTimes(new HashMap<TopicPartition, Long>(), time.timer(100L)).isEmpty());
+        assertTrue(fetcher.offsetsForTimes(new HashMap<TopicPartition, Long>(), time.timer(100L)).isEmpty());
         // Unknown Offset
         testGetOffsetsForTimesWithUnknownOffset();
         // Error code none with unknown offset
@@ -2042,9 +2154,10 @@ public class FetcherTest {
 
     @Test
     public void testGetOffsetsFencedLeaderEpoch() {
+        buildFetcher();
+        subscriptions.assignFromUser(singleton(tp0));
         client.updateMetadata(initialUpdateResponse);
 
-        subscriptions.assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
         client.prepareResponse(listOffsetResponse(Errors.FENCED_LEADER_EPOCH, 1L, 5L));
@@ -2059,6 +2172,7 @@ public class FetcherTest {
 
     @Test
     public void testGetOffsetsUnknownLeaderEpoch() {
+        buildFetcher();
         subscriptions.assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
 
@@ -2074,6 +2188,9 @@ public class FetcherTest {
 
     @Test
     public void testGetOffsetsIncludesLeaderEpoch() {
+        buildFetcher();
+        subscriptions.assignFromUser(singleton(tp0));
+
         client.updateMetadata(initialUpdateResponse);
 
         // Metadata update with leader epochs
@@ -2083,7 +2200,6 @@ public class FetcherTest {
         client.updateMetadata(metadataResponse);
 
         // Request latest offset
-        subscriptions.assignFromUser(singleton(tp0));
         subscriptions.requestOffsetReset(tp0);
         fetcher.resetOffsetsIfNeeded();
 
@@ -2107,6 +2223,8 @@ public class FetcherTest {
 
     @Test
     public void testGetOffsetsForTimesWhenSomeTopicPartitionLeadersNotKnownInitially() {
+        buildFetcher();
+
         final String anotherTopic = "another-topic";
         final TopicPartition t2p0 = new TopicPartition(anotherTopic, 0);
 
@@ -2117,7 +2235,7 @@ public class FetcherTest {
         client.updateMetadata(initialMetadata);
 
         // The first metadata refresh should contain one topic
-        client.prepareMetadataUpdate(initialMetadata, false);
+        client.prepareMetadataUpdate(initialMetadata);
         client.prepareResponseFrom(listOffsetResponse(tp0, Errors.NONE, 1000L, 11L),
                 metadata.fetch().leaderFor(tp0));
         client.prepareResponseFrom(listOffsetResponse(tp1, Errors.NONE, 1000L, 32L),
@@ -2128,7 +2246,7 @@ public class FetcherTest {
         partitionNumByTopic.put(topicName, 2);
         partitionNumByTopic.put(anotherTopic, 1);
         MetadataResponse updatedMetadata = TestUtils.metadataUpdateWith(3, partitionNumByTopic);
-        client.prepareMetadataUpdate(updatedMetadata, false);
+        client.prepareMetadataUpdate(updatedMetadata);
         client.prepareResponseFrom(listOffsetResponse(t2p0, Errors.NONE, 1000L, 54L),
                 metadata.fetch().leaderFor(t2p0));
 
@@ -2137,13 +2255,13 @@ public class FetcherTest {
         timestampToSearch.put(tp1, ListOffsetRequest.LATEST_TIMESTAMP);
         timestampToSearch.put(t2p0, ListOffsetRequest.LATEST_TIMESTAMP);
         Map<TopicPartition, OffsetAndTimestamp> offsetAndTimestampMap =
-            fetcher.offsetsByTimes(timestampToSearch, time.timer(Long.MAX_VALUE));
+            fetcher.offsetsForTimes(timestampToSearch, time.timer(Long.MAX_VALUE));
 
-        assertNotNull("Expect Fetcher.offsetsByTimes() to return non-null result for " + tp0,
+        assertNotNull("Expect Fetcher.offsetsForTimes() to return non-null result for " + tp0,
                       offsetAndTimestampMap.get(tp0));
-        assertNotNull("Expect Fetcher.offsetsByTimes() to return non-null result for " + tp1,
+        assertNotNull("Expect Fetcher.offsetsForTimes() to return non-null result for " + tp1,
                       offsetAndTimestampMap.get(tp1));
-        assertNotNull("Expect Fetcher.offsetsByTimes() to return non-null result for " + t2p0,
+        assertNotNull("Expect Fetcher.offsetsForTimes() to return non-null result for " + t2p0,
                       offsetAndTimestampMap.get(t2p0));
         assertEquals(11L, offsetAndTimestampMap.get(tp0).offset());
         assertEquals(32L, offsetAndTimestampMap.get(tp1).offset());
@@ -2152,6 +2270,8 @@ public class FetcherTest {
 
     @Test(expected = TimeoutException.class)
     public void testBatchedListOffsetsMetadataErrors() {
+        buildFetcher();
+
         Map<TopicPartition, ListOffsetResponse.PartitionData> partitionData = new HashMap<>();
         partitionData.put(tp0, new ListOffsetResponse.PartitionData(Errors.NOT_LEADER_FOR_PARTITION,
                 ListOffsetResponse.UNKNOWN_TIMESTAMP, ListOffsetResponse.UNKNOWN_OFFSET,
@@ -2165,12 +2285,12 @@ public class FetcherTest {
         offsetsToSearch.put(tp0, ListOffsetRequest.EARLIEST_TIMESTAMP);
         offsetsToSearch.put(tp1, ListOffsetRequest.EARLIEST_TIMESTAMP);
 
-        fetcher.offsetsByTimes(offsetsToSearch, time.timer(0));
+        fetcher.offsetsForTimes(offsetsToSearch, time.timer(0));
     }
 
     @Test
     public void testSkippingAbortedTransactions() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(), new ByteArrayDeserializer(),
+        buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
         ByteBuffer buffer = ByteBuffer.allocate(1024);
         int currentOffset = 0;
@@ -2186,7 +2306,7 @@ public class FetcherTest {
         List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
         abortedTransactions.add(new FetchResponse.AbortedTransaction(1, 0));
         MemoryRecords records = MemoryRecords.readableRecords(buffer);
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
 
         subscriptions.seek(tp0, 0);
 
@@ -2198,13 +2318,13 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
         assertFalse(fetchedRecords.containsKey(tp0));
     }
 
     @Test
     public void testReturnCommittedTransactions() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(), new ByteArrayDeserializer(),
+        buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
         ByteBuffer buffer = ByteBuffer.allocate(1024);
         int currentOffset = 0;
@@ -2218,7 +2338,7 @@ public class FetcherTest {
 
         List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
         MemoryRecords records = MemoryRecords.readableRecords(buffer);
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
 
         subscriptions.seek(tp0, 0);
 
@@ -2237,14 +2357,14 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
         assertTrue(fetchedRecords.containsKey(tp0));
         assertEquals(fetchedRecords.get(tp0).size(), 2);
     }
 
     @Test
     public void testReadCommittedWithCommittedAndAbortedTransactions() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(), new ByteArrayDeserializer(),
+        buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
         ByteBuffer buffer = ByteBuffer.allocate(1024);
 
@@ -2295,7 +2415,7 @@ public class FetcherTest {
         buffer.flip();
 
         MemoryRecords records = MemoryRecords.readableRecords(buffer);
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
 
         subscriptions.seek(tp0, 0);
 
@@ -2307,7 +2427,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
         assertTrue(fetchedRecords.containsKey(tp0));
         // There are only 3 committed records
         List<ConsumerRecord<byte[], byte[]>> fetchedConsumerRecords = fetchedRecords.get(tp0);
@@ -2320,7 +2440,7 @@ public class FetcherTest {
 
     @Test
     public void testMultipleAbortMarkers() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(), new ByteArrayDeserializer(),
+        buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
         ByteBuffer buffer = ByteBuffer.allocate(1024);
         int currentOffset = 0;
@@ -2342,7 +2462,7 @@ public class FetcherTest {
         List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
         abortedTransactions.add(new FetchResponse.AbortedTransaction(1, 0));
         MemoryRecords records = MemoryRecords.readableRecords(buffer);
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
 
         subscriptions.seek(tp0, 0);
 
@@ -2354,7 +2474,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
         assertTrue(fetchedRecords.containsKey(tp0));
         assertEquals(fetchedRecords.get(tp0).size(), 2);
         List<ConsumerRecord<byte[], byte[]>> fetchedConsumerRecords = fetchedRecords.get(tp0);
@@ -2368,7 +2488,7 @@ public class FetcherTest {
 
     @Test
     public void testReadCommittedAbortMarkerWithNoData() {
-        Fetcher<String, String> fetcher = createFetcher(subscriptions, new Metrics(), new StringDeserializer(),
+        buildFetcher(OffsetResetStrategy.EARLIEST, new StringDeserializer(),
                 new StringDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
         ByteBuffer buffer = ByteBuffer.allocate(1024);
 
@@ -2386,7 +2506,7 @@ public class FetcherTest {
         buffer.flip();
 
         // send the fetch
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
 
@@ -2399,7 +2519,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<String, String>>> allFetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<String, String>>> allFetchedRecords = fetchedRecords();
         assertTrue(allFetchedRecords.containsKey(tp0));
         List<ConsumerRecord<String, String>> fetchedRecords = allFetchedRecords.get(tp0);
         assertEquals(3, fetchedRecords.size());
@@ -2408,6 +2528,8 @@ public class FetcherTest {
 
     @Test
     public void testUpdatePositionWithLastRecordMissingFromBatch() {
+        buildFetcher();
+
         MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE,
                 new SimpleRecord("0".getBytes(), "v".getBytes()),
                 new SimpleRecord("1".getBytes(), "v".getBytes()),
@@ -2429,14 +2551,14 @@ public class FetcherTest {
         result.outputBuffer().flip();
         MemoryRecords compactedRecords = MemoryRecords.readableRecords(result.outputBuffer());
 
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, compactedRecords, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> allFetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> allFetchedRecords = fetchedRecords();
         assertTrue(allFetchedRecords.containsKey(tp0));
         List<ConsumerRecord<byte[], byte[]>> fetchedRecords = allFetchedRecords.get(tp0);
         assertEquals(3, fetchedRecords.size());
@@ -2451,6 +2573,8 @@ public class FetcherTest {
 
     @Test
     public void testUpdatePositionOnEmptyBatch() {
+        buildFetcher();
+
         long producerId = 1;
         short producerEpoch = 0;
         int sequence = 1;
@@ -2464,14 +2588,14 @@ public class FetcherTest {
         buffer.flip();
         MemoryRecords recordsWithEmptyBatch = MemoryRecords.readableRecords(buffer);
 
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp0, recordsWithEmptyBatch, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> allFetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> allFetchedRecords = fetchedRecords();
         assertTrue(allFetchedRecords.isEmpty());
 
         // The next offset should point to the next batch
@@ -2480,7 +2604,7 @@ public class FetcherTest {
 
     @Test
     public void testReadCommittedWithCompactedTopic() {
-        Fetcher<String, String> fetcher = createFetcher(subscriptions, new Metrics(), new StringDeserializer(),
+        buildFetcher(OffsetResetStrategy.EARLIEST, new StringDeserializer(),
                 new StringDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
         ByteBuffer buffer = ByteBuffer.allocate(1024);
 
@@ -2519,7 +2643,7 @@ public class FetcherTest {
         buffer.flip();
 
         // send the fetch
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
 
@@ -2533,7 +2657,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<String, String>>> allFetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<String, String>>> allFetchedRecords = fetchedRecords();
         assertTrue(allFetchedRecords.containsKey(tp0));
         List<ConsumerRecord<String, String>> fetchedRecords = allFetchedRecords.get(tp0);
         assertEquals(5, fetchedRecords.size());
@@ -2542,7 +2666,7 @@ public class FetcherTest {
 
     @Test
     public void testReturnAbortedTransactionsinUncommittedMode() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(), new ByteArrayDeserializer(),
+        buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED);
         ByteBuffer buffer = ByteBuffer.allocate(1024);
         int currentOffset = 0;
@@ -2558,7 +2682,7 @@ public class FetcherTest {
         List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
         abortedTransactions.add(new FetchResponse.AbortedTransaction(1, 0));
         MemoryRecords records = MemoryRecords.readableRecords(buffer);
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
 
         subscriptions.seek(tp0, 0);
 
@@ -2570,13 +2694,13 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
         assertTrue(fetchedRecords.containsKey(tp0));
     }
 
     @Test
     public void testConsumerPositionUpdatedWhenSkippingAbortedTransactions() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(), new ByteArrayDeserializer(),
+        buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
         ByteBuffer buffer = ByteBuffer.allocate(1024);
         long currentOffset = 0;
@@ -2591,7 +2715,7 @@ public class FetcherTest {
         List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
         abortedTransactions.add(new FetchResponse.AbortedTransaction(1, 0));
         MemoryRecords records = MemoryRecords.readableRecords(buffer);
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
 
         subscriptions.seek(tp0, 0);
 
@@ -2603,7 +2727,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
 
         // Ensure that we don't return any of the aborted records, but yet advance the consumer position.
         assertFalse(fetchedRecords.containsKey(tp0));
@@ -2612,10 +2736,10 @@ public class FetcherTest {
 
     @Test
     public void testConsumingViaIncrementalFetchRequests() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(time), 2);
+        buildFetcher(2);
 
         List<ConsumerRecord<byte[], byte[]>> records;
-        subscriptions.assignFromUser(new HashSet<>(Arrays.asList(tp0, tp1)));
+        assignFromUser(new HashSet<>(Arrays.asList(tp0, tp1)));
         subscriptions.seek(tp0, 0);
         subscriptions.seek(tp1, 1);
 
@@ -2631,7 +2755,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
         assertFalse(fetchedRecords.containsKey(tp1));
         records = fetchedRecords.get(tp0);
         assertEquals(2, records.size());
@@ -2642,7 +2766,7 @@ public class FetcherTest {
 
         // There is still a buffered record.
         assertEquals(0, fetcher.sendFetches());
-        fetchedRecords = fetcher.fetchedRecords();
+        fetchedRecords = fetchedRecords();
         assertFalse(fetchedRecords.containsKey(tp1));
         records = fetchedRecords.get(tp0);
         assertEquals(1, records.size());
@@ -2655,7 +2779,7 @@ public class FetcherTest {
         client.prepareResponse(resp2);
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
-        fetchedRecords = fetcher.fetchedRecords();
+        fetchedRecords = fetchedRecords();
         assertTrue(fetchedRecords.isEmpty());
         assertEquals(4L, subscriptions.position(tp0).longValue());
         assertEquals(1L, subscriptions.position(tp1).longValue());
@@ -2668,7 +2792,7 @@ public class FetcherTest {
         client.prepareResponse(resp3);
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
-        fetchedRecords = fetcher.fetchedRecords();
+        fetchedRecords = fetchedRecords();
         assertFalse(fetchedRecords.containsKey(tp1));
         records = fetchedRecords.get(tp0);
         assertEquals(2, records.size());
@@ -2685,13 +2809,9 @@ public class FetcherTest {
         for (int i = 0; i < numPartitions; i++)
             topicPartitions.add(new TopicPartition(topicName, i));
 
-        MetadataResponse initialMetadataResponse = TestUtils.metadataUpdateWith(1,
-                singletonMap(topicName, numPartitions));
-        client.updateMetadata(initialMetadataResponse);
-        node = metadata.fetch().nodes().get(0);
-        fetchSize = 10000;
+        buildDependencies(new MetricConfig(), OffsetResetStrategy.EARLIEST);
 
-        Fetcher<byte[], byte[]> fetcher = new Fetcher<byte[], byte[]>(
+        fetcher = new Fetcher<byte[], byte[]>(
                 new LogContext(),
                 consumerClient,
                 minBytes,
@@ -2756,7 +2876,12 @@ public class FetcherTest {
             }
         };
 
-        subscriptions.assignFromUser(topicPartitions);
+        MetadataResponse initialMetadataResponse = TestUtils.metadataUpdateWith(1,
+                singletonMap(topicName, numPartitions));
+        client.updateMetadata(initialMetadataResponse);
+        fetchSize = 10000;
+
+        assignFromUser(topicPartitions);
         topicPartitions.forEach(tp -> subscriptions.seek(tp, 0L));
 
         AtomicInteger fetchesRemaining = new AtomicInteger(1000);
@@ -2790,7 +2915,7 @@ public class FetcherTest {
                 }
             }
             if (fetcher.hasCompletedFetches()) {
-                Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+                Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
                 if (!fetchedRecords.isEmpty()) {
                     fetchesRemaining.decrementAndGet();
                     fetchedRecords.entrySet().forEach(entry -> {
@@ -2810,7 +2935,7 @@ public class FetcherTest {
 
     @Test
     public void testEmptyControlBatch() {
-        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(), new ByteArrayDeserializer(),
+        buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
         ByteBuffer buffer = ByteBuffer.allocate(1024);
         int currentOffset = 1;
@@ -2830,7 +2955,7 @@ public class FetcherTest {
 
         List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
         MemoryRecords records = MemoryRecords.readableRecords(buffer);
-        subscriptions.assignFromUser(singleton(tp0));
+        assignFromUser(singleton(tp0));
 
         subscriptions.seek(tp0, 0);
 
@@ -2849,7 +2974,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
-        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
         assertTrue(fetchedRecords.containsKey(tp0));
         assertEquals(fetchedRecords.get(tp0).size(), 2);
     }
@@ -2930,7 +3055,7 @@ public class FetcherTest {
         timestampToSearch.put(t2p0, 0L);
         timestampToSearch.put(tp1, 0L);
         Map<TopicPartition, OffsetAndTimestamp> offsetAndTimestampMap =
-                fetcher.offsetsByTimes(timestampToSearch, time.timer(Long.MAX_VALUE));
+                fetcher.offsetsForTimes(timestampToSearch, time.timer(Long.MAX_VALUE));
 
         if (expectedOffsetForP0 == null)
             assertNull(offsetAndTimestampMap.get(t2p0));
@@ -2964,7 +3089,7 @@ public class FetcherTest {
         Map<TopicPartition, Long> timestampToSearch = new HashMap<>();
         timestampToSearch.put(tp0, 0L);
         Map<TopicPartition, OffsetAndTimestamp> offsetAndTimestampMap =
-                fetcher.offsetsByTimes(timestampToSearch, time.timer(Long.MAX_VALUE));
+                fetcher.offsetsForTimes(timestampToSearch, time.timer(Long.MAX_VALUE));
 
         assertTrue(offsetAndTimestampMap.containsKey(tp0));
         assertNull(offsetAndTimestampMap.get(tp0));
@@ -3041,32 +3166,43 @@ public class FetcherTest {
                 initialUpdateResponse.controller().id(), Collections.singletonList(topicMetadata));
     }
 
-    private Fetcher<byte[], byte[]> createFetcher(SubscriptionState subscriptions,
-                                                  Metrics metrics,
-                                                  int maxPollRecords) {
-        return createFetcher(subscriptions, metrics, new ByteArrayDeserializer(), new ByteArrayDeserializer(),
+    @SuppressWarnings("unchecked")
+    private <K, V> Map<TopicPartition, List<ConsumerRecord<K, V>>> fetchedRecords() {
+        return (Map) fetcher.fetchedRecords();
+    }
+
+    private void buildFetcher(int maxPollRecords) {
+        buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), new ByteArrayDeserializer(),
                 maxPollRecords, IsolationLevel.READ_UNCOMMITTED);
     }
 
-    private Fetcher<byte[], byte[]> createFetcher(SubscriptionState subscriptions, Metrics metrics) {
-        return createFetcher(subscriptions, metrics, Integer.MAX_VALUE);
+    private void buildFetcher() {
+        buildFetcher(Integer.MAX_VALUE);
+    }
+
+    private void buildFetcher(Deserializer<?> keyDeserializer,
+                              Deserializer<?> valueDeserializer) {
+        buildFetcher(OffsetResetStrategy.EARLIEST, keyDeserializer, valueDeserializer,
+                Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED);
     }
 
-    private <K, V> Fetcher<K, V> createFetcher(SubscriptionState subscriptions,
-                                               Metrics metrics,
-                                               Deserializer<K> keyDeserializer,
-                                               Deserializer<V> valueDeserializer) {
-        return createFetcher(subscriptions, metrics, keyDeserializer, valueDeserializer, Integer.MAX_VALUE,
-                IsolationLevel.READ_UNCOMMITTED);
+    private <K, V> void buildFetcher(OffsetResetStrategy offsetResetStrategy,
+                                     Deserializer<K> keyDeserializer,
+                                     Deserializer<V> valueDeserializer,
+                                     int maxPollRecords,
+                                     IsolationLevel isolationLevel) {
+        buildFetcher(new MetricConfig(), offsetResetStrategy, keyDeserializer, valueDeserializer,
+                maxPollRecords, isolationLevel);
     }
 
-    private <K, V> Fetcher<K, V> createFetcher(SubscriptionState subscriptions,
-                                               Metrics metrics,
-                                               Deserializer<K> keyDeserializer,
-                                               Deserializer<V> valueDeserializer,
-                                               int maxPollRecords,
-                                               IsolationLevel isolationLevel) {
-        return new Fetcher<>(
+    private <K, V> void buildFetcher(MetricConfig metricConfig,
+                                     OffsetResetStrategy offsetResetStrategy,
+                                     Deserializer<K> keyDeserializer,
+                                     Deserializer<V> valueDeserializer,
+                                     int maxPollRecords,
+                                     IsolationLevel isolationLevel) {
+        buildDependencies(metricConfig, offsetResetStrategy);
+        fetcher = new Fetcher<>(
                 new LogContext(),
                 consumerClient,
                 minBytes,
@@ -3087,10 +3223,20 @@ public class FetcherTest {
                 isolationLevel);
     }
 
+    private void buildDependencies(MetricConfig metricConfig, OffsetResetStrategy offsetResetStrategy) {
+        LogContext logContext = new LogContext();
+        time = new MockTime(1);
+        subscriptions = new SubscriptionState(logContext, offsetResetStrategy);
+        metadata = new ConsumerMetadata(0, Long.MAX_VALUE, false,
+                subscriptions, logContext, new ClusterResourceListeners());
+        client = new MockClient(time, metadata);
+        metrics = new Metrics(metricConfig, time);
+        consumerClient = new ConsumerNetworkClient(logContext, client, metadata, time,
+                100, 1000, Integer.MAX_VALUE);
+        metricsRegistry = new FetcherMetricsRegistry(metricConfig.tags().keySet(), "consumer" + groupId);
+    }
+
     private <T> List<Long> collectRecordOffsets(List<ConsumerRecord<T, T>> records) {
-        List<Long> res = new ArrayList<>(records.size());
-        for (ConsumerRecord<?, ?> record : records)
-            res.add(record.offset());
-        return res;
+        return records.stream().map(ConsumerRecord::offset).collect(Collectors.toList());
     }
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
index f3faf26..638cb7b 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
@@ -18,9 +18,9 @@ package org.apache.kafka.clients.producer;
 
 import org.apache.kafka.clients.CommonClientConfigs;
 import org.apache.kafka.clients.KafkaClient;
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.producer.internals.ProducerInterceptors;
+import org.apache.kafka.clients.producer.internals.ProducerMetadata;
 import org.apache.kafka.clients.producer.internals.Sender;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.KafkaException;
@@ -64,6 +64,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -241,7 +242,7 @@ public class KafkaProducerTest {
 
         Time time = new MockTime();
         MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, singletonMap("topic", 1));
-        Metadata metadata = new Metadata(0, Long.MAX_VALUE, true);
+        ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE);
         MockClient client = new MockClient(time, metadata);
         client.updateMetadata(initialUpdateResponse);
 
@@ -312,7 +313,7 @@ public class KafkaProducerTest {
     public void testMetadataFetch() throws InterruptedException {
         Map<String, Object> configs = new HashMap<>();
         configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
-        Metadata metadata = mock(Metadata.class);
+        ProducerMetadata metadata = mock(ProducerMetadata.class);
 
         // Return empty cluster 4 times and cluster from then on
         when(metadata.fetch()).thenReturn(emptyCluster, emptyCluster, emptyCluster, emptyCluster, onePartitionCluster);
@@ -320,9 +321,9 @@ public class KafkaProducerTest {
         KafkaProducer<String, String> producer = new KafkaProducer<String, String>(configs, new StringSerializer(),
                 new StringSerializer(), metadata, new MockClient(Time.SYSTEM, metadata), null, Time.SYSTEM) {
             @Override
-            Sender newSender(LogContext logContext, KafkaClient kafkaClient, Metadata metadata) {
+            Sender newSender(LogContext logContext, KafkaClient kafkaClient, ProducerMetadata metadata) {
                 // give Sender its own Metadata instance so that we can isolate Metadata calls from KafkaProducer
-                return super.newSender(logContext, kafkaClient, new Metadata(0, 100_000, true));
+                return super.newSender(logContext, kafkaClient, newMetadata(0, 100_000));
             }
         };
         ProducerRecord<String, String> record = new ProducerRecord<>(topic, "value");
@@ -356,7 +357,7 @@ public class KafkaProducerTest {
 
         // Create a record with a partition higher than the initial (outdated) partition range
         ProducerRecord<String, String> record = new ProducerRecord<>(topic, 2, null, "value");
-        Metadata metadata = mock(Metadata.class);
+        ProducerMetadata metadata = mock(ProducerMetadata.class);
 
         MockTime mockTime = new MockTime();
         AtomicInteger invocationCount = new AtomicInteger(0);
@@ -372,9 +373,9 @@ public class KafkaProducerTest {
         KafkaProducer<String, String> producer = new KafkaProducer<String, String>(configs, new StringSerializer(),
                 new StringSerializer(), metadata, new MockClient(Time.SYSTEM, metadata), null, mockTime) {
             @Override
-            Sender newSender(LogContext logContext, KafkaClient kafkaClient, Metadata metadata) {
+            Sender newSender(LogContext logContext, KafkaClient kafkaClient, ProducerMetadata metadata) {
                 // give Sender its own Metadata instance so that we can isolate Metadata calls from KafkaProducer
-                return super.newSender(logContext, kafkaClient, new Metadata(0, 100_000, true));
+                return super.newSender(logContext, kafkaClient, newMetadata(0, 100_000));
             }
         };
 
@@ -401,7 +402,7 @@ public class KafkaProducerTest {
 
         // Create a record with a partition higher than the initial (outdated) partition range
         ProducerRecord<String, String> record = new ProducerRecord<>(topic, 2, null, "value");
-        Metadata metadata = mock(Metadata.class);
+        ProducerMetadata metadata = mock(ProducerMetadata.class);
 
         MockTime mockTime = new MockTime();
 
@@ -410,9 +411,9 @@ public class KafkaProducerTest {
         KafkaProducer<String, String> producer = new KafkaProducer<String, String>(configs, new StringSerializer(),
                 new StringSerializer(), metadata, new MockClient(Time.SYSTEM, metadata), null, mockTime) {
             @Override
-            Sender newSender(LogContext logContext, KafkaClient kafkaClient, Metadata metadata) {
+            Sender newSender(LogContext logContext, KafkaClient kafkaClient, ProducerMetadata metadata) {
                 // give Sender its own Metadata instance so that we can isolate Metadata calls from KafkaProducer
-                return super.newSender(logContext, kafkaClient, new Metadata(0, 100_000, true));
+                return super.newSender(logContext, kafkaClient, newMetadata(0, 100_000));
             }
         };
         // One request update if metadata is available but outdated for the given record
@@ -432,7 +433,7 @@ public class KafkaProducerTest {
 
         // Create a record with a partition higher than the initial (outdated) partition range
         ProducerRecord<String, String> record = new ProducerRecord<>(topic, 2, null, "value");
-        Metadata metadata = mock(Metadata.class);
+        ProducerMetadata metadata = mock(ProducerMetadata.class);
 
         MockTime mockTime = new MockTime();
         AtomicInteger invocationCount = new AtomicInteger(0);
@@ -448,9 +449,9 @@ public class KafkaProducerTest {
         KafkaProducer<String, String> producer = new KafkaProducer<String, String>(configs, new StringSerializer(),
                 new StringSerializer(), metadata, new MockClient(Time.SYSTEM, metadata), null, mockTime) {
             @Override
-            Sender newSender(LogContext logContext, KafkaClient kafkaClient, Metadata metadata) {
+            Sender newSender(LogContext logContext, KafkaClient kafkaClient, ProducerMetadata metadata) {
                 // give Sender its own Metadata instance so that we can isolate Metadata calls from KafkaProducer
-                return super.newSender(logContext, kafkaClient, new Metadata(0, 100_000, true));
+                return super.newSender(logContext, kafkaClient, newMetadata(0, 100_000));
             }
         };
 
@@ -476,17 +477,18 @@ public class KafkaProducerTest {
         configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, "600000");
         long refreshBackoffMs = 500L;
         long metadataExpireMs = 60000L;
-        final Metadata metadata = new Metadata(refreshBackoffMs, metadataExpireMs, true,
-                true, new ClusterResourceListeners());
         final Time time = new MockTime();
+        final ProducerMetadata metadata = new ProducerMetadata(refreshBackoffMs, metadataExpireMs,
+                new LogContext(), new ClusterResourceListeners(), time);
         final String topic = "topic";
         try (KafkaProducer<String, String> producer = new KafkaProducer<>(configs, new StringSerializer(),
-                new StringSerializer(), metadata, null, null, time)) {
+                new StringSerializer(), metadata, new MockClient(time, metadata), null, time)) {
 
+            AtomicBoolean running = new AtomicBoolean(true);
             Thread t = new Thread(() -> {
                 long startTimeMs = System.currentTimeMillis();
-                for (int i = 0; i < 10; i++) {
-                    while (!metadata.updateRequested() && System.currentTimeMillis() - startTimeMs < 1000)
+                while (running.get()) {
+                    while (!metadata.updateRequested() && System.currentTimeMillis() - startTimeMs < 100)
                         Thread.yield();
                     MetadataResponse updateResponse = TestUtils.metadataUpdateWith("kafka-cluster", 1,
                             singletonMap(topic, Errors.UNKNOWN_TOPIC_OR_PARTITION), emptyMap());
@@ -501,9 +503,9 @@ public class KafkaProducerTest {
             } catch (TimeoutException e) {
                 // skip
             }
+            running.set(false);
             t.join();
         }
-        assertTrue("Topic should still exist in metadata", metadata.containsTopic(topic));
     }
 
     @SuppressWarnings("unchecked")
@@ -528,7 +530,9 @@ public class KafkaProducerTest {
         Serializer<String> valueSerializer = mock(serializerClassToMock);
 
         String topic = "topic";
-        Metadata metadata = new Metadata(0, 90000, true);
+        ProducerMetadata metadata = newMetadata(0, 90000);
+        metadata.add(topic);
+
         MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, singletonMap(topic, 1));
         metadata.update(initialUpdateResponse, Time.SYSTEM.milliseconds());
 
@@ -596,7 +600,8 @@ public class KafkaProducerTest {
         String topic = "topic";
         ProducerRecord<String, String> record = new ProducerRecord<>(topic, "value");
 
-        Metadata metadata = new Metadata(0, 90000, true);
+        ProducerMetadata metadata = newMetadata(0, 90000);
+        metadata.add(topic);
         MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, singletonMap(topic, 1));
         metadata.update(initialUpdateResponse, Time.SYSTEM.milliseconds());
 
@@ -636,7 +641,7 @@ public class KafkaProducerTest {
 
         Time time = new MockTime(1);
         MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, singletonMap("topic", 1));
-        Metadata metadata = new Metadata(0, Long.MAX_VALUE, true);
+        ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE);
         metadata.update(initialUpdateResponse, time.milliseconds());
 
         MockClient client = new MockClient(time, metadata);
@@ -657,7 +662,7 @@ public class KafkaProducerTest {
 
         Time time = new MockTime();
         MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, singletonMap("topic", 1));
-        Metadata metadata = new Metadata(0, Long.MAX_VALUE, true);
+        ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE);
         metadata.update(initialUpdateResponse, time.milliseconds());
 
         MockClient client = new MockClient(time, metadata);
@@ -685,7 +690,7 @@ public class KafkaProducerTest {
 
         Time time = new MockTime();
         MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, emptyMap());
-        Metadata metadata = new Metadata(0, Long.MAX_VALUE, true);
+        ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE);
         metadata.update(initialUpdateResponse, time.milliseconds());
 
         MockClient client = new MockClient(time, metadata);
@@ -725,9 +730,10 @@ public class KafkaProducerTest {
         // block in Metadata#awaitUpdate for the configured max.block.ms. When close() is invoked, KafkaProducer#send should
         // return with a KafkaException.
         String topicName = "test";
-        Time time = new MockTime();
+        Time time = Time.SYSTEM;
         MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, emptyMap());
-        Metadata metadata = new Metadata(0, Long.MAX_VALUE, true);
+        ProducerMetadata metadata = new ProducerMetadata(0, Long.MAX_VALUE,
+                new LogContext(), new ClusterResourceListeners(), time);
         metadata.update(initialUpdateResponse, time.milliseconds());
         MockClient client = new MockClient(time, metadata);
 
@@ -750,7 +756,8 @@ public class KafkaProducerTest {
             });
 
             // Wait until metadata update for the topic has been requested
-            TestUtils.waitForCondition(() -> metadata.containsTopic(topicName), "Timeout when waiting for topic to be added to metadata");
+            TestUtils.waitForCondition(() -> metadata.containsTopic(topicName),
+                    "Timeout when waiting for topic to be added to metadata");
             producer.close(Duration.ofMillis(0));
             TestUtils.waitForCondition(() -> sendException.get() != null, "No producer exception within timeout");
             assertEquals(KafkaException.class, sendException.get().getClass());
@@ -758,4 +765,10 @@ public class KafkaProducerTest {
             executor.shutdownNow();
         }
     }
+
+    private ProducerMetadata newMetadata(long refreshBackoffMs, long expirationMs) {
+        return new ProducerMetadata(refreshBackoffMs, expirationMs,
+                new LogContext(), new ClusterResourceListeners(), Time.SYSTEM);
+    }
+
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerMetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerMetadataTest.java
new file mode 100644
index 0000000..aaf3857
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerMetadataTest.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.producer.internals;
+
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
+import org.apache.kafka.common.requests.MetadataResponse;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.test.TestUtils;
+import org.junit.After;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class ProducerMetadataTest {
+
+    private long refreshBackoffMs = 100;
+    private long metadataExpireMs = 1000;
+    private ProducerMetadata metadata = new ProducerMetadata(refreshBackoffMs, metadataExpireMs, new LogContext(),
+            new ClusterResourceListeners(), Time.SYSTEM);
+    private AtomicReference<Exception> backgroundError = new AtomicReference<>();
+
+    @After
+    public void tearDown() {
+        assertNull("Exception in background thread : " + backgroundError.get(), backgroundError.get());
+    }
+
+    @Test
+    public void testMetadata() throws Exception {
+        String topic = "my-topic";
+        metadata.add(topic);
+
+        long time = Time.SYSTEM.milliseconds();
+        metadata.update(responseWithTopics(Collections.emptySet()), time);
+        assertTrue("No update needed.", metadata.timeToNextUpdate(time) > 0);
+        metadata.requestUpdate();
+        assertTrue("Still no updated needed due to backoff", metadata.timeToNextUpdate(time) > 0);
+        time += refreshBackoffMs;
+        assertEquals("Update needed now that backoff time expired", 0, metadata.timeToNextUpdate(time));
+        Thread t1 = asyncFetch(topic, 500);
+        Thread t2 = asyncFetch(topic, 500);
+        assertTrue("Awaiting update", t1.isAlive());
+        assertTrue("Awaiting update", t2.isAlive());
+        // Perform metadata update when an update is requested on the async fetch thread
+        // This simulates the metadata update sequence in KafkaProducer
+        while (t1.isAlive() || t2.isAlive()) {
+            if (metadata.timeToNextUpdate(time) == 0) {
+                metadata.update(responseWithCurrentTopics(), time);
+                time += refreshBackoffMs;
+            }
+            Thread.sleep(1);
+        }
+        t1.join();
+        t2.join();
+        assertTrue("No update needed.", metadata.timeToNextUpdate(time) > 0);
+        time += metadataExpireMs;
+        assertEquals("Update needed due to stale metadata.", 0, metadata.timeToNextUpdate(time));
+    }
+
+    @Test
+    public void testMetadataAwaitAfterClose() throws InterruptedException {
+        long time = 0;
+        metadata.update(responseWithCurrentTopics(), time);
+        assertTrue("No update needed.", metadata.timeToNextUpdate(time) > 0);
+        metadata.requestUpdate();
+        assertTrue("Still no updated needed due to backoff", metadata.timeToNextUpdate(time) > 0);
+        time += refreshBackoffMs;
+        assertEquals("Update needed now that backoff time expired", 0, metadata.timeToNextUpdate(time));
+        String topic = "my-topic";
+        metadata.close();
+        Thread t1 = asyncFetch(topic, 500);
+        t1.join();
+        assertEquals(KafkaException.class, backgroundError.get().getClass());
+        assertTrue(backgroundError.get().toString().contains("Requested metadata update after close"));
+        clearBackgroundError();
+    }
+
+    /**
+     * Tests that {@link org.apache.kafka.clients.producer.internals.ProducerMetadata#awaitUpdate(int, long)} doesn't
+     * wait forever with a max timeout value of 0
+     *
+     * @throws Exception
+     * @see <a href=https://issues.apache.org/jira/browse/KAFKA-1836>KAFKA-1836</a>
+     */
+    @Test
+    public void testMetadataUpdateWaitTime() throws Exception {
+        long time = 0;
+        metadata.update(responseWithCurrentTopics(), time);
+        assertTrue("No update needed.", metadata.timeToNextUpdate(time) > 0);
+        // first try with a max wait time of 0 and ensure that this returns back without waiting forever
+        try {
+            metadata.awaitUpdate(metadata.requestUpdate(), 0);
+            fail("Wait on metadata update was expected to timeout, but it didn't");
+        } catch (TimeoutException te) {
+            // expected
+        }
+        // now try with a higher timeout value once
+        final long twoSecondWait = 2000;
+        try {
+            metadata.awaitUpdate(metadata.requestUpdate(), twoSecondWait);
+            fail("Wait on metadata update was expected to timeout, but it didn't");
+        } catch (TimeoutException te) {
+            // expected
+        }
+    }
+
+    @Test
+    public void testTimeToNextUpdateOverwriteBackoff() {
+        long now = 10000;
+
+        // New topic added to fetch set and update requested. It should allow immediate update.
+        metadata.update(responseWithCurrentTopics(), now);
+        metadata.add("new-topic");
+        assertEquals(0, metadata.timeToNextUpdate(now));
+
+        // Even though add is called, immediate update isn't necessary if the new topic set isn't
+        // containing a new topic,
+        metadata.update(responseWithCurrentTopics(), now);
+        metadata.add("new-topic");
+        assertEquals(metadataExpireMs, metadata.timeToNextUpdate(now));
+
+        // If the new set of topics containing a new topic then it should allow immediate update.
+        metadata.add("another-new-topic");
+        assertEquals(0, metadata.timeToNextUpdate(now));
+    }
+
+    @Test
+    public void testTopicExpiry() {
+        // Test that topic is expired if not used within the expiry interval
+        long time = 0;
+        String topic1 = "topic1";
+        metadata.add(topic1);
+        metadata.update(responseWithCurrentTopics(), time);
+        assertTrue(metadata.containsTopic(topic1));
+
+        time += ProducerMetadata.TOPIC_EXPIRY_MS;
+        metadata.update(responseWithCurrentTopics(), time);
+        assertFalse("Unused topic not expired", metadata.containsTopic(topic1));
+
+        // Test that topic is not expired if used within the expiry interval
+        metadata.add("topic2");
+        metadata.update(responseWithCurrentTopics(), time);
+        for (int i = 0; i < 3; i++) {
+            time += ProducerMetadata.TOPIC_EXPIRY_MS / 2;
+            metadata.update(responseWithCurrentTopics(), time);
+            assertTrue("Topic expired even though in use", metadata.containsTopic("topic2"));
+            metadata.add("topic2");
+        }
+    }
+
+    private MetadataResponse responseWithCurrentTopics() {
+        return responseWithTopics(metadata.topics());
+    }
+
+    private MetadataResponse responseWithTopics(Set<String> topics) {
+        Map<String, Integer> partitionCounts = new HashMap<>();
+        for (String topic : topics)
+            partitionCounts.put(topic, 1);
+        return TestUtils.metadataUpdateWith(1, partitionCounts);
+    }
+
+    private void clearBackgroundError() {
+        backgroundError.set(null);
+    }
+
+    private Thread asyncFetch(final String topic, final long maxWaitMs) {
+        Thread thread = new Thread(() -> {
+            try {
+                while (metadata.fetch().partitionsForTopic(topic).isEmpty())
+                    metadata.awaitUpdate(metadata.requestUpdate(), maxWaitMs);
+            } catch (Exception e) {
+                backgroundError.set(e);
+            }
+        });
+        thread.start();
+        return thread;
+    }
+
+}
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
index 0aa752e..4cbdaa2 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
@@ -16,27 +16,9 @@
  */
 package org.apache.kafka.clients.producer.internals;
 
-import java.nio.ByteBuffer;
-import java.util.Collections;
-import java.util.Deque;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.IdentityHashMap;
-import java.util.Iterator;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.OptionalInt;
-import java.util.OptionalLong;
-import java.util.Set;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Future;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReference;
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.ClientDnsLookup;
 import org.apache.kafka.clients.ClientRequest;
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.clients.NodeApiVersions;
@@ -91,10 +73,28 @@ import org.junit.Before;
 import org.junit.Test;
 import org.mockito.InOrder;
 
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.IdentityHashMap;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.OptionalInt;
+import java.util.OptionalLong;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.AdditionalMatchers.geq;
 import static org.mockito.ArgumentMatchers.any;
@@ -120,7 +120,8 @@ public class SenderTest {
     private TopicPartition tp1 = new TopicPartition("test", 1);
     private MockTime time = new MockTime();
     private int batchSize = 16 * 1024;
-    private Metadata metadata = new Metadata(0, Long.MAX_VALUE, true, true, new ClusterResourceListeners());
+    private ProducerMetadata metadata = new ProducerMetadata(0, Long.MAX_VALUE,
+            new LogContext(), new ClusterResourceListeners(), time);
     private MockClient client = new MockClient(time, metadata);
     private ApiVersions apiVersions = new ApiVersions();
     private Metrics metrics = null;
@@ -486,7 +487,7 @@ public class SenderTest {
     @Test
     public void testMetadataTopicExpiry() throws Exception {
         long offset = 0;
-        client.updateMetadata(TestUtils.metadataUpdateWith(1, Collections.emptyMap()));
+        client.updateMetadata(TestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2)));
 
         Future<RecordMetadata> future = accumulator.append(tp0, time.milliseconds(), "key".getBytes(), "value".getBytes(), null, null, MAX_BLOCK_TIMEOUT).future;
         sender.run(time.milliseconds());
@@ -502,8 +503,8 @@ public class SenderTest {
         assertTrue("Request should be completed", future.isDone());
 
         assertTrue("Topic not retained in metadata list", metadata.containsTopic(tp0.topic()));
-        time.sleep(Metadata.TOPIC_EXPIRY_MS);
-        client.updateMetadata(TestUtils.metadataUpdateWith(1, Collections.emptyMap()));
+        time.sleep(ProducerMetadata.TOPIC_EXPIRY_MS);
+        client.updateMetadata(TestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2)));
         assertFalse("Unused topic has not been expired", metadata.containsTopic(tp0.topic()));
         future = accumulator.append(tp0, time.milliseconds(), "key".getBytes(), "value".getBytes(), null, null, MAX_BLOCK_TIMEOUT).future;
         sender.run(time.milliseconds());
@@ -2298,6 +2299,7 @@ public class SenderTest {
         this.sender = new Sender(logContext, this.client, this.metadata, this.accumulator, guaranteeOrder, MAX_REQUEST_SIZE, ACKS_ALL,
                 Integer.MAX_VALUE, this.senderMetricsRegistry, this.time, REQUEST_TIMEOUT, 50, transactionManager, apiVersions);
 
+        metadata.add("test");
         this.client.updateMetadata(TestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2)));
     }
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
index b476961..97f7f5d 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
@@ -17,7 +17,6 @@
 package org.apache.kafka.clients.producer.internals;
 
 import org.apache.kafka.clients.ApiVersions;
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.producer.RecordMetadata;
@@ -102,7 +101,8 @@ public class TransactionManagerTest {
     private TopicPartition tp0 = new TopicPartition(topic, 0);
     private TopicPartition tp1 = new TopicPartition(topic, 1);
     private MockTime time = new MockTime();
-    private Metadata metadata = new Metadata(0, Long.MAX_VALUE, true, true, new ClusterResourceListeners());
+    private ProducerMetadata metadata = new ProducerMetadata(0, Long.MAX_VALUE, new LogContext(),
+            new ClusterResourceListeners(), time);
     private MockClient client = new MockClient(time, metadata);
 
     private ApiVersions apiVersions = new ApiVersions();
@@ -132,6 +132,7 @@ public class TransactionManagerTest {
                 new BufferPool(totalSize, batchSize, metrics, time, metricGrpName));
         this.sender = new Sender(logContext, this.client, this.metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL,
                 MAX_RETRIES, senderMetrics, this.time, REQUEST_TIMEOUT, 50, transactionManager, apiVersions);
+        this.metadata.add("test");
         this.client.updateMetadata(TestUtils.metadataUpdateWith(1, singletonMap("test", 2)));
     }
 
diff --git a/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java b/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java
index 4dbd03c..be9e1a3 100644
--- a/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java
+++ b/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java
@@ -16,9 +16,12 @@
  */
 package org.apache.kafka.common.utils;
 
+import org.apache.kafka.common.errors.TimeoutException;
+
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Supplier;
 
 /**
  * A clock that you can manually advance by calling sleep
@@ -83,6 +86,27 @@ public class MockTime implements Time {
         tick();
     }
 
+    @Override
+    public void waitObject(Object obj, Supplier<Boolean> condition, long deadlineMs) throws InterruptedException {
+        MockTimeListener listener = () -> {
+            synchronized (obj) {
+                obj.notify();
+            }
+        };
+        listeners.add(listener);
+        try {
+            synchronized (obj) {
+                while (milliseconds() < deadlineMs && !condition.get()) {
+                    obj.wait();
+                }
+                if (!condition.get())
+                    throw new TimeoutException("Condition not satisfied before deadline");
+            }
+        } finally {
+            listeners.remove(listener);
+        }
+    }
+
     public void setCurrentTimeMs(long newMs) {
         long oldMs = timeMs.getAndSet(newMs);
 
diff --git a/clients/src/test/java/org/apache/kafka/common/utils/MockTimeTest.java b/clients/src/test/java/org/apache/kafka/common/utils/MockTimeTest.java
index 7bd302b..d8101ac 100644
--- a/clients/src/test/java/org/apache/kafka/common/utils/MockTimeTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/utils/MockTimeTest.java
@@ -16,15 +16,13 @@
  */
 package org.apache.kafka.common.utils;
 
-import org.junit.Assert;
 import org.junit.Rule;
-import org.junit.rules.Timeout;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 import org.junit.Test;
+import org.junit.rules.Timeout;
+
+import static org.junit.Assert.assertEquals;
 
-public class MockTimeTest {
-    private static final Logger log = LoggerFactory.getLogger(MockTimeTest.class);
+public class MockTimeTest extends TimeTest {
 
     @Rule
     final public Timeout globalTimeout = Timeout.millis(120000);
@@ -32,19 +30,24 @@ public class MockTimeTest {
     @Test
     public void testAdvanceClock() {
         MockTime time = new MockTime(0, 100, 200);
-        Assert.assertEquals(100, time.milliseconds());
-        Assert.assertEquals(200, time.nanoseconds());
+        assertEquals(100, time.milliseconds());
+        assertEquals(200, time.nanoseconds());
         time.sleep(1);
-        Assert.assertEquals(101, time.milliseconds());
-        Assert.assertEquals(1000200, time.nanoseconds());
+        assertEquals(101, time.milliseconds());
+        assertEquals(1000200, time.nanoseconds());
     }
 
     @Test
     public void testAutoTickMs() {
         MockTime time = new MockTime(1, 100, 200);
-        Assert.assertEquals(101, time.milliseconds());
-        Assert.assertEquals(2000200, time.nanoseconds());
-        Assert.assertEquals(103, time.milliseconds());
-        Assert.assertEquals(104, time.milliseconds());
+        assertEquals(101, time.milliseconds());
+        assertEquals(2000200, time.nanoseconds());
+        assertEquals(103, time.milliseconds());
+        assertEquals(104, time.milliseconds());
+    }
+
+    @Override
+    protected Time createTime() {
+        return new MockTime();
     }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java b/clients/src/test/java/org/apache/kafka/common/utils/SystemTimeTest.java
similarity index 59%
copy from clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java
copy to clients/src/test/java/org/apache/kafka/common/utils/SystemTimeTest.java
index c8b79ab..edc53d2 100644
--- a/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java
+++ b/clients/src/test/java/org/apache/kafka/common/utils/SystemTimeTest.java
@@ -16,30 +16,10 @@
  */
 package org.apache.kafka.common.utils;
 
-/**
- * A time implementation that uses the system clock and sleep call. Use `Time.SYSTEM` instead of creating an instance
- * of this class.
- */
-public class SystemTime implements Time {
-
-    @Override
-    public long milliseconds() {
-        return System.currentTimeMillis();
-    }
+public class SystemTimeTest extends TimeTest {
 
     @Override
-    public long nanoseconds() {
-        return System.nanoTime();
+    protected Time createTime() {
+        return Time.SYSTEM;
     }
-
-    @Override
-    public void sleep(long ms) {
-        try {
-            Thread.sleep(ms);
-        } catch (InterruptedException e) {
-            // just wake up early
-            Thread.currentThread().interrupt();
-        }
-    }
-
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/utils/TimeTest.java b/clients/src/test/java/org/apache/kafka/common/utils/TimeTest.java
new file mode 100644
index 0000000..98878e2
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/utils/TimeTest.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.utils;
+
+import org.apache.kafka.common.errors.TimeoutException;
+import org.junit.Test;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+public abstract class TimeTest {
+
+    protected abstract Time createTime();
+
+    @Test
+    public void testWaitObjectTimeout() throws InterruptedException {
+        Object obj = new Object();
+        Time time = createTime();
+        long timeoutMs = 100;
+        long deadlineMs = time.milliseconds() + timeoutMs;
+        AtomicReference<Exception> caughtException = new AtomicReference<>();
+        Thread t = new Thread(() -> {
+            try {
+                time.waitObject(obj, () -> false, deadlineMs);
+            } catch (Exception e) {
+                caughtException.set(e);
+            }
+        });
+
+        t.start();
+        time.sleep(timeoutMs);
+        t.join();
+
+        assertEquals(TimeoutException.class, caughtException.get().getClass());
+    }
+
+    @Test
+    public void testWaitObjectConditionSatisfied() throws InterruptedException {
+        Object obj = new Object();
+        Time time = createTime();
+        long timeoutMs = 1000000000;
+        long deadlineMs = time.milliseconds() + timeoutMs;
+        AtomicBoolean condition = new AtomicBoolean(false);
+        AtomicReference<Exception> caughtException = new AtomicReference<>();
+        Thread t = new Thread(() -> {
+            try {
+                time.waitObject(obj, condition::get, deadlineMs);
+            } catch (Exception e) {
+                caughtException.set(e);
+            }
+        });
+
+        t.start();
+
+        synchronized (obj) {
+            condition.set(true);
+            obj.notify();
+        }
+
+        t.join();
+
+        assertTrue(time.milliseconds() < deadlineMs);
+        assertNull(caughtException.get());
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/test/TestUtils.java b/clients/src/test/java/org/apache/kafka/test/TestUtils.java
index 900357c..3f9a1b7 100644
--- a/clients/src/test/java/org/apache/kafka/test/TestUtils.java
+++ b/clients/src/test/java/org/apache/kafka/test/TestUtils.java
@@ -83,10 +83,6 @@ public class TestUtils {
         return clusterWith(1);
     }
 
-    public static Cluster singletonCluster(final Map<String, Integer> topicPartitionCounts) {
-        return clusterWith(1, topicPartitionCounts);
-    }
-
     public static Cluster singletonCluster(final String topic, final int partitions) {
         return clusterWith(1, topic, partitions);
     }
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java
index ed48d57..6591be0 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java
@@ -24,6 +24,7 @@ import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient;
 import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.metrics.JmxReporter;
 import org.apache.kafka.common.metrics.MetricConfig;
 import org.apache.kafka.common.metrics.Metrics;
@@ -94,7 +95,8 @@ public class WorkerGroupMember {
             reporters.add(new JmxReporter(JMX_PREFIX));
             this.metrics = new Metrics(metricConfig, reporters, time);
             this.retryBackoffMs = config.getLong(CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG);
-            this.metadata = new Metadata(retryBackoffMs, config.getLong(CommonClientConfigs.METADATA_MAX_AGE_CONFIG), true);
+            this.metadata = new Metadata(retryBackoffMs, config.getLong(CommonClientConfigs.METADATA_MAX_AGE_CONFIG),
+                    logContext, new ClusterResourceListeners());
             List<InetSocketAddress> addresses = ClientUtils.parseAndValidateAddresses(
                     config.getList(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG),
                     config.getString(CommonClientConfigs.CLIENT_DNS_LOOKUP_CONFIG));
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java
index 7ccb68c..68d236f 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java
@@ -20,6 +20,7 @@ import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient;
 import org.apache.kafka.common.Node;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.requests.AbstractRequest;
@@ -87,20 +88,20 @@ public class WorkerCoordinatorTest {
 
     @Before
     public void setup() {
-        LogContext loggerFactory = new LogContext();
+        LogContext logContext = new LogContext();
 
         this.time = new MockTime();
-        this.metadata = new Metadata(0, Long.MAX_VALUE, true);
+        this.metadata = new Metadata(0, Long.MAX_VALUE, logContext, new ClusterResourceListeners());
         this.client = new MockClient(time, metadata);
         this.client.updateMetadata(TestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)));
         this.node = metadata.fetch().nodes().get(0);
-        this.consumerClient = new ConsumerNetworkClient(loggerFactory, client, metadata, time, 100, 1000, heartbeatIntervalMs);
+        this.consumerClient = new ConsumerNetworkClient(logContext, client, metadata, time, 100, 1000, heartbeatIntervalMs);
         this.metrics = new Metrics(time);
         this.rebalanceListener = new MockRebalanceListener();
         this.configStorage = PowerMock.createMock(KafkaConfigBackingStore.class);
 
         this.coordinator = new WorkerCoordinator(
-                loggerFactory,
+                logContext,
                 consumerClient,
                 groupId,
                 rebalanceTimeoutMs,
diff --git a/core/src/main/scala/kafka/admin/AdminClient.scala b/core/src/main/scala/kafka/admin/AdminClient.scala
index bd09db4..17716d3 100644
--- a/core/src/main/scala/kafka/admin/AdminClient.scala
+++ b/core/src/main/scala/kafka/admin/AdminClient.scala
@@ -27,6 +27,7 @@ import org.apache.kafka.common.config.ConfigDef.ValidString._
 import org.apache.kafka.common.config.ConfigDef.{Importance, Type}
 import org.apache.kafka.common.config.{AbstractConfig, ConfigDef}
 import org.apache.kafka.common.errors.{AuthenticationException, TimeoutException}
+import org.apache.kafka.common.internals.ClusterResourceListeners
 import org.apache.kafka.common.message.{DescribeGroupsRequestData, DescribeGroupsResponseData}
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.network.Selector
@@ -350,7 +351,7 @@ class CompositeFuture[T](time: Time,
     val timeoutMs = unit.toMillis(timeout)
     var remaining: Long = timeoutMs
 
-    val observedResults = futures.flatMap{ future =>
+    val observedResults = futures.flatMap { future =>
       val elapsed = time.milliseconds() - start
       remaining = if (timeoutMs - elapsed > 0) timeoutMs - elapsed else 0L
 
@@ -429,9 +430,12 @@ object AdminClient {
   def create(props: Map[String, _]): AdminClient = create(new AdminConfig(props))
 
   def create(config: AdminConfig): AdminClient = {
+    val clientId = "admin-" + AdminClientIdSequence.getAndIncrement()
+    val logContext = new LogContext(s"[LegacyAdminClient clientId=$clientId] ")
     val time = Time.SYSTEM
     val metrics = new Metrics(time)
-    val metadata = new Metadata(100L, 60 * 60 * 1000L, true)
+    val metadata = new Metadata(100L, 60 * 60 * 1000L, logContext,
+      new ClusterResourceListeners)
     val channelBuilder = ClientUtils.createChannelBuilder(config, time)
     val requestTimeoutMs = config.getInt(CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG)
     val retryBackoffMs = config.getLong(CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG)
@@ -441,15 +445,13 @@ object AdminClient {
     val brokerAddresses = ClientUtils.parseAndValidateAddresses(brokerUrls, clientDnsLookup)
     metadata.bootstrap(brokerAddresses, time.milliseconds())
 
-    val clientId = "admin-" + AdminClientIdSequence.getAndIncrement()
-
     val selector = new Selector(
       DefaultConnectionMaxIdleMs,
       metrics,
       time,
       "admin",
       channelBuilder,
-      new LogContext(String.format("[Producer clientId=%s] ", clientId)))
+      logContext)
 
     val networkClient = new NetworkClient(
       selector,
@@ -465,10 +467,10 @@ object AdminClient {
       time,
       true,
       new ApiVersions,
-      new LogContext(String.format("[NetworkClient clientId=%s] ", clientId)))
+      logContext)
 
     val highLevelClient = new ConsumerNetworkClient(
-      new LogContext(String.format("[ConsumerNetworkClient clientId=%s] ", clientId)),
+      logContext,
       networkClient,
       metadata,
       time,
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index b00d229..bceead2 100644
--- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -90,6 +90,7 @@ import java.util.Queue;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Supplier;
 import java.util.regex.Pattern;
 
 /**
@@ -806,6 +807,12 @@ public class TopologyTestDriver implements Closeable {
             timeMs.addAndGet(ms);
             highResTimeNs.addAndGet(TimeUnit.MILLISECONDS.toNanos(ms));
         }
+
+        @Override
+        public void waitObject(final Object obj, final Supplier<Boolean> condition, final long timeoutMs) {
+            throw new UnsupportedOperationException();
+        }
+
     }
 
     private MockConsumer<byte[], byte[]> createRestoreConsumer(final Map<String, String> storeToChangelogTopic) {


Mime
View raw message