kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From j...@apache.org
Subject [kafka] branch 2.5 updated: KAFKA-8805; Bump producer epoch on recoverable errors (#7389)
Date Sun, 16 Feb 2020 06:59:31 GMT
This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.5
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.5 by this push:
     new 5a073fc  KAFKA-8805; Bump producer epoch on recoverable errors (#7389)
5a073fc is described below

commit 5a073fca4cca64b2acf5029615a0d91ed97c894c
Author: Bob Barrett <bob.barrett@confluent.io>
AuthorDate: Sat Feb 15 22:47:10 2020 -0800

    KAFKA-8805; Bump producer epoch on recoverable errors (#7389)
    
    This change is the client-side part of KIP-360. It identifies cases where it is safe to abort a transaction, bump the producer epoch, and allow the application to continue without closing the producer. In these cases, when KafkaProducer.abortTransaction() is called, the producer sends an InitProducerId following the transaction abort, which causes the producer epoch to be bumped. The application can then start a new transaction and continue processing.
    
    For recoverable errors in the idempotent producer, the epoch is bumped locally. In-flight requests for partitions with an error are rewritten to reflect the new epoch, and in-flights of all other partitions are allowed to complete using the old epoch.
    
    Reviewers: Boyang Chen <boyang@confluent.io>, Jason Gustafson <jason@confluent.io>
---
 checkstyle/suppressions.xml                        |   4 +-
 .../kafka/clients/producer/KafkaProducer.java      |   8 +-
 .../clients/producer/internals/ProducerBatch.java  |   6 +
 .../producer/internals/RecordAccumulator.java      |  23 +-
 .../kafka/clients/producer/internals/Sender.java   |  42 +-
 .../producer/internals/TransactionManager.java     | 468 +++++++++----
 .../producer/internals/RecordAccumulatorTest.java  |   3 +-
 .../clients/producer/internals/SenderTest.java     | 315 ++++++---
 .../producer/internals/TransactionManagerTest.java | 739 ++++++++++++++++++---
 .../kafka/api/TransactionsBounceTest.scala         |   3 +-
 .../kafka/api/TransactionsExpirationTest.scala     | 122 ++++
 .../integration/kafka/api/TransactionsTest.scala   | 129 +++-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |  19 +-
 13 files changed, 1449 insertions(+), 432 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 7ebe7fc..7b85455 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -20,7 +20,7 @@
 
     <!-- Clients -->
     <suppress checks="ClassFanOutComplexity"
-              files="(Fetcher|Sender|SenderTest|ConsumerCoordinator|KafkaConsumer|KafkaProducer|Utils|TransactionManagerTest|KafkaAdminClient|NetworkClient|Admin).java"/>
+              files="(Fetcher|Sender|SenderTest|ConsumerCoordinator|KafkaConsumer|KafkaProducer|Utils|TransactionManager|TransactionManagerTest|KafkaAdminClient|NetworkClient|Admin).java"/>
     <suppress checks="ClassFanOutComplexity"
               files="(SaslServerAuthenticator|SaslAuthenticatorTest).java"/>
     <suppress checks="ClassFanOutComplexity"
@@ -59,7 +59,7 @@
               files="(Utils|Topic|KafkaLZ4BlockOutputStream|AclData|JoinGroupRequest).java"/>
 
     <suppress checks="CyclomaticComplexity"
-              files="(ConsumerCoordinator|Fetcher|Sender|KafkaProducer|BufferPool|ConfigDef|RecordAccumulator|KerberosLogin|AbstractRequest|AbstractResponse|Selector|SslFactory|SslTransportLayer|SaslClientAuthenticator|SaslClientCallbackHandler|SaslServerAuthenticator|AbstractCoordinator).java"/>
+              files="(ConsumerCoordinator|Fetcher|Sender|KafkaProducer|BufferPool|ConfigDef|RecordAccumulator|KerberosLogin|AbstractRequest|AbstractResponse|Selector|SslFactory|SslTransportLayer|SaslClientAuthenticator|SaslClientCallbackHandler|SaslServerAuthenticator|AbstractCoordinator|TransactionManager).java"/>
 
     <suppress checks="JavaNCSS"
               files="(AbstractRequest|KerberosLogin|WorkerSinkTaskTest|TransactionManagerTest|SenderTest|KafkaAdminClient|ConsumerCoordinatorTest).java"/>
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
index 12d68b2..e1d85de 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
@@ -388,10 +388,10 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
             this.compressionType = CompressionType.forName(config.getString(ProducerConfig.COMPRESSION_TYPE_CONFIG));
 
             this.maxBlockTimeMs = config.getLong(ProducerConfig.MAX_BLOCK_MS_CONFIG);
-            this.transactionManager = configureTransactionState(config, logContext, log);
             int deliveryTimeoutMs = configureDeliveryTimeout(config, log);
 
             this.apiVersions = new ApiVersions();
+            this.transactionManager = configureTransactionState(config, logContext);
             this.accumulator = new RecordAccumulator(logContext,
                     config.getInt(ProducerConfig.BATCH_SIZE_CONFIG),
                     this.compressionType,
@@ -504,7 +504,8 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
         return deliveryTimeoutMs;
     }
 
-    private static TransactionManager configureTransactionState(ProducerConfig config, LogContext logContext, Logger log) {
+    private TransactionManager configureTransactionState(ProducerConfig config,
+                                                         LogContext logContext) {
 
         TransactionManager transactionManager = null;
 
@@ -518,7 +519,8 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
             String transactionalId = config.getString(ProducerConfig.TRANSACTIONAL_ID_CONFIG);
             int transactionTimeoutMs = config.getInt(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG);
             long retryBackoffMs = config.getLong(ProducerConfig.RETRY_BACKOFF_MS_CONFIG);
-            transactionManager = new TransactionManager(logContext, transactionalId, transactionTimeoutMs, retryBackoffMs);
+            transactionManager = new TransactionManager(logContext, transactionalId, transactionTimeoutMs,
+                    retryBackoffMs, apiVersions);
             if (transactionManager.isTransactional())
                 log.info("Instantiated a transactional producer.");
             else
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java
index f4c171e..9323a61 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java
@@ -389,6 +389,8 @@ public final class ProducerBatch {
     }
 
     public void resetProducerState(ProducerIdAndEpoch producerIdAndEpoch, int baseSequence, boolean isTransactional) {
+        log.info("Resetting sequence number of batch with current sequence {} for partition {} to {}",
+                this.baseSequence(), this.topicPartition, baseSequence);
         reopened = true;
         recordsBuilder.reopenAndRewriteProducerState(producerIdAndEpoch.producerId, producerIdAndEpoch.epoch, baseSequence, isTransactional);
     }
@@ -454,6 +456,10 @@ public final class ProducerBatch {
         return recordsBuilder.baseSequence();
     }
 
+    public int lastSequence() {
+        return recordsBuilder.baseSequence() + recordsBuilder.numRecords() - 1;
+    }
+
     public boolean hasSequence() {
         return baseSequence() != RecordBatch.NO_SEQUENCE;
     }
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
index 58a5b3f..2afc15c 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
@@ -525,11 +525,24 @@ public final class RecordAccumulator {
                 // we cannot send the batch until we have refreshed the producer id
                 return true;
 
-            if (!first.hasSequence() && transactionManager.hasUnresolvedSequence(first.topicPartition))
-                // Don't drain any new batches while the state of previous sequence numbers
-                // is unknown. The previous batches would be unknown if they were aborted
-                // on the client after being sent to the broker at least once.
-                return true;
+            if (!first.hasSequence()) {
+                if (transactionManager.hasInflightBatches(tp)) {
+                    // Don't drain any new batches while the partition has in-flight batches with a different epoch
+                    // and/or producer ID. Otherwise, a batch with a new epoch and sequence number
+                    // 0 could be written before earlier batches complete, which would cause out of sequence errors
+                    ProducerBatch firstInFlightBatch = transactionManager.nextBatchBySequence(tp);
+
+                    if (firstInFlightBatch != null && !transactionManager.matchesProducerIdAndEpoch(firstInFlightBatch)) {
+                        return true;
+                    }
+                }
+
+                if (transactionManager.hasUnresolvedSequence(first.topicPartition))
+                    // Don't drain any new batches while the state of previous sequence numbers
+                    // is unknown. The previous batches would be unknown if they were aborted
+                    // on the client after being sent to the broker at least once.
+                    return true;
+            }
 
             int firstInFlightSequence = transactionManager.firstInFlightSequence(first.topicPartition);
             if (firstInFlightSequence != RecordBatch.NO_SEQUENCE && first.hasSequence()
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
index ceddce3..970d7a8 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
@@ -31,7 +31,6 @@ import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.errors.ClusterAuthorizationException;
 import org.apache.kafka.common.errors.InvalidMetadataException;
-import org.apache.kafka.common.errors.OutOfOrderSequenceException;
 import org.apache.kafka.common.errors.RetriableException;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.errors.TopicAuthorizationException;
@@ -295,13 +294,7 @@ public class Sender implements Runnable {
     void runOnce() {
         if (transactionManager != null) {
             try {
-                if (transactionManager.isTransactional()
-                        && transactionManager.hasUnresolvedSequences()
-                        && !transactionManager.hasFatalError()) {
-                    transactionManager.transitionToFatalError(
-                            new KafkaException("The client hasn't received acknowledgment for " +
-                                    "some previously sent messages and can no longer retry them. It isn't safe to continue."));
-                }
+                transactionManager.maybeResolveSequences();
 
                 // do not continue sending if the transaction manager is in a failed state
                 if (transactionManager.hasFatalError()) {
@@ -314,14 +307,14 @@ public class Sender implements Runnable {
 
                 // Check whether we need a new producerId. If so, we will enqueue an InitProducerId
                 // request which will be sent below
-                transactionManager.resetIdempotentProducerIdIfNeeded();
+                transactionManager.bumpIdempotentEpochAndResetIdIfNeeded();
 
                 if (maybeSendAndPollTransactionalRequest()) {
                     return;
                 }
             } catch (AuthenticationException e) {
                 // This is already logged as error, but propagated here to perform any clean ups.
-                log.trace("Authentication exception while processing transactional request: {}", e);
+                log.trace("Authentication exception while processing transactional request", e);
                 transactionManager.authenticationFailed(e);
             }
         }
@@ -387,7 +380,7 @@ public class Sender implements Runnable {
             failBatch(expiredBatch, -1, NO_TIMESTAMP, new TimeoutException(errorMessage), false);
             if (transactionManager != null && expiredBatch.inRetry()) {
                 // This ensures that no new batches are drained until the current in flight batches are fully resolved.
-                transactionManager.markSequenceUnresolved(expiredBatch.topicPartition);
+                transactionManager.markSequenceUnresolved(expiredBatch);
             }
         }
         sensors.updateProduceRequestMetrics(batches);
@@ -459,7 +452,7 @@ public class Sender implements Runnable {
             long currentTimeMs = time.milliseconds();
             ClientRequest clientRequest = client.newClientRequest(
                 targetNode.idString(), requestBuilder, currentTimeMs, true, requestTimeoutMs, nextRequestHandler);
-            log.debug("Sending transactional request {} to node {}", requestBuilder, targetNode);
+            log.debug("Sending transactional request {} to node {} with correlation ID {}", requestBuilder, targetNode, clientRequest.correlationId());
             client.send(clientRequest, currentTimeMs);
             transactionManager.setInFlightCorrelationId(clientRequest.correlationId());
             client.poll(retryBackoffMs, time.milliseconds());
@@ -521,6 +514,12 @@ public class Sender implements Runnable {
                 client.leastLoadedNode(time.milliseconds());
 
         if (node != null && NetworkClientUtils.awaitReady(client, node, time, requestTimeoutMs)) {
+            if (coordinatorType == FindCoordinatorRequest.CoordinatorType.TRANSACTION) {
+                // Indicate to the transaction manager that the coordinator is ready, allowing it to check ApiVersions
+                // This allows us to bump transactional epochs even if the coordinator is temporarily unavailable at
+                // the time when the abortable error is handled
+                transactionManager.handleCoordinatorReady();
+            }
             return node;
         }
         return null;
@@ -599,19 +598,7 @@ public class Sender implements Runnable {
                     batch.topicPartition,
                     this.retries - batch.attempts() - 1,
                     error);
-                if (transactionManager == null) {
-                    reenqueueBatch(batch, now);
-                } else if (transactionManager.hasProducerIdAndEpoch(batch.producerId(), batch.producerEpoch())) {
-                    // If idempotence is enabled only retry the request if the current producer id is the same as
-                    // the producer id of the batch.
-                    log.debug("Retrying batch to topic-partition {}. ProducerId: {}; Sequence number : {}",
-                            batch.topicPartition, batch.producerId(), batch.baseSequence());
-                    reenqueueBatch(batch, now);
-                } else {
-                    failBatch(batch, response, new OutOfOrderSequenceException("Attempted to retry sending a " +
-                            "batch but the producer id changed from " + batch.producerId() + " to " +
-                            transactionManager.producerIdAndEpoch().producerId + " in the mean time. This batch will be dropped."), false);
-                }
+                reenqueueBatch(batch, now);
             } else if (error == Errors.DUPLICATE_SEQUENCE_NUMBER) {
                 // If we have received a duplicate sequence error, it means that the sequence number has advanced beyond
                 // the sequence of the current batch, and we haven't retained batch metadata on the broker to return
@@ -700,8 +687,9 @@ public class Sender implements Runnable {
         return !batch.hasReachedDeliveryTimeout(accumulator.getDeliveryTimeoutMs(), now) &&
             batch.attempts() < this.retries &&
             !batch.isDone() &&
-            ((response.error.exception() instanceof RetriableException) ||
-                (transactionManager != null && transactionManager.canRetry(response, batch)));
+            (transactionManager == null ?
+                    response.error.exception() instanceof RetriableException :
+                    transactionManager.canRetry(response, batch));
     }
 
     /**
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index 18ab408..cd7dcbe 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -16,11 +16,19 @@
  */
 package org.apache.kafka.clients.producer.internals;
 
+import org.apache.kafka.clients.ApiVersion;
+import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.ClientResponse;
+import org.apache.kafka.clients.NodeApiVersions;
 import org.apache.kafka.clients.RequestCompletionHandler;
 import org.apache.kafka.clients.consumer.CommitFailedException;
 import org.apache.kafka.clients.consumer.ConsumerGroupMetadata;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.errors.InvalidPidMappingException;
+import org.apache.kafka.common.errors.RetriableException;
+import org.apache.kafka.common.errors.UnknownProducerIdException;
+import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.utils.ProducerIdAndEpoch;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
@@ -58,7 +66,6 @@ import org.apache.kafka.common.requests.TxnOffsetCommitRequest.CommittedOffset;
 import org.apache.kafka.common.requests.TxnOffsetCommitResponse;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.PrimitiveRef;
-import org.apache.kafka.common.utils.ProducerIdAndEpoch;
 import org.slf4j.Logger;
 
 import java.util.ArrayList;
@@ -87,47 +94,58 @@ public class TransactionManager {
     private final Logger log;
     private final String transactionalId;
     private final int transactionTimeoutMs;
+    private final ApiVersions apiVersions;
 
     private static class TopicPartitionBookkeeper {
 
-        private final Map<TopicPartition, TopicPartitionEntry> topicPartitionBookkeeping = new HashMap<>();
+        private final Map<TopicPartition, TopicPartitionEntry> topicPartitions = new HashMap<>();
 
-        public TopicPartitionEntry getPartition(TopicPartition topic) {
-            TopicPartitionEntry ent = topicPartitionBookkeeping.get(topic);
+        private TopicPartitionEntry getPartition(TopicPartition topicPartition) {
+            TopicPartitionEntry ent = topicPartitions.get(topicPartition);
             if (ent == null)
-                throw new IllegalStateException("Trying to get the sequence number for " + topic +
+                throw new IllegalStateException("Trying to get the sequence number for " + topicPartition +
                         ", but the sequence number was never set for this partition.");
             return ent;
         }
 
-        public void addPartition(TopicPartition topic) {
-            if (!topicPartitionBookkeeping.containsKey(topic))
-                topicPartitionBookkeeping.put(topic, new TopicPartitionEntry());
+        private void addPartition(TopicPartition topicPartition) {
+            this.topicPartitions.putIfAbsent(topicPartition, new TopicPartitionEntry());
         }
 
-        boolean contains(TopicPartition partition) {
-            return topicPartitionBookkeeping.containsKey(partition);
+        private boolean contains(TopicPartition topicPartition) {
+            return topicPartitions.containsKey(topicPartition);
         }
 
-        public void reset() {
-            topicPartitionBookkeeping.clear();
+        private void reset() {
+            topicPartitions.clear();
         }
 
-        OptionalLong lastAckedOffset(TopicPartition partition) {
-            TopicPartitionEntry entry = topicPartitionBookkeeping.get(partition);
+        private OptionalLong lastAckedOffset(TopicPartition topicPartition) {
+            TopicPartitionEntry entry = topicPartitions.get(topicPartition);
             if (entry != null && entry.lastAckedOffset != ProduceResponse.INVALID_OFFSET)
                 return OptionalLong.of(entry.lastAckedOffset);
             else
                 return OptionalLong.empty();
         }
 
-        OptionalInt lastAckedSequence(TopicPartition partition) {
-            TopicPartitionEntry entry = topicPartitionBookkeeping.get(partition);
+        private OptionalInt lastAckedSequence(TopicPartition topicPartition) {
+            TopicPartitionEntry entry = topicPartitions.get(topicPartition);
             if (entry != null && entry.lastAckedSequence != NO_LAST_ACKED_SEQUENCE_NUMBER)
                 return OptionalInt.of(entry.lastAckedSequence);
             else
                 return OptionalInt.empty();
         }
+
+        private void startSequencesAtBeginning(TopicPartition topicPartition, ProducerIdAndEpoch newProducerIdAndEpoch) {
+            final PrimitiveRef.IntRef sequence = PrimitiveRef.ofInt(0);
+            TopicPartitionEntry topicPartitionEntry = getPartition(topicPartition);
+            topicPartitionEntry.resetSequenceNumbers(inFlightBatch -> {
+                inFlightBatch.resetProducerState(newProducerIdAndEpoch, sequence.value, inFlightBatch.isTransactional());
+                sequence.value += inFlightBatch.recordCount;
+            });
+            topicPartitionEntry.nextSequence = sequence.value;
+            topicPartitionEntry.lastAckedSequence = NO_LAST_ACKED_SEQUENCE_NUMBER;
+        }
     }
 
     private static class TopicPartitionEntry {
@@ -164,7 +182,6 @@ public class TransactionManager {
             }
             inflightBatchesBySequence = newInflights;
         }
-
     }
 
     private final TopicPartitionBookkeeper topicPartitionBookkeeper;
@@ -177,7 +194,14 @@ public class TransactionManager {
     // successfully (indicating that the expired batch actually made it to the broker). If we don't get any successful
     // responses for the partition once the inflight request count falls to zero, we reset the producer id and
     // consequently clear this data structure as well.
-    private final Set<TopicPartition> partitionsWithUnresolvedSequences;
+    // The value of the map is the sequence number of the batch following the expired one, computed by adding its
+    // record count to its sequence number. This is used to tell if a subsequent batch is the one immediately following
+    // the expired one.
+    private final Map<TopicPartition, Integer> partitionsWithUnresolvedSequences;
+
+    // The partitions that have received an error that triggers an epoch bump. When the epoch is bumped, these
+    // partitions will have the sequences of their in-flight batches rewritten
+    private final Set<TopicPartition> partitionsToRewriteSequences;
 
     private final PriorityQueue<TxnRequestHandler> pendingRequests;
     private final Set<TopicPartition> newPartitionsInTransaction;
@@ -197,11 +221,13 @@ public class TransactionManager {
     private int inFlightRequestCorrelationId = NO_INFLIGHT_REQUEST_CORRELATION_ID;
     private Node transactionCoordinator;
     private Node consumerGroupCoordinator;
+    private boolean coordinatorSupportsBumpingEpoch;
 
     private volatile State currentState = State.UNINITIALIZED;
     private volatile RuntimeException lastError = null;
     private volatile ProducerIdAndEpoch producerIdAndEpoch;
     private volatile boolean transactionStarted = false;
+    private volatile boolean epochBumpRequired = false;
 
     private enum State {
         UNINITIALIZED,
@@ -218,7 +244,7 @@ public class TransactionManager {
                 case UNINITIALIZED:
                     return source == READY;
                 case INITIALIZING:
-                    return source == UNINITIALIZED;
+                    return source == UNINITIALIZED || source == ABORTING_TRANSACTION;
                 case READY:
                     return source == INITIALIZING || source == COMMITTING_TRANSACTION || source == ABORTING_TRANSACTION;
                 case IN_TRANSACTION:
@@ -241,12 +267,14 @@ public class TransactionManager {
 
     // We use the priority to determine the order in which requests need to be sent out. For instance, if we have
     // a pending FindCoordinator request, that must always go first. Next, If we need a producer id, that must go second.
-    // The endTxn request must always go last.
+    // The endTxn request must always go last, unless we are bumping the epoch (a special case of InitProducerId) as
+    // part of ending the transaction.
     private enum Priority {
         FIND_COORDINATOR(0),
         INIT_PRODUCER_ID(1),
         ADD_PARTITIONS_OR_OFFSETS(2),
-        END_TXN(3);
+        END_TXN(3),
+        EPOCH_BUMP(4);
 
         final int priority;
 
@@ -255,7 +283,11 @@ public class TransactionManager {
         }
     }
 
-    public TransactionManager(LogContext logContext, String transactionalId, int transactionTimeoutMs, long retryBackoffMs) {
+    public TransactionManager(LogContext logContext,
+                              String transactionalId,
+                              int transactionTimeoutMs,
+                              long retryBackoffMs,
+                              ApiVersions apiVersions) {
         this.producerIdAndEpoch = ProducerIdAndEpoch.NONE;
         this.transactionalId = transactionalId;
         this.log = logContext.logger(TransactionManager.class);
@@ -267,22 +299,34 @@ public class TransactionManager {
         this.partitionsInTransaction = new HashSet<>();
         this.pendingRequests = new PriorityQueue<>(10, Comparator.comparingInt(o -> o.priority().priority));
         this.pendingTxnOffsetCommits = new HashMap<>();
-        this.partitionsWithUnresolvedSequences = new HashSet<>();
+        this.partitionsWithUnresolvedSequences = new HashMap<>();
+        this.partitionsToRewriteSequences = new HashSet<>();
         this.retryBackoffMs = retryBackoffMs;
         this.topicPartitionBookkeeper = new TopicPartitionBookkeeper();
+        this.apiVersions = apiVersions;
     }
 
-    TransactionManager() {
-        this(new LogContext(), null, 0, 100L);
+    public synchronized TransactionalRequestResult initializeTransactions() {
+        return initializeTransactions(ProducerIdAndEpoch.NONE);
     }
 
-    public synchronized TransactionalRequestResult initializeTransactions() {
+    synchronized TransactionalRequestResult initializeTransactions(ProducerIdAndEpoch producerIdAndEpoch) {
+        boolean isEpochBump = producerIdAndEpoch != ProducerIdAndEpoch.NONE;
         return handleCachedTransactionRequestResult(() -> {
-            transitionTo(State.INITIALIZING);
+            // If this is an epoch bump, we will transition the state as part of handling the EndTxnRequest
+            if (!isEpochBump) {
+                transitionTo(State.INITIALIZING);
+                log.info("Invoking InitProducerId for the first time in order to acquire a producer ID");
+            } else {
+                log.info("Invoking InitProducerId with current producer ID and epoch {} in order to bump the epoch", producerIdAndEpoch);
+            }
             InitProducerIdRequestData requestData = new InitProducerIdRequestData()
                     .setTransactionalId(transactionalId)
-                    .setTransactionTimeoutMs(transactionTimeoutMs);
-            InitProducerIdHandler handler = new InitProducerIdHandler(new InitProducerIdRequest.Builder(requestData));
+                    .setTransactionTimeoutMs(transactionTimeoutMs)
+                    .setProducerId(producerIdAndEpoch.producerId)
+                    .setProducerEpoch(producerIdAndEpoch.epoch);
+            InitProducerIdHandler handler = new InitProducerIdHandler(new InitProducerIdRequest.Builder(requestData),
+                    isEpochBump);
             enqueueRequest(handler);
             return handler.result;
         }, State.INITIALIZING);
@@ -317,15 +361,26 @@ public class TransactionManager {
     private TransactionalRequestResult beginCompletingTransaction(TransactionResult transactionResult) {
         if (!newPartitionsInTransaction.isEmpty())
             enqueueRequest(addPartitionsToTransactionHandler());
-        EndTxnRequest.Builder builder = new EndTxnRequest.Builder(
-            new EndTxnRequestData()
-                .setTransactionalId(transactionalId)
-                .setProducerId(producerIdAndEpoch.producerId)
-                .setProducerEpoch(producerIdAndEpoch.epoch)
-                .setCommitted(transactionResult.id));
-        EndTxnHandler handler = new EndTxnHandler(builder);
-        enqueueRequest(handler);
-        return handler.result;
+
+        // If the error is an INVALID_PRODUCER_ID_MAPPING error, the server will not accept an EndTxnRequest, so skip
+        // directly to InitProducerId. Otherwise, we must first abort the transaction, because the producer will be
+        // fenced if we directly call InitProducerId.
+        if (!(lastError instanceof InvalidPidMappingException)) {
+            EndTxnRequest.Builder builder = new EndTxnRequest.Builder(
+                    new EndTxnRequestData()
+                            .setTransactionalId(transactionalId)
+                            .setProducerId(producerIdAndEpoch.producerId)
+                            .setProducerEpoch(producerIdAndEpoch.epoch)
+                            .setCommitted(transactionResult.id));
+
+            EndTxnHandler handler = new EndTxnHandler(builder);
+            enqueueRequest(handler);
+            if (!shouldBumpEpoch()) {
+                return handler.result;
+            }
+        }
+
+        return initializeTransactions(this.producerIdAndEpoch);
     }
 
     public synchronized TransactionalRequestResult sendOffsetsToTransaction(final Map<TopicPartition, OffsetAndMetadata> offsets,
@@ -338,7 +393,7 @@ public class TransactionManager {
 
         log.debug("Begin adding offsets {} for consumer group {} to transaction", offsets, groupMetadata);
         AddOffsetsToTxnRequest.Builder builder = new AddOffsetsToTxnRequest.Builder(transactionalId,
-            producerIdAndEpoch.producerId, producerIdAndEpoch.epoch, groupMetadata.groupId());
+                producerIdAndEpoch.producerId, producerIdAndEpoch.epoch, groupMetadata.groupId());
         AddOffsetsToTxnHandler handler = new AddOffsetsToTxnHandler(builder, offsets, groupMetadata);
         enqueueRequest(handler);
         return handler.result;
@@ -412,6 +467,7 @@ public class TransactionManager {
                     "aborted. Underlying exception: ", exception);
             return;
         }
+
         transitionTo(State.ABORTABLE_ERROR, exception);
     }
 
@@ -447,9 +503,9 @@ public class TransactionManager {
         return producerIdAndEpoch.producerId == producerId;
     }
 
-    boolean hasProducerIdAndEpoch(long producerId, short producerEpoch) {
+    boolean matchesProducerIdAndEpoch(ProducerBatch batch) {
         ProducerIdAndEpoch idAndEpoch = this.producerIdAndEpoch;
-        return idAndEpoch.producerId == producerId && idAndEpoch.epoch == producerEpoch;
+        return idAndEpoch.producerId == batch.producerId() && idAndEpoch.epoch == batch.producerEpoch();
     }
 
     /**
@@ -461,45 +517,67 @@ public class TransactionManager {
     }
 
     /**
-     * This method is used when the producer needs to reset its internal state because of an irrecoverable exception
-     * from the broker.
-     *
-     * We need to reset the producer id and associated state when we have sent a batch to the broker, but we either get
-     * a non-retriable exception or we run out of retries, or the batch expired in the producer queue after it was already
-     * sent to the broker.
-     *
-     * In all of these cases, we don't know whether batch was actually committed on the broker, and hence whether the
-     * sequence number was actually updated. If we don't reset the producer state, we risk the chance that all future
-     * messages will return an OutOfOrderSequenceException.
-     *
-     * Note that we can't reset the producer state for the transactional producer as this would mean bumping the epoch
-     * for the same producer id. This might involve aborting the ongoing transaction during the initPidRequest, and the user
-     * would not have any way of knowing this happened. So for the transactional producer, it's best to return the
-     * produce error to the user and let them abort the transaction and close the producer explicitly.
+     * This method resets the producer ID and epoch and sets the state to UNINITIALIZED, which will trigger a new
+     * InitProducerId request. This method is only called when the producer epoch is exhausted; we will bump the epoch
+     * instead.
      */
-    synchronized void resetIdempotentProducerId() {
+    private void resetIdempotentProducerId() {
         if (isTransactional())
             throw new IllegalStateException("Cannot reset producer state for a transactional producer. " +
                     "You must either abort the ongoing transaction or reinitialize the transactional producer instead");
+        log.debug("Resetting idempotent producer ID. ID and epoch before reset are {}", this.producerIdAndEpoch);
         setProducerIdAndEpoch(ProducerIdAndEpoch.NONE);
+        transitionTo(State.UNINITIALIZED);
+    }
+
+    private void resetSequenceForPartition(TopicPartition topicPartition) {
+        topicPartitionBookkeeper.topicPartitions.remove(topicPartition);
+        this.partitionsWithUnresolvedSequences.remove(topicPartition);
+    }
+
+    private void resetSequenceNumbers() {
         topicPartitionBookkeeper.reset();
         this.partitionsWithUnresolvedSequences.clear();
-        transitionTo(State.UNINITIALIZED);
     }
 
-    synchronized void resetIdempotentProducerIdIfNeeded() {
+    synchronized void requestEpochBumpForPartition(TopicPartition tp) {
+        epochBumpRequired = true;
+        this.partitionsToRewriteSequences.add(tp);
+    }
+
+    private boolean shouldBumpEpoch() {
+        return epochBumpRequired;
+    }
+
+    private void bumpIdempotentProducerEpoch() {
+        if (this.producerIdAndEpoch.epoch == Short.MAX_VALUE) {
+            resetIdempotentProducerId();
+        } else {
+            setProducerIdAndEpoch(new ProducerIdAndEpoch(this.producerIdAndEpoch.producerId, (short) (this.producerIdAndEpoch.epoch + 1)));
+            log.debug("Incremented producer epoch, current producer ID and epoch are now {}", this.producerIdAndEpoch);
+        }
+
+        // When the epoch is bumped, rewrite all in-flight sequences for the partition(s) that triggered the epoch bump
+        for (TopicPartition topicPartition : this.partitionsToRewriteSequences) {
+            this.topicPartitionBookkeeper.startSequencesAtBeginning(topicPartition, this.producerIdAndEpoch);
+            this.partitionsWithUnresolvedSequences.remove(topicPartition);
+        }
+
+        this.partitionsToRewriteSequences.clear();
+        epochBumpRequired = false;
+    }
+
+    synchronized void bumpIdempotentEpochAndResetIdIfNeeded() {
         if (!isTransactional()) {
-            if (shouldResetProducerStateAfterResolvingSequences()) {
-                // Check if the previous run expired batches which requires a reset of the producer state.
-                resetIdempotentProducerId();
+            if (shouldBumpEpoch()) {
+                bumpIdempotentProducerEpoch();
             }
-
             if (currentState != State.INITIALIZING && !hasProducerId()) {
                 transitionTo(State.INITIALIZING);
                 InitProducerIdRequestData requestData = new InitProducerIdRequestData()
                         .setTransactionalId(null)
                         .setTransactionTimeoutMs(Integer.MAX_VALUE);
-                InitProducerIdHandler handler = new InitProducerIdHandler(new InitProducerIdRequest.Builder(requestData));
+                InitProducerIdHandler handler = new InitProducerIdHandler(new InitProducerIdRequest.Builder(requestData), false);
                 enqueueRequest(handler);
             }
         }
@@ -557,9 +635,14 @@ public class TransactionManager {
         }
     }
 
-    private void maybeUpdateLastAckedSequence(TopicPartition topicPartition, int sequence) {
-        if (sequence > lastAckedSequence(topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER))
+    private int maybeUpdateLastAckedSequence(TopicPartition topicPartition, int sequence) {
+        int lastAckedSequence = lastAckedSequence(topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER);
+        if (sequence > lastAckedSequence) {
             topicPartitionBookkeeper.getPartition(topicPartition).lastAckedSequence = sequence;
+            return sequence;
+        }
+
+        return lastAckedSequence;
     }
 
     synchronized OptionalInt lastAckedSequence(TopicPartition topicPartition) {
@@ -589,21 +672,20 @@ public class TransactionManager {
     }
 
     public synchronized void handleCompletedBatch(ProducerBatch batch, ProduceResponse.PartitionResponse response) {
-        if (!hasProducerIdAndEpoch(batch.producerId(), batch.producerEpoch())) {
-            log.debug("Ignoring completed batch {} with producer id {}, epoch {}, and sequence number {} " +
-                            "since the producerId has been reset internally", batch, batch.producerId(),
-                    batch.producerEpoch(), batch.baseSequence());
-            return;
-        }
-
-        maybeUpdateLastAckedSequence(batch.topicPartition, batch.baseSequence() + batch.recordCount - 1);
+        int lastAckedSequence = maybeUpdateLastAckedSequence(batch.topicPartition, batch.lastSequence());
         log.debug("ProducerId: {}; Set last ack'd sequence number for topic-partition {} to {}",
                 batch.producerId(),
                 batch.topicPartition,
-                lastAckedSequence(batch.topicPartition).orElse(-1));
+                lastAckedSequence);
 
         updateLastAckedOffset(response, batch);
         removeInFlightBatch(batch);
+
+        if (!matchesProducerIdAndEpoch(batch) && !hasInflightBatches(batch.topicPartition)) {
+            // If the batch was on a different ID and/or epoch (due to an epoch bump) and all its in-flight batches
+            // have completed, reset the partition sequence so that the next batch (with the new epoch) starts from 0
+            topicPartitionBookkeeper.startSequencesAtBeginning(batch.topicPartition, this.producerIdAndEpoch);
+        }
     }
 
     private void maybeTransitionToErrorState(RuntimeException exception) {
@@ -613,16 +695,20 @@ public class TransactionManager {
                 || exception instanceof UnsupportedVersionException) {
             transitionToFatalError(exception);
         } else if (isTransactional()) {
+            if (canBumpEpoch() && !isCompleting()) {
+                epochBumpRequired = true;
+            }
             transitionToAbortableError(exception);
         }
     }
 
-    public synchronized void handleFailedBatch(ProducerBatch batch, RuntimeException exception, boolean adjustSequenceNumbers) {
+    synchronized void handleFailedBatch(ProducerBatch batch, RuntimeException exception, boolean adjustSequenceNumbers) {
         maybeTransitionToErrorState(exception);
+        removeInFlightBatch(batch);
 
-        if (!hasProducerIdAndEpoch(batch.producerId(), batch.producerEpoch())) {
+        if (!matchesProducerIdAndEpoch(batch)) {
             log.debug("Ignoring failed batch {} with producer id {}, epoch {}, and sequence number {} " +
-                    "since the producerId has been reset internally", batch, batch.producerId(),
+                            "since the producerId has been reset internally", batch, batch.producerId(),
                     batch.producerEpoch(), batch.baseSequence(), exception);
             return;
         }
@@ -631,14 +717,24 @@ public class TransactionManager {
             log.error("The broker returned {} for topic-partition {} with producerId {}, epoch {}, and sequence number {}",
                     exception, batch.topicPartition, batch.producerId(), batch.producerEpoch(), batch.baseSequence());
 
-            // Reset the producerId since we have hit an irrecoverable exception and cannot make any guarantees
-            // about the previously committed message. Note that this will discard the producer id and sequence
-            // numbers for all existing partitions.
-            resetIdempotentProducerId();
+            // If we fail with an OutOfOrderSequenceException, we have a gap in the log. Bump the epoch for this
+            // partition, which will reset the sequence number to 0 and allow us to continue
+            requestEpochBumpForPartition(batch.topicPartition);
+        } else if (exception instanceof UnknownProducerIdException) {
+            // If we get an UnknownProducerId for a partition, then the broker has no state for that producer. It will
+            // therefore accept a write with sequence number 0. We reset the sequence number for the partition here so
+            // that the producer can continue after aborting the transaction. All inflight-requests to this partition
+            // will also fail with an UnknownProducerId error, so the sequence will remain at 0. Note that if the
+            // broker supports bumping the epoch, we will later reset all sequence numbers after calling InitProducerId
+            resetSequenceForPartition(batch.topicPartition);
         } else {
-            removeInFlightBatch(batch);
-            if (adjustSequenceNumbers)
-                adjustSequencesDueToFailedBatch(batch);
+            if (adjustSequenceNumbers) {
+                if (!isTransactional()) {
+                    requestEpochBumpForPartition(batch.topicPartition);
+                } else {
+                    adjustSequencesDueToFailedBatch(batch);
+                }
+            }
         }
     }
 
@@ -672,24 +768,10 @@ public class TransactionManager {
 
             log.info("Resetting sequence number of batch with current sequence {} for partition {} to {}", inFlightBatch.baseSequence(), batch.topicPartition, newSequence);
             inFlightBatch.resetProducerState(new ProducerIdAndEpoch(inFlightBatch.producerId(), inFlightBatch.producerEpoch()), newSequence, inFlightBatch.isTransactional());
-
         });
     }
 
-    private void startSequencesAtBeginning(TopicPartition topicPartition) {
-        final PrimitiveRef.IntRef sequence = PrimitiveRef.ofInt(0);
-        topicPartitionBookkeeper.getPartition(topicPartition).resetSequenceNumbers(inFlightBatch -> {
-            log.info("Resetting sequence number of batch with current sequence {} for partition {} to {}",
-                    inFlightBatch.baseSequence(), inFlightBatch.topicPartition, sequence.value);
-            inFlightBatch.resetProducerState(new ProducerIdAndEpoch(inFlightBatch.producerId(),
-                    inFlightBatch.producerEpoch()), sequence.value, inFlightBatch.isTransactional());
-            sequence.value += inFlightBatch.recordCount;
-        });
-        setNextSequence(topicPartition, sequence.value);
-        topicPartitionBookkeeper.getPartition(topicPartition).lastAckedSequence = NO_LAST_ACKED_SEQUENCE_NUMBER;
-    }
-
-    private boolean hasInflightBatches(TopicPartition topicPartition) {
+    synchronized boolean hasInflightBatches(TopicPartition topicPartition) {
         return topicPartitionBookkeeper.contains(topicPartition)
                 && !topicPartitionBookkeeper.getPartition(topicPartition).inflightBatchesBySequence.isEmpty();
     }
@@ -699,21 +781,24 @@ public class TransactionManager {
     }
 
     synchronized boolean hasUnresolvedSequence(TopicPartition topicPartition) {
-        return partitionsWithUnresolvedSequences.contains(topicPartition);
+        return partitionsWithUnresolvedSequences.containsKey(topicPartition);
     }
 
-    synchronized void markSequenceUnresolved(TopicPartition topicPartition) {
-        log.debug("Marking partition {} unresolved", topicPartition);
-        partitionsWithUnresolvedSequences.add(topicPartition);
+    synchronized void markSequenceUnresolved(ProducerBatch batch) {
+        int nextSequence = batch.lastSequence() + 1;
+        partitionsWithUnresolvedSequences.compute(batch.topicPartition,
+            (k, v) -> v == null ? nextSequence : Math.max(v, nextSequence));
+        log.debug("Marking partition {} unresolved with next sequence number {}", batch.topicPartition,
+                partitionsWithUnresolvedSequences.get(batch.topicPartition));
     }
 
-    // Checks if there are any partitions with unresolved partitions which may now be resolved. Returns true if
-    // the producer id needs a reset, false otherwise.
-    private boolean shouldResetProducerStateAfterResolvingSequences() {
-        for (Iterator<TopicPartition> iter = partitionsWithUnresolvedSequences.iterator(); iter.hasNext(); ) {
+    // Attempts to resolve unresolved sequences. If all in-flight requests are complete and some partitions are still
+    // unresolved, either bump the epoch if possible, or transition to a fatal error
+    synchronized void maybeResolveSequences() {
+        for (Iterator<TopicPartition> iter = partitionsWithUnresolvedSequences.keySet().iterator(); iter.hasNext(); ) {
             TopicPartition topicPartition = iter.next();
             if (!hasInflightBatches(topicPartition)) {
-                // The partition has been fully drained. At this point, the last ack'd sequence should be once less than
+                // The partition has been fully drained. At this point, the last ack'd sequence should be one less than
                 // next sequence destined for the partition. If so, the partition is fully resolved. If not, we should
                 // reset the sequence number if necessary.
                 if (isNextSequence(topicPartition, sequenceNumber(topicPartition))) {
@@ -721,14 +806,31 @@ public class TransactionManager {
                     iter.remove();
                 } else {
                     // We would enter this branch if all in flight batches were ultimately expired in the producer.
-                    log.info("No inflight batches remaining for {}, last ack'd sequence for partition is {}, next sequence is {}. " +
-                            "Going to reset producer state.", topicPartition,
-                            lastAckedSequence(topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER), sequenceNumber(topicPartition));
-                    return true;
+                    if (isTransactional()) {
+                        // For the transactional producer, we bump the epoch if possible, otherwise we transition to a fatal error
+                        String unackedMessagesErr = "The client hasn't received acknowledgment for some previously " +
+                                "sent messages and can no longer retry them. ";
+                        if (canBumpEpoch()) {
+                            epochBumpRequired = true;
+                            KafkaException exception = new KafkaException(unackedMessagesErr + "It is safe to abort " +
+                                    "the transaction and continue.");
+                            transitionToAbortableError(exception);
+                        } else {
+                            KafkaException exception = new KafkaException(unackedMessagesErr + "It isn't safe to continue.");
+                            transitionToFatalError(exception);
+                        }
+                    } else {
+                        // For the idempotent producer, bump the epoch
+                        log.info("No inflight batches remaining for {}, last ack'd sequence for partition is {}, next sequence is {}. " +
+                                        "Going to bump epoch and reset sequence numbers.", topicPartition,
+                                lastAckedSequence(topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER), sequenceNumber(topicPartition));
+                        requestEpochBumpForPartition(topicPartition);
+                    }
+
+                    iter.remove();
                 }
             }
         }
-        return false;
     }
 
     private boolean isNextSequence(TopicPartition topicPartition, int sequence) {
@@ -739,6 +841,11 @@ public class TransactionManager {
         topicPartitionBookkeeper.getPartition(topicPartition).nextSequence = sequence;
     }
 
+    private boolean isNextSequenceForUnresolvedPartition(TopicPartition topicPartition, int sequence) {
+        return this.hasUnresolvedSequence(topicPartition) &&
+                sequence == this.partitionsWithUnresolvedSequences.get(topicPartition);
+    }
+
     synchronized TxnRequestHandler nextRequest(boolean hasIncompleteBatches) {
         if (!newPartitionsInTransaction.isEmpty())
             enqueueRequest(addPartitionsToTransactionHandler());
@@ -851,23 +958,16 @@ public class TransactionManager {
     }
 
     synchronized boolean canRetry(ProduceResponse.PartitionResponse response, ProducerBatch batch) {
-        if (!hasProducerIdAndEpoch(batch.producerId(), batch.producerEpoch()))
-            return false;
-
         Errors error = response.error;
-        if (error == Errors.OUT_OF_ORDER_SEQUENCE_NUMBER && !hasUnresolvedSequence(batch.topicPartition) &&
-                (batch.sequenceHasBeenReset() || !isNextSequence(batch.topicPartition, batch.baseSequence())))
-            // We should retry the OutOfOrderSequenceException if the batch is _not_ the next batch, ie. its base
-            // sequence isn't the lastAckedSequence + 1. However, if the first in flight batch fails fatally, we will
-            // adjust the sequences of the other inflight batches to account for the 'loss' of the sequence range in
-            // the batch which failed. In this case, an inflight batch will have a base sequence which is
-            // the lastAckedSequence + 1 after adjustment. When this batch fails with an OutOfOrderSequence, we want to retry it.
-            // To account for the latter case, we check whether the sequence has been reset since the last drain.
-            // If it has, we will retry it anyway.
-            return true;
 
+        // An UNKNOWN_PRODUCER_ID means that we have lost the producer state on the broker. Depending on the log start
+        // offset, we may want to retry these, as described for each case below. If none of those apply, then for the
+        // idempotent producer, we will locally bump the epoch and reset the sequence numbers of in-flight batches from
+        // sequence 0, then retry the failed batch, which should now succeed. For the transactional producer, allow the
+        // batch to fail. When processing the failed batch, we will transition to an abortable error and set a flag
+        // indicating that we need to bump the epoch (if supported by the broker).
         if (error == Errors.UNKNOWN_PRODUCER_ID) {
-            if (response.logStartOffset == -1)
+            if (response.logStartOffset == -1) {
                 // We don't know the log start offset with this response. We should just retry the request until we get it.
                 // The UNKNOWN_PRODUCER_ID error code was added along with the new ProduceResponse which includes the
                 // logStartOffset. So the '-1' sentinel is not for backward compatibility. Instead, it is possible for
@@ -876,6 +976,7 @@ public class TransactionManager {
                 // response was being constructed. In these cases, we should just retry the request: we are guaranteed
                 // to eventually get a logStartOffset once things settle down.
                 return true;
+            }
 
             if (batch.sequenceHasBeenReset()) {
                 // When the first inflight batch fails due to the truncation case, then the sequences of all the other
@@ -885,13 +986,44 @@ public class TransactionManager {
                 return true;
             } else if (lastAckedOffset(batch.topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER) < response.logStartOffset) {
                 // The head of the log has been removed, probably due to the retention time elapsing. In this case,
-                // we expect to lose the producer state. Reset the sequences of all inflight batches to be from the beginning
-                // and retry them.
-                startSequencesAtBeginning(batch.topicPartition);
+                // we expect to lose the producer state. For the transactional procducer, reset the sequences of all
+                // inflight batches to be from the beginning and retry them, so that the transaction does not need to
+                // be aborted. For the idempotent producer, bump the epoch to avoid reusing (sequence, epoch) pairs
+                if (isTransactional()) {
+                    topicPartitionBookkeeper.startSequencesAtBeginning(batch.topicPartition, this.producerIdAndEpoch);
+                } else {
+                    requestEpochBumpForPartition(batch.topicPartition);
+                }
+                return true;
+            }
+
+            if (!isTransactional()) {
+                // For the idempotent producer, always retry UNKNOWN_PRODUCER_ID errors. If the batch has the current
+                // producer ID and epoch, request a bump of the epoch. Otherwise just retry, as the
+                requestEpochBumpForPartition(batch.topicPartition);
+                return true;
+            }
+        } else if (error == Errors.OUT_OF_ORDER_SEQUENCE_NUMBER) {
+            if (!hasUnresolvedSequence(batch.topicPartition) &&
+                    (batch.sequenceHasBeenReset() || !isNextSequence(batch.topicPartition, batch.baseSequence()))) {
+                // We should retry the OutOfOrderSequenceException if the batch is _not_ the next batch, ie. its base
+                // sequence isn't the lastAckedSequence + 1.
+                return true;
+            } else if (!isTransactional()) {
+                // For the idempotent producer, retry all OUT_OF_ORDER_SEQUENCE_NUMBER errors. If there are no
+                // unresolved sequences, or this batch is the one immediately following an unresolved sequence, we know
+                // there is actually a gap in the sequences, and we bump the epoch. Otherwise, retry without bumping
+                // and wait to see if the sequence resolves
+                if (!hasUnresolvedSequence(batch.topicPartition) ||
+                        isNextSequenceForUnresolvedPartition(batch.topicPartition, batch.baseSequence())) {
+                    requestEpochBumpForPartition(batch.topicPartition);
+                }
                 return true;
             }
         }
-        return false;
+
+        // If neither of the above cases are true, retry if the exception is retriable
+        return error.exception() instanceof RetriableException;
     }
 
     // visible for testing
@@ -899,6 +1031,17 @@ public class TransactionManager {
         return isTransactional() && currentState == State.READY;
     }
 
+    void handleCoordinatorReady() {
+        NodeApiVersions nodeApiVersions = transactionCoordinator != null ?
+                apiVersions.get(transactionCoordinator.idString()) :
+                null;
+        ApiVersion initProducerIdVersion = nodeApiVersions != null ?
+                nodeApiVersions.apiVersion(ApiKeys.INIT_PRODUCER_ID) :
+                null;
+        this.coordinatorSupportsBumpingEpoch = initProducerIdVersion != null &&
+                initProducerIdVersion.maxVersion >= 3;
+    }
+
     private void transitionTo(State target) {
         transitionTo(target, null);
     }
@@ -980,14 +1123,7 @@ public class TransactionManager {
         enqueueRequest(new FindCoordinatorHandler(builder));
     }
 
-    private void completeTransaction() {
-        transitionTo(State.READY);
-        lastError = null;
-        transactionStarted = false;
-        newPartitionsInTransaction.clear();
-        pendingPartitionsInTransaction.clear();
-        partitionsInTransaction.clear();
-    }
+
 
     private TxnRequestHandler addPartitionsToTransactionHandler() {
         pendingPartitionsInTransaction.addAll(newPartitionsInTransaction);
@@ -1035,6 +1171,29 @@ public class TransactionManager {
         return pendingResult;
     }
 
+    // package-private for testing
+    boolean canBumpEpoch() {
+        if (!isTransactional()) {
+            return true;
+        }
+
+        return coordinatorSupportsBumpingEpoch;
+    }
+
+    private void completeTransaction() {
+        if (epochBumpRequired) {
+            transitionTo(State.INITIALIZING);
+        } else {
+            transitionTo(State.READY);
+        }
+        lastError = null;
+        epochBumpRequired = false;
+        transactionStarted = false;
+        newPartitionsInTransaction.clear();
+        pendingPartitionsInTransaction.clear();
+        partitionsInTransaction.clear();
+    }
+
     abstract class TxnRequestHandler implements RequestCompletionHandler {
         protected final TransactionalRequestResult result;
         private boolean isRetry = false;
@@ -1057,6 +1216,15 @@ public class TransactionManager {
             transitionToAbortableError(e);
         }
 
+        void abortableErrorIfPossible(RuntimeException e) {
+            if (canBumpEpoch()) {
+                epochBumpRequired = true;
+                abortableError(e);
+            } else {
+                fatalError(e);
+            }
+        }
+
         void fail(RuntimeException e) {
             result.fail(e);
         }
@@ -1130,10 +1298,12 @@ public class TransactionManager {
 
     private class InitProducerIdHandler extends TxnRequestHandler {
         private final InitProducerIdRequest.Builder builder;
+        private final boolean isEpochBump;
 
-        private InitProducerIdHandler(InitProducerIdRequest.Builder builder) {
+        private InitProducerIdHandler(InitProducerIdRequest.Builder builder, boolean isEpochBump) {
             super("InitProducerId");
             this.builder = builder;
+            this.isEpochBump = isEpochBump;
         }
 
         @Override
@@ -1143,7 +1313,7 @@ public class TransactionManager {
 
         @Override
         Priority priority() {
-            return Priority.INIT_PRODUCER_ID;
+            return this.isEpochBump ? Priority.EPOCH_BUMP : Priority.INIT_PRODUCER_ID;
         }
 
         @Override
@@ -1166,6 +1336,9 @@ public class TransactionManager {
                 setProducerIdAndEpoch(producerIdAndEpoch);
                 transitionTo(State.READY);
                 lastError = null;
+                if (this.isEpochBump) {
+                    resetSequenceNumbers();
+                }
                 result.done();
             } else if (error == Errors.NOT_COORDINATOR || error == Errors.COORDINATOR_NOT_AVAILABLE) {
                 lookupCoordinator(FindCoordinatorRequest.CoordinatorType.TRANSACTION, transactionalId);
@@ -1232,8 +1405,7 @@ public class TransactionManager {
                 } else if (error == Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED) {
                     fatalError(error.exception());
                     return;
-                } else if (error == Errors.INVALID_PRODUCER_ID_MAPPING
-                        || error == Errors.INVALID_TXN_STATE) {
+                } else if (error == Errors.INVALID_TXN_STATE) {
                     fatalError(new KafkaException(error.exception()));
                     return;
                 } else if (error == Errors.TOPIC_AUTHORIZATION_FAILED) {
@@ -1242,6 +1414,9 @@ public class TransactionManager {
                     log.debug("Did not attempt to add partition {} to transaction because other partitions in the " +
                             "batch had errors.", topicPartition);
                     hasPartitionErrors = true;
+                } else if (error == Errors.UNKNOWN_PRODUCER_ID || error == Errors.INVALID_PRODUCER_ID_MAPPING) {
+                    abortableErrorIfPossible(error.exception());
+                    return;
                 } else {
                     log.error("Could not add partition {} due to unexpected error {}", topicPartition, error);
                     hasPartitionErrors = true;
@@ -1328,6 +1503,7 @@ public class TransactionManager {
                         break;
                     case TRANSACTION:
                         transactionCoordinator = node;
+
                 }
                 result.done();
                 log.info("Discovered {} coordinator {}", coordinatorType.toString().toLowerCase(Locale.ROOT), node);
@@ -1387,6 +1563,8 @@ public class TransactionManager {
                 fatalError(error.exception());
             } else if (error == Errors.INVALID_TXN_STATE) {
                 fatalError(error.exception());
+            } else if (error == Errors.UNKNOWN_PRODUCER_ID || error == Errors.INVALID_PRODUCER_ID_MAPPING) {
+                abortableErrorIfPossible(error.exception());
             } else {
                 fatalError(new KafkaException("Unhandled error in EndTxnResponse: " + error.message()));
             }
@@ -1433,6 +1611,8 @@ public class TransactionManager {
                 reenqueue();
             } else if (error == Errors.COORDINATOR_LOAD_IN_PROGRESS || error == Errors.CONCURRENT_TRANSACTIONS) {
                 reenqueue();
+            } else if (error == Errors.UNKNOWN_PRODUCER_ID || error == Errors.INVALID_PRODUCER_ID_MAPPING) {
+                abortableErrorIfPossible(error.exception());
             } else if (error == Errors.INVALID_PRODUCER_EPOCH) {
                 fatalError(error.exception());
             } else if (error == Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java
index 08b29b0..a7a9f09 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java
@@ -705,8 +705,9 @@ public class RecordAccumulatorTest {
         String metricGrpName = "producer-metrics";
 
         apiVersions.update("foobar", NodeApiVersions.create(ApiKeys.PRODUCE.id, (short) 0, (short) 2));
+        TransactionManager transactionManager = new TransactionManager(new LogContext(), null, 0, 100L, new ApiVersions());
         RecordAccumulator accum = new RecordAccumulator(logContext, batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD,
-            CompressionType.NONE, lingerMs, retryBackoffMs, deliveryTimeoutMs, metrics, metricGrpName, time, apiVersions, new TransactionManager(),
+            CompressionType.NONE, lingerMs, retryBackoffMs, deliveryTimeoutMs, metrics, metricGrpName, time, apiVersions, transactionManager,
             new BufferPool(totalSize, batchSize, metrics, time, metricGrpName));
         accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds());
     }
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
index 86b4291..5f127e8 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
@@ -33,7 +33,6 @@ import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.ClusterAuthorizationException;
 import org.apache.kafka.common.errors.NetworkException;
-import org.apache.kafka.common.errors.OutOfOrderSequenceException;
 import org.apache.kafka.common.errors.RecordTooLargeException;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.errors.TopicAuthorizationException;
@@ -527,9 +526,9 @@ public class SenderTest {
     }
 
     @Test
-    public void testInitProducerIdRequest() throws Exception {
+    public void testInitProducerIdRequest() {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -540,7 +539,7 @@ public class SenderTest {
     @Test
     public void testClusterAuthorizationExceptionInInitProducerIdRequest() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.CLUSTER_AUTHORIZATION_FAILED);
         assertFalse(transactionManager.hasProducerId());
@@ -585,7 +584,7 @@ public class SenderTest {
     @Test
     public void testIdempotenceWithMultipleInflights() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -635,7 +634,7 @@ public class SenderTest {
     public void testIdempotenceWithMultipleInflightsRetriedInOrder() throws Exception {
         // Send multiple in flight requests, retry them all one at a time, in the correct order.
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -735,7 +734,7 @@ public class SenderTest {
     @Test
     public void testIdempotenceWithMultipleInflightsWhereFirstFailsFatallyAndSequenceOfFutureBatchesIsAdjusted() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -792,9 +791,9 @@ public class SenderTest {
     }
 
     @Test
-    public void testMustNotRetryOutOfOrderSequenceForNextBatch() throws Exception {
+    public void testEpochBumpOnOutOfOrderSequenceForNextBatch() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -827,17 +826,22 @@ public class SenderTest {
         assertFalse(request2.isDone());
         assertTrue(client.isReady(node, time.milliseconds()));
 
-        // This OutOfOrderSequence is fatal since it is returned for the batch succeeding the last acknowledged batch.
+        // This OutOfOrderSequence triggers an epoch bump since it is returned for the batch succeeding the last acknowledged batch.
         sendIdempotentProducerResponse(2, tp0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, -1L);
 
         sender.runOnce();
-        assertFutureFailure(request2, OutOfOrderSequenceException.class);
+        sender.runOnce();
+
+        // epoch should be bumped and sequence numbers reset
+        assertEquals(1, transactionManager.producerIdAndEpoch().epoch);
+        assertEquals(1, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(0, transactionManager.firstInFlightSequence(tp0));
     }
 
     @Test
     public void testCorrectHandlingOfOutOfOrderResponses() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -918,7 +922,7 @@ public class SenderTest {
     @Test
     public void testCorrectHandlingOfOutOfOrderResponsesWhenSecondSucceeds() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -986,7 +990,7 @@ public class SenderTest {
     @Test
     public void testExpiryOfUnsentBatchesShouldNotCauseUnresolvedSequences() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1009,7 +1013,7 @@ public class SenderTest {
     @Test
     public void testExpiryOfFirstBatchShouldNotCauseUnresolvedSequencesIfFutureBatchesSucceed() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager, false, null);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1074,9 +1078,9 @@ public class SenderTest {
     }
 
     @Test
-    public void testExpiryOfFirstBatchShouldCauseResetIfFutureBatchesFail() throws Exception {
+    public void testExpiryOfFirstBatchShouldCauseEpochBumpIfFutureBatchesFail() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1111,24 +1115,64 @@ public class SenderTest {
         sender.runOnce();  // send second request
         sendIdempotentProducerResponse(1, tp0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, 1);
         sender.runOnce(); // receive second response, the third request shouldn't be sent since we are in an unresolved state.
-        assertFutureFailure(request2, OutOfOrderSequenceException.class);
 
         Deque<ProducerBatch> batches = accumulator.batches().get(tp0);
 
-        // The second request should not be requeued.
-        assertEquals(1, batches.size());
-        assertFalse(batches.peekFirst().hasSequence());
-        assertFalse(client.hasInFlightRequests());
+        // The epoch should be bumped and the second request should be requeued
+        assertEquals(2, batches.size());
 
-        // The producer state should be reset.
-        assertFalse(transactionManager.hasProducerId());
+        sender.runOnce();
+        assertEquals((short) 1, transactionManager.producerIdAndEpoch().epoch);
+        assertEquals(1, transactionManager.sequenceNumber(tp0).longValue());
         assertFalse(transactionManager.hasUnresolvedSequence(tp0));
     }
 
     @Test
+    public void testUnresolvedSequencesAreNotFatal() throws Exception {
+        ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
+        apiVersions.update("0", NodeApiVersions.create(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 3));
+        TransactionManager txnManager = new TransactionManager(logContext, "testUnresolvedSeq", 60000, 100, apiVersions);
+
+        setupWithTransactionState(txnManager);
+        doInitTransactions(txnManager, producerIdAndEpoch);
+
+        txnManager.beginTransaction();
+        txnManager.failIfNotReadyForSend();
+        txnManager.maybeAddPartitionToTransaction(tp0);
+        client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp0, Errors.NONE)));
+        sender.runOnce();
+
+        // Send first ProduceRequest
+        Future<RecordMetadata> request1 = appendToAccumulator(tp0);
+        sender.runOnce();  // send request
+
+        time.sleep(1000L);
+        appendToAccumulator(tp0);
+        sender.runOnce();  // send request
+
+        assertEquals(2, client.inFlightRequestCount());
+
+        sendIdempotentProducerResponse(0, tp0, Errors.NOT_LEADER_FOR_PARTITION, -1);
+        sender.runOnce();  // receive first response
+
+        Node node = metadata.fetch().nodes().get(0);
+        time.sleep(1000L);
+        client.disconnect(node.idString());
+        client.blackout(node, 10);
+
+        sender.runOnce(); // now expire the first batch.
+        assertFutureFailure(request1, TimeoutException.class);
+        assertTrue(txnManager.hasUnresolvedSequence(tp0));
+
+        // Loop once and confirm that the transaction manager does not enter a fatal error state
+        sender.runOnce();
+        assertTrue(txnManager.hasAbortableError());
+    }
+
+    @Test
     public void testExpiryOfAllSentBatchesShouldCauseUnresolvedSequences() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1157,18 +1201,18 @@ public class SenderTest {
         assertEquals(0, batches.size());
         assertTrue(transactionManager.hasProducerId(producerId));
 
-        // We should now clear the old producerId and get a new one in a single run loop.
-        time.sleep(10);
-        prepareAndReceiveInitProducerId(producerId + 1, Errors.NONE);
-        assertTrue(transactionManager.hasProducerId(producerId + 1));
+        // In the next run loop, we bump the epoch and clear the unresolved sequences
+        sender.runOnce();
+        assertEquals(1, transactionManager.producerIdAndEpoch().epoch);
+        assertFalse(transactionManager.hasUnresolvedSequence(tp0));
     }
 
     @Test
     public void testResetOfProducerStateShouldAllowQueuedBatchesToDrain() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
-        prepareAndReceiveInitProducerId(producerId, Errors.NONE);
+        prepareAndReceiveInitProducerId(producerId, Short.MAX_VALUE, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
 
         int maxRetries = 10;
@@ -1189,11 +1233,9 @@ public class SenderTest {
         responses.put(tp0, new OffsetAndError(-1, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER));
         client.respond(produceResponse(responses));
 
-        sender.runOnce();
-        assertTrue(failedResponse.isDone());
-        assertFalse("Expected transaction state to be reset upon receiving an OutOfOrderSequenceException", transactionManager.hasProducerId());
+        sender.runOnce(); // trigger epoch bump
         prepareAndReceiveInitProducerId(producerId + 1, Errors.NONE); // also send request to tp1
-        sender.runOnce();
+        sender.runOnce(); // reset producer ID because epoch is maxed out
         assertEquals(producerId + 1, transactionManager.producerIdAndEpoch().producerId);
 
         assertFalse(successfulResponse.isDone());
@@ -1210,9 +1252,9 @@ public class SenderTest {
     @Test
     public void testCloseWithProducerIdReset() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
-        prepareAndReceiveInitProducerId(producerId, Errors.NONE);
+        prepareAndReceiveInitProducerId(producerId, Short.MAX_VALUE, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
 
         Metrics m = new Metrics();
@@ -1232,9 +1274,7 @@ public class SenderTest {
         responses.put(tp0, new OffsetAndError(-1, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER));
         client.respond(produceResponse(responses));
         sender.initiateClose(); // initiate close
-        sender.runOnce();
-        assertTrue(failedResponse.isDone());
-        assertFalse("Expected transaction state to be reset upon receiving an OutOfOrderSequenceException", transactionManager.hasProducerId());
+        sender.runOnce(); // out of order sequence error triggers producer ID reset because epoch is maxed out
 
         TestUtils.waitForCondition(new TestCondition() {
             @Override
@@ -1248,9 +1288,9 @@ public class SenderTest {
 
     @Test
     public void testForceCloseWithProducerIdReset() throws Exception {
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
-        prepareAndReceiveInitProducerId(1L, Errors.NONE);
+        prepareAndReceiveInitProducerId(1L, Short.MAX_VALUE, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
 
         Metrics m = new Metrics();
@@ -1269,9 +1309,7 @@ public class SenderTest {
         responses.put(tp1, new OffsetAndError(-1, Errors.NOT_LEADER_FOR_PARTITION));
         responses.put(tp0, new OffsetAndError(-1, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER));
         client.respond(produceResponse(responses));
-        sender.runOnce();
-        assertTrue(failedResponse.isDone());
-        assertFalse("Expected transaction state to be reset upon receiving an OutOfOrderSequenceException", transactionManager.hasProducerId());
+        sender.runOnce(); // out of order sequence error triggers producer ID reset because epoch is maxed out
         sender.forceClose(); // initiate force close
         sender.runOnce(); // this should not block
         sender.run(); // run main loop to test forceClose flag
@@ -1280,9 +1318,9 @@ public class SenderTest {
     }
 
     @Test
-    public void testBatchesDrainedWithOldProducerIdShouldFailWithOutOfOrderSequenceOnSubsequentRetry() throws Exception {
+    public void testBatchesDrainedWithOldProducerIdShouldSucceedOnSubsequentRetry() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1294,7 +1332,7 @@ public class SenderTest {
         Sender sender = new Sender(logContext, client, metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL, maxRetries,
                 senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, transactionManager, apiVersions);
 
-        Future<RecordMetadata> failedResponse = appendToAccumulator(tp0);
+        Future<RecordMetadata> outOfOrderResponse = appendToAccumulator(tp0);
         Future<RecordMetadata> successfulResponse = appendToAccumulator(tp1);
         sender.runOnce();  // connect.
         sender.runOnce();  // send.
@@ -1306,32 +1344,29 @@ public class SenderTest {
         responses.put(tp0, new OffsetAndError(-1, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER));
         client.respond(produceResponse(responses));
         sender.runOnce();
-        assertTrue(failedResponse.isDone());
-        assertFalse("Expected transaction state to be reset upon receiving an OutOfOrderSequenceException", transactionManager.hasProducerId());
-        prepareAndReceiveInitProducerId(producerId + 1, Errors.NONE);
-        assertEquals(producerId + 1, transactionManager.producerIdAndEpoch().producerId);
-        sender.runOnce();  // send request to tp1 with the old producerId
+        assertFalse(outOfOrderResponse.isDone());
+
+        sender.runOnce();  // bump epoch send request to tp1 with the old producerId
+        assertEquals(1, transactionManager.producerIdAndEpoch().epoch);
 
         assertFalse(successfulResponse.isDone());
         // The response comes back with a retriable error.
         client.respond(produceResponse(tp1, 0, Errors.NOT_LEADER_FOR_PARTITION, -1));
         sender.runOnce();
 
+        // The response
+        assertFalse(successfulResponse.isDone());
+        sender.runOnce(); // retry one more time
+        client.respond(produceResponse(tp1, 0, Errors.NONE, -1));
+        sender.runOnce();
         assertTrue(successfulResponse.isDone());
-        // Since the batch has an old producerId, it will not be retried yet again, but will be failed with a Fatal
-        // exception.
-        try {
-            successfulResponse.get();
-            fail("Should have raised an OutOfOrderSequenceException");
-        } catch (Exception e) {
-            assertTrue(e.getCause() instanceof OutOfOrderSequenceException);
-        }
+        assertEquals(0, transactionManager.sequenceNumber(tp1).intValue());
     }
 
     @Test
     public void testCorrectHandlingOfDuplicateSequenceError() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1382,13 +1417,18 @@ public class SenderTest {
     }
 
     @Test
-    public void testUnknownProducerHandlingWhenRetentionLimitReached() throws Exception {
+    public void testTransactionalUnknownProducerHandlingWhenRetentionLimitReached() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = new TransactionManager(logContext, "testUnresolvedSeq", 60000, 100, apiVersions);
+
         setupWithTransactionState(transactionManager);
-        prepareAndReceiveInitProducerId(producerId, Errors.NONE);
+        doInitTransactions(transactionManager, new ProducerIdAndEpoch(producerId, (short) 0));
         assertTrue(transactionManager.hasProducerId());
 
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+        client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp0, Errors.NONE)));
+        sender.runOnce(); // Receive AddPartitions response
+
         assertEquals(0, transactionManager.sequenceNumber(tp0).longValue());
 
         // Send first ProduceRequest
@@ -1440,9 +1480,67 @@ public class SenderTest {
     }
 
     @Test
+    public void testIdempotentUnknownProducerHandlingWhenRetentionLimitReached() throws Exception {
+        final long producerId = 343434L;
+        TransactionManager transactionManager = createTransactionManager();
+        setupWithTransactionState(transactionManager);
+        prepareAndReceiveInitProducerId(producerId, Errors.NONE);
+        assertTrue(transactionManager.hasProducerId());
+
+        assertEquals(0, transactionManager.sequenceNumber(tp0).longValue());
+
+        // Send first ProduceRequest
+        Future<RecordMetadata> request1 = appendToAccumulator(tp0);
+        sender.runOnce();
+
+        assertEquals(1, client.inFlightRequestCount());
+        assertEquals(1, transactionManager.sequenceNumber(tp0).longValue());
+        assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0));
+
+        sendIdempotentProducerResponse(0, tp0, Errors.NONE, 1000L, 10L);
+
+        sender.runOnce();  // receive the response.
+
+        assertTrue(request1.isDone());
+        assertEquals(1000L, request1.get().offset());
+        assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0));
+        assertEquals(OptionalLong.of(1000L), transactionManager.lastAckedOffset(tp0));
+
+        // Send second ProduceRequest, a single batch with 2 records.
+        appendToAccumulator(tp0);
+        Future<RecordMetadata> request2 = appendToAccumulator(tp0);
+        sender.runOnce();
+        assertEquals(3, transactionManager.sequenceNumber(tp0).longValue());
+        assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0));
+
+        assertFalse(request2.isDone());
+
+        sendIdempotentProducerResponse(1, tp0, Errors.UNKNOWN_PRODUCER_ID, -1L, 1010L);
+        sender.runOnce(); // receive response 0, should be retried since the logStartOffset > lastAckedOffset.
+        sender.runOnce(); // bump epoch and retry request
+
+        // We should have reset the sequence number state of the partition because the state was lost on the broker.
+        assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0));
+        assertEquals(2, transactionManager.sequenceNumber(tp0).longValue());
+        assertFalse(request2.isDone());
+        assertTrue(client.hasInFlightRequests());
+        assertEquals((short) 1, transactionManager.producerIdAndEpoch().epoch);
+
+        // resend the request. Note that the expected sequence is 0, since we have lost producer state on the broker.
+        sendIdempotentProducerResponse(0, tp0, Errors.NONE, 1011L, 1010L);
+        sender.runOnce(); // receive response 1
+        assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0));
+        assertEquals(2, transactionManager.sequenceNumber(tp0).longValue());
+        assertFalse(client.hasInFlightRequests());
+        assertTrue(request2.isDone());
+        assertEquals(1012L, request2.get().offset());
+        assertEquals(OptionalLong.of(1012L), transactionManager.lastAckedOffset(tp0));
+    }
+
+    @Test
     public void testUnknownProducerErrorShouldBeRetriedWhenLogStartOffsetIsUnknown() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1500,7 +1598,7 @@ public class SenderTest {
     @Test
     public void testUnknownProducerErrorShouldBeRetriedForFutureBatchesWhenFirstFails() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1541,20 +1639,17 @@ public class SenderTest {
         assertFalse(request3.isDone());
         assertEquals(2, client.inFlightRequestCount());
 
-
         sendIdempotentProducerResponse(1, tp0, Errors.UNKNOWN_PRODUCER_ID, -1L, 1010L);
         sender.runOnce(); // receive response 2, should reset the sequence numbers and be retried.
+        sender.runOnce(); // bump epoch and retry request 2
 
         // We should have reset the sequence number state of the partition because the state was lost on the broker.
         assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0));
         assertEquals(2, transactionManager.sequenceNumber(tp0).longValue());
         assertFalse(request2.isDone());
         assertFalse(request3.isDone());
-        assertEquals(1, client.inFlightRequestCount());
-
-        sender.runOnce(); // resend request 2.
-
         assertEquals(2, client.inFlightRequestCount());
+        assertEquals((short) 1, transactionManager.producerIdAndEpoch().epoch);
 
         // receive the original response 3. note the expected sequence is still the originally assigned sequence.
         sendIdempotentProducerResponse(2, tp0, Errors.UNKNOWN_PRODUCER_ID, -1, 1010L);
@@ -1589,7 +1684,7 @@ public class SenderTest {
     @Test
     public void testShouldRaiseOutOfOrderSequenceExceptionToUserIfLogWasNotTruncated() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1622,9 +1717,11 @@ public class SenderTest {
         assertFalse(request2.isDone());
 
         sendIdempotentProducerResponse(1, tp0, Errors.UNKNOWN_PRODUCER_ID, -1L, 10L);
-        sender.runOnce(); // receive response 0, should cause a producerId reset since the logStartOffset < lastAckedOffset
-        assertFutureFailure(request2, OutOfOrderSequenceException.class);
-
+        sender.runOnce(); // receive response 0, should request an epoch bump
+        sender.runOnce(); // bump epoch
+        assertEquals(1, transactionManager.producerIdAndEpoch().epoch);
+        assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0));
+        assertFalse(request2.isDone());
     }
     void sendIdempotentProducerResponse(int expectedSequence, TopicPartition tp, Errors responseError, long responseOffset) {
         sendIdempotentProducerResponse(expectedSequence, tp, responseError, responseOffset, -1L);
@@ -1650,7 +1747,7 @@ public class SenderTest {
     @Test
     public void testClusterAuthorizationExceptionInProduceRequest() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
 
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
@@ -1676,7 +1773,7 @@ public class SenderTest {
     @Test
     public void testCancelInFlightRequestAfterFatalError() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
 
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
@@ -1716,7 +1813,7 @@ public class SenderTest {
     @Test
     public void testUnsupportedForMessageFormatInProduceRequest() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
 
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
@@ -1740,7 +1837,7 @@ public class SenderTest {
     @Test
     public void testUnsupportedVersionInProduceRequest() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
 
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
@@ -1765,7 +1862,7 @@ public class SenderTest {
     @Test
     public void testSequenceNumberIncrement() throws InterruptedException {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1807,11 +1904,11 @@ public class SenderTest {
     }
 
     @Test
-    public void testAbortRetryWhenProducerIdChanges() throws InterruptedException {
+    public void testRetryWhenProducerIdChanges() throws InterruptedException {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
-        prepareAndReceiveInitProducerId(producerId, Errors.NONE);
+        prepareAndReceiveInitProducerId(producerId, Short.MAX_VALUE, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
 
         int maxRetries = 10;
@@ -1830,25 +1927,21 @@ public class SenderTest {
         client.disconnect(id);
         assertEquals(0, client.inFlightRequestCount());
         assertFalse("Client ready status should be false", client.isReady(node, time.milliseconds()));
+        sender.runOnce(); // receive error
+        sender.runOnce(); // reset producer ID because epoch is maxed out
 
-        transactionManager.resetIdempotentProducerId();
         prepareAndReceiveInitProducerId(producerId + 1, Errors.NONE);
-        sender.runOnce(); // receive error
-        sender.runOnce(); // reconnect
         sender.runOnce(); // nothing to do, since the pid has changed. We should check the metrics for errors.
-        assertEquals("Expected requests to be aborted after pid change", 0, client.inFlightRequestCount());
-
-        KafkaMetric recordErrors = m.metrics().get(senderMetrics.recordErrorRate);
-        assertTrue("Expected non-zero value for record send errors", (Double) recordErrors.metricValue() > 0);
+        assertEquals("Expected requests to be retried after pid change", 1, client.inFlightRequestCount());
 
-        assertTrue(responseFuture.isDone());
-        assertEquals(0, (long) transactionManager.sequenceNumber(tp0));
+        assertFalse(responseFuture.isDone());
+        assertEquals(1, (long) transactionManager.sequenceNumber(tp0));
     }
 
     @Test
-    public void testResetWhenOutOfOrderSequenceReceived() throws InterruptedException {
+    public void testBumpEpochWhenOutOfOrderSequenceReceived() throws InterruptedException {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager();
+        TransactionManager transactionManager = createTransactionManager();
         setupWithTransactionState(transactionManager);
         prepareAndReceiveInitProducerId(producerId, Errors.NONE);
         assertTrue(transactionManager.hasProducerId());
@@ -1869,16 +1962,17 @@ public class SenderTest {
 
         client.respond(produceResponse(tp0, 0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, 0));
 
-        sender.runOnce();
-        assertTrue(responseFuture.isDone());
-        assertEquals(0, sender.inFlightBatches(tp0).size());
-        assertFalse("Expected transaction state to be reset upon receiving an OutOfOrderSequenceException", transactionManager.hasProducerId());
+        sender.runOnce(); // receive the out of order sequence error
+        sender.runOnce(); // bump the epoch
+        assertFalse(responseFuture.isDone());
+        assertEquals(1, sender.inFlightBatches(tp0).size());
+        assertEquals(1, transactionManager.producerIdAndEpoch().epoch);
     }
 
     @Test
     public void testIdempotentSplitBatchAndSend() throws Exception {
         TopicPartition tp = new TopicPartition("testSplitBatchAndSend", 1);
-        TransactionManager txnManager = new TransactionManager();
+        TransactionManager txnManager = createTransactionManager();
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
         setupWithTransactionState(txnManager);
         prepareAndReceiveInitProducerId(123456L, Errors.NONE);
@@ -1890,7 +1984,7 @@ public class SenderTest {
     public void testTransactionalSplitBatchAndSend() throws Exception {
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
         TopicPartition tp = new TopicPartition("testSplitBatchAndSend", 1);
-        TransactionManager txnManager = new TransactionManager(logContext, "testSplitBatchAndSend", 60000, 100);
+        TransactionManager txnManager = new TransactionManager(logContext, "testSplitBatchAndSend", 60000, 100, apiVersions);
 
         setupWithTransactionState(txnManager);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -2199,7 +2293,7 @@ public class SenderTest {
         Metrics m = new Metrics();
         SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
         try {
-            TransactionManager txnManager = new TransactionManager(logContext, "testTransactionalRequestsSentOnShutdown", 6000, 100);
+            TransactionManager txnManager = new TransactionManager(logContext, "testTransactionalRequestsSentOnShutdown", 6000, 100, apiVersions);
             Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
                     maxRetries, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions);
 
@@ -2234,7 +2328,7 @@ public class SenderTest {
         Metrics m = new Metrics();
         SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
         try {
-            TransactionManager txnManager = new TransactionManager(logContext, "testIncompleteTransactionAbortOnShutdown", 6000, 100);
+            TransactionManager txnManager = new TransactionManager(logContext, "testIncompleteTransactionAbortOnShutdown", 6000, 100, apiVersions);
             Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
                     maxRetries, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions);
 
@@ -2268,7 +2362,7 @@ public class SenderTest {
         Metrics m = new Metrics();
         SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
         try {
-            TransactionManager txnManager = new TransactionManager(logContext, "testForceShutdownWithIncompleteTransaction", 6000, 100);
+            TransactionManager txnManager = new TransactionManager(logContext, "testForceShutdownWithIncompleteTransaction", 6000, 100, apiVersions);
             Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
                     maxRetries, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions);
 
@@ -2300,7 +2394,7 @@ public class SenderTest {
     public void testDoNotPollWhenNoRequestSent() {
         client = spy(new MockClient(time, metadata));
 
-        TransactionManager txnManager = new TransactionManager(logContext, "testDoNotPollWhenNoRequestSent", 6000, 100);
+        TransactionManager txnManager = new TransactionManager(logContext, "testDoNotPollWhenNoRequestSent", 6000, 100, apiVersions);
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
         setupWithTransactionState(txnManager);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -2312,7 +2406,7 @@ public class SenderTest {
     @Test
     public void testTooLargeBatchesAreSafelyRemoved() throws InterruptedException {
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
-        TransactionManager txnManager = new TransactionManager(logContext, "testSplitBatchAndSend", 60000, 100);
+        TransactionManager txnManager = new TransactionManager(logContext, "testSplitBatchAndSend", 60000, 100, apiVersions);
 
         setupWithTransactionState(txnManager, false, null);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -2462,6 +2556,10 @@ public class SenderTest {
         return produceResponse(tp, offset, error, throttleTimeMs, -1L);
     }
 
+    private TransactionManager createTransactionManager() {
+        return new TransactionManager(new LogContext(), null, 0, 100L, new ApiVersions());
+    }
+    
     private void setupWithTransactionState(TransactionManager transactionManager) {
         setupWithTransactionState(transactionManager, false, null);
     }
@@ -2497,7 +2595,10 @@ public class SenderTest {
     }
 
     private void prepareAndReceiveInitProducerId(long producerId, Errors error) {
-        short producerEpoch = 0;
+        prepareAndReceiveInitProducerId(producerId, (short) 0, error);
+    }
+
+    private void prepareAndReceiveInitProducerId(long producerId, short producerEpoch, Errors error) {
         if (error != Errors.NONE)
             producerEpoch = RecordBatch.NO_PRODUCER_EPOCH;
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
index aec9c01..d2e7629 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
@@ -16,10 +16,12 @@
  */
 package org.apache.kafka.clients.producer.internals;
 
+import org.apache.kafka.clients.ApiVersion;
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.consumer.CommitFailedException;
 import org.apache.kafka.clients.consumer.ConsumerGroupMetadata;
+import org.apache.kafka.clients.NodeApiVersions;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.common.errors.FencedInstanceIdException;
 import org.apache.kafka.common.requests.JoinGroupRequest;
@@ -43,6 +45,7 @@ import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.message.EndTxnResponseData;
 import org.apache.kafka.common.message.InitProducerIdResponseData;
 import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.record.CompressionType;
 import org.apache.kafka.common.record.MemoryRecords;
@@ -95,6 +98,7 @@ import static java.util.Collections.singletonList;
 import static java.util.Collections.singletonMap;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
@@ -138,14 +142,18 @@ public class TransactionManagerTest {
     private void initializeTransactionManager(Optional<String> transactionalId) {
         Metrics metrics = new Metrics(time);
 
+        apiVersions.update("0", new NodeApiVersions(Arrays.asList(
+                new ApiVersion(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 3),
+                new ApiVersion(ApiKeys.PRODUCE.id, (short) 0, (short) 7))));
         this.transactionManager = new TransactionManager(logContext, transactionalId.orElse(null),
-                transactionTimeoutMs, DEFAULT_RETRY_BACKOFF_MS);
+                transactionTimeoutMs, DEFAULT_RETRY_BACKOFF_MS, apiVersions);
 
         int batchSize = 16 * 1024;
         int deliveryTimeoutMs = 3000;
         long totalSize = 1024 * 1024;
         String metricGrpName = "producer-metrics";
 
+        this.brokerNode = new Node(0, "localhost", 2211);
         this.accumulator = new RecordAccumulator(logContext, batchSize, CompressionType.NONE, 0, 0L,
                 deliveryTimeoutMs, metrics, metricGrpName, time, apiVersions, transactionManager,
                 new BufferPool(totalSize, batchSize, metrics, time, metricGrpName));
@@ -612,7 +620,7 @@ public class TransactionManagerTest {
     }
 
     @Test
-    public void testResetSequenceNumbersAfterUnknownProducerId() {
+    public void testBumpEpochAndResetSequenceNumbersAfterUnknownProducerId() {
         final long producerId = 13131L;
         final short epoch = 1;
 
@@ -633,58 +641,15 @@ public class TransactionManagerTest {
         b1.done(500L, b1AppendTime, null);
         transactionManager.handleCompletedBatch(b1, b1Response);
 
-        // Retention caused log start offset to jump forward. We set sequence numbers back to 0
+        // We get an UNKNOWN_PRODUCER_ID, so bump the epoch and set sequence numbers back to 0
         ProduceResponse.PartitionResponse b2Response = new ProduceResponse.PartitionResponse(
-                Errors.UNKNOWN_PRODUCER_ID, -1, -1, 600L);
+                Errors.UNKNOWN_PRODUCER_ID, -1, -1, 500L);
         assertTrue(transactionManager.canRetry(b2Response, b2));
-        assertEquals(4, transactionManager.sequenceNumber(tp0).intValue());
-        assertEquals(0, b2.baseSequence());
-        assertEquals(1, b3.baseSequence());
-        assertEquals(2, b4.baseSequence());
-        assertEquals(3, b5.baseSequence());
-    }
-
-    @Test
-    public void testAdjustSequenceNumbersAfterFatalError() {
-        final long producerId = 13131L;
-        final short epoch = 1;
-
-        initializeTransactionManager(Optional.empty());
-        initializeIdempotentProducerId(producerId, epoch);
-
-        ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1");
-        ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
-        ProducerBatch b3 = writeIdempotentBatchWithValue(transactionManager, tp0, "3");
-        ProducerBatch b4 = writeIdempotentBatchWithValue(transactionManager, tp0, "4");
-        ProducerBatch b5 = writeIdempotentBatchWithValue(transactionManager, tp0, "5");
-        assertEquals(5, transactionManager.sequenceNumber(tp0).intValue());
-
-        // First batch succeeds
-        long b1AppendTime = time.milliseconds();
-        ProduceResponse.PartitionResponse b1Response = new ProduceResponse.PartitionResponse(
-                Errors.NONE, 500L, b1AppendTime, 0L);
-        b1.done(500L, b1AppendTime, null);
-        transactionManager.handleCompletedBatch(b1, b1Response);
-
-        // Second batch fails with a fatal error. Sequence numbers are adjusted by one for remaining
-        // inflight batches.
-        ProduceResponse.PartitionResponse b2Response = new ProduceResponse.PartitionResponse(
-                Errors.MESSAGE_TOO_LARGE, -1, -1, 0L);
-        assertFalse(transactionManager.canRetry(b2Response, b2));
-
-        b2.done(-1L, -1L, Errors.MESSAGE_TOO_LARGE.exception());
-        transactionManager.handleFailedBatch(b2, Errors.MESSAGE_TOO_LARGE.exception(), true);
-        assertEquals(4, transactionManager.sequenceNumber(tp0).intValue());
-        assertEquals(1, b3.baseSequence());
-        assertEquals(2, b4.baseSequence());
-        assertEquals(3, b5.baseSequence());
 
-        // The remaining batches are doomed to fail, but they can be retried. Expected
-        // sequence numbers should remain the same.
-        ProduceResponse.PartitionResponse b3Response = new ProduceResponse.PartitionResponse(
-                Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, -1, -1, 0L);
-        assertTrue(transactionManager.canRetry(b3Response, b3));
-        assertEquals(4, transactionManager.sequenceNumber(tp0).intValue());
+        // Run sender loop to trigger epoch bump
+        runUntil(() -> transactionManager.producerIdAndEpoch().epoch == 2);
+        assertEquals(2, b2.producerEpoch());
+        assertEquals(0, b2.baseSequence());
         assertEquals(1, b3.baseSequence());
         assertEquals(2, b4.baseSequence());
         assertEquals(3, b5.baseSequence());
@@ -693,55 +658,82 @@ public class TransactionManagerTest {
     @Test
     public void testBatchFailureAfterProducerReset() {
         // This tests a scenario where the producerId is reset while pending requests are still inflight.
-        // The returned responses should not update internal state.
+        // The partition(s) that triggered the reset will have their sequence number reset, while any others will not
 
         final long producerId = 13131L;
-        final short epoch = 1;
+        final short epoch = Short.MAX_VALUE;
 
         initializeTransactionManager(Optional.empty());
         initializeIdempotentProducerId(producerId, epoch);
 
-        ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1");
+        ProducerBatch tp0b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1");
+        ProducerBatch tp1b1 = writeIdempotentBatchWithValue(transactionManager, tp1, "1");
 
-        transactionManager.resetIdempotentProducerId();
-        initializeIdempotentProducerId(producerId + 1, epoch);
+        ProduceResponse.PartitionResponse tp0b1Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, -1, -1, 400L);
+        transactionManager.handleCompletedBatch(tp0b1, tp0b1Response);
 
-        ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
-        assertEquals(1, transactionManager.sequenceNumber(tp0).intValue());
+        ProduceResponse.PartitionResponse tp1b1Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, -1, -1, 400L);
+        transactionManager.handleCompletedBatch(tp1b1, tp1b1Response);
+
+        ProducerBatch tp0b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
+        ProducerBatch tp1b2 = writeIdempotentBatchWithValue(transactionManager, tp1, "2");
+        assertEquals(2, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(2, transactionManager.sequenceNumber(tp1).intValue());
 
         ProduceResponse.PartitionResponse b1Response = new ProduceResponse.PartitionResponse(
                 Errors.UNKNOWN_PRODUCER_ID, -1, -1, 400L);
-        assertFalse(transactionManager.canRetry(b1Response, b1));
-        transactionManager.handleFailedBatch(b1, Errors.UNKNOWN_PRODUCER_ID.exception(), true);
+        assertTrue(transactionManager.canRetry(b1Response, tp0b1));
+
+        ProduceResponse.PartitionResponse b2Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, -1, -1, 400L);
+        transactionManager.handleCompletedBatch(tp1b1, b2Response);
+
+        transactionManager.bumpIdempotentEpochAndResetIdIfNeeded();
 
         assertEquals(1, transactionManager.sequenceNumber(tp0).intValue());
-        assertEquals(b2, transactionManager.nextBatchBySequence(tp0));
+        assertEquals(tp0b2, transactionManager.nextBatchBySequence(tp0));
+        assertEquals(2, transactionManager.sequenceNumber(tp1).intValue());
+        assertEquals(tp1b2, transactionManager.nextBatchBySequence(tp1));
     }
 
     @Test
     public void testBatchCompletedAfterProducerReset() {
         final long producerId = 13131L;
-        final short epoch = 1;
+        final short epoch = Short.MAX_VALUE;
 
         initializeTransactionManager(Optional.empty());
         initializeIdempotentProducerId(producerId, epoch);
 
         ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1");
-
-        // The producerId might be reset due to a failure on another partition
-        transactionManager.resetIdempotentProducerId();
-        initializeIdempotentProducerId(producerId + 1, epoch);
+        writeIdempotentBatchWithValue(transactionManager, tp1, "1");
 
         ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
-        assertEquals(1, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(2, transactionManager.sequenceNumber(tp0).intValue());
 
-        // If the request returns successfully, we should ignore the response and not update any state
+        // The producerId might be reset due to a failure on another partition
+        transactionManager.requestEpochBumpForPartition(tp1);
+        transactionManager.bumpIdempotentEpochAndResetIdIfNeeded();
+        initializeIdempotentProducerId(producerId + 1, (short) 0);
+
+        // We continue to track the state of tp0 until in-flight requests complete
         ProduceResponse.PartitionResponse b1Response = new ProduceResponse.PartitionResponse(
                 Errors.NONE, 500L, time.milliseconds(), 0L);
         transactionManager.handleCompletedBatch(b1, b1Response);
 
-        assertEquals(1, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(2, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(0, transactionManager.lastAckedSequence(tp0).getAsInt());
         assertEquals(b2, transactionManager.nextBatchBySequence(tp0));
+        assertEquals(epoch, transactionManager.nextBatchBySequence(tp0).producerEpoch());
+
+        ProduceResponse.PartitionResponse b2Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, time.milliseconds(), 0L);
+        transactionManager.handleCompletedBatch(b2, b2Response);
+
+        assertEquals(0, transactionManager.sequenceNumber(tp0).intValue());
+        assertFalse(transactionManager.lastAckedSequence(tp0).isPresent());
+        assertNull(transactionManager.nextBatchBySequence(tp0));
     }
 
     private ProducerBatch writeIdempotentBatchWithValue(TransactionManager manager,
@@ -780,12 +772,18 @@ public class TransactionManagerTest {
     @Test
     public void testProducerIdReset() {
         initializeTransactionManager(Optional.empty());
-        initializeIdempotentProducerId(15L, (short) 0);
+        initializeIdempotentProducerId(15L, Short.MAX_VALUE);
         assertEquals((int) transactionManager.sequenceNumber(tp0), 0);
+        assertEquals((int) transactionManager.sequenceNumber(tp1), 0);
         transactionManager.incrementSequenceNumber(tp0, 3);
         assertEquals((int) transactionManager.sequenceNumber(tp0), 3);
-        transactionManager.resetIdempotentProducerId();
+        transactionManager.incrementSequenceNumber(tp1, 3);
+        assertEquals((int) transactionManager.sequenceNumber(tp1), 3);
+
+        transactionManager.requestEpochBumpForPartition(tp0);
+        transactionManager.bumpIdempotentEpochAndResetIdIfNeeded();
         assertEquals((int) transactionManager.sequenceNumber(tp0), 0);
+        assertEquals((int) transactionManager.sequenceNumber(tp1), 3);
     }
 
     @Test
@@ -994,7 +992,6 @@ public class TransactionManagerTest {
             assertEquals(epoch, txnOffsetCommitRequest.data.producerEpoch());
             return !txnOffsetCommitRequest.data.memberId().equals(memberId);
         }, new TxnOffsetCommitResponse(0, singletonMap(tp, Errors.UNKNOWN_MEMBER_ID)));
-        sender.runOnce();  // TxnOffsetCommit Handled
 
         runUntil(transactionManager::hasError);
         assertTrue(transactionManager.lastError() instanceof CommitFailedException);
@@ -1034,7 +1031,6 @@ public class TransactionManagerTest {
             assertEquals(epoch, txnOffsetCommitRequest.data.producerEpoch());
             return txnOffsetCommitRequest.data.generationId() != generationId;
         }, new TxnOffsetCommitResponse(0, singletonMap(tp, Errors.ILLEGAL_GENERATION)));
-        sender.runOnce();  // TxnOffsetCommit Handled
 
         runUntil(transactionManager::hasError);
         assertTrue(transactionManager.lastError() instanceof CommitFailedException);
@@ -1724,7 +1720,7 @@ public class TransactionManagerTest {
         // Commit is not allowed, so let's abort and try again.
         TransactionalRequestResult abortResult = transactionManager.beginAbort();
         prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, pid, epoch);
-
+        prepareInitPidResponse(Errors.NONE, false, pid, (short) (epoch + 1));
         runUntil(abortResult::isCompleted);
         assertTrue(abortResult.isSuccessful());
         assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
@@ -1746,11 +1742,12 @@ public class TransactionManagerTest {
         assertFalse(responseFuture.isDone());
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, pid);
         prepareProduceResponse(Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, pid, epoch);
-        prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, pid, epoch);
 
+        // Because this is a failure that triggers an epoch bump, the abort will trigger an InitProducerId call
         runUntil(transactionManager::hasAbortableError);
         TransactionalRequestResult abortResult = transactionManager.beginAbort();
-
+        prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, pid, epoch);
+        prepareInitPidResponse(Errors.NONE, false, pid, (short) (epoch + 1));
         runUntil(abortResult::isCompleted);
         assertTrue(abortResult.isSuccessful());
         assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
@@ -2480,6 +2477,7 @@ public class TransactionManagerTest {
         TransactionalRequestResult abortResult = transactionManager.beginAbort();
 
         prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, pid, epoch);
+        prepareInitPidResponse(Errors.NONE, false, pid, (short) (epoch + 1));
         runUntil(abortResult::isCompleted);
         assertTrue(abortResult.isSuccessful());
         assertFalse(transactionManager.hasOngoingTransaction());
@@ -2491,6 +2489,10 @@ public class TransactionManagerTest {
         final long pid = 13131L;
         final short epoch = 1;
 
+        apiVersions.update("0", new NodeApiVersions(Arrays.asList(
+                new ApiVersion(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 1),
+                new ApiVersion(ApiKeys.PRODUCE.id, (short) 0, (short) 7))));
+
         doInitTransactions(pid, epoch);
 
         transactionManager.beginTransaction();
@@ -2540,7 +2542,7 @@ public class TransactionManagerTest {
     }
 
     @Test
-    public void testResetProducerIdAfterWithoutPendingInflightRequests() {
+    public void testBumpEpochAfterTimeoutWithoutPendingInflightRequests() {
         initializeTransactionManager(Optional.empty());
         long producerId = 15L;
         short epoch = 5;
@@ -2548,7 +2550,7 @@ public class TransactionManagerTest {
         initializeIdempotentProducerId(producerId, epoch);
 
         // Nothing to resolve, so no reset is needed
-        transactionManager.resetIdempotentProducerIdIfNeeded();
+        transactionManager.bumpIdempotentEpochAndResetIdIfNeeded();
         assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch());
 
         TopicPartition tp0 = new TopicPartition("foo", 0);
@@ -2561,23 +2563,25 @@ public class TransactionManagerTest {
         assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0));
 
         // Marking sequence numbers unresolved without inflight requests is basically a no-op.
-        transactionManager.markSequenceUnresolved(tp0);
-        transactionManager.resetIdempotentProducerIdIfNeeded();
+        transactionManager.markSequenceUnresolved(b1);
+        transactionManager.maybeResolveSequences();
         assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch());
         assertFalse(transactionManager.hasUnresolvedSequences());
 
         // We have a new batch which fails with a timeout
         ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
         assertEquals(Integer.valueOf(2), transactionManager.sequenceNumber(tp0));
-        transactionManager.markSequenceUnresolved(tp0);
+        transactionManager.markSequenceUnresolved(b2);
         transactionManager.handleFailedBatch(b2, new TimeoutException(), false);
         assertTrue(transactionManager.hasUnresolvedSequences());
 
         // We only had one inflight batch, so we should be able to clear the unresolved status
-        // and reset the producerId
-        transactionManager.resetIdempotentProducerIdIfNeeded();
+        // and bump the epoch
+        transactionManager.maybeResolveSequences();
         assertFalse(transactionManager.hasUnresolvedSequences());
-        assertFalse(transactionManager.hasProducerId());
+
+        // Run sender loop to trigger epoch bump
+        runUntil(() -> transactionManager.producerIdAndEpoch().epoch == 6);
     }
 
     @Test
@@ -2595,18 +2599,18 @@ public class TransactionManagerTest {
         assertEquals(3, transactionManager.sequenceNumber(tp0).intValue());
 
         // The first batch fails with a timeout
-        transactionManager.markSequenceUnresolved(tp0);
+        transactionManager.markSequenceUnresolved(b1);
         transactionManager.handleFailedBatch(b1, new TimeoutException(), false);
         assertTrue(transactionManager.hasUnresolvedSequences());
 
         // The reset should not occur until sequence numbers have been resolved
-        transactionManager.resetIdempotentProducerIdIfNeeded();
+        transactionManager.bumpIdempotentEpochAndResetIdIfNeeded();
         assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch());
         assertTrue(transactionManager.hasUnresolvedSequences());
 
         // The second batch fails as well with a timeout
         transactionManager.handleFailedBatch(b2, new TimeoutException(), false);
-        transactionManager.resetIdempotentProducerIdIfNeeded();
+        transactionManager.bumpIdempotentEpochAndResetIdIfNeeded();
         assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch());
         assertTrue(transactionManager.hasUnresolvedSequences());
 
@@ -2614,14 +2618,14 @@ public class TransactionManagerTest {
         // requiring a producerId reset.
         transactionManager.handleCompletedBatch(b3, new ProduceResponse.PartitionResponse(
                 Errors.NONE, 500L, time.milliseconds(), 0L));
-        transactionManager.resetIdempotentProducerIdIfNeeded();
+        transactionManager.maybeResolveSequences();
         assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch());
         assertFalse(transactionManager.hasUnresolvedSequences());
         assertEquals(3, transactionManager.sequenceNumber(tp0).intValue());
     }
 
     @Test
-    public void testProducerIdResetAfterLastInFlightBatchFails() {
+    public void testEpochBumpAfterLastInflightBatchFails() {
         initializeTransactionManager(Optional.empty());
         long producerId = 15L;
         short epoch = 5;
@@ -2635,26 +2639,444 @@ public class TransactionManagerTest {
         assertEquals(Integer.valueOf(3), transactionManager.sequenceNumber(tp0));
 
         // The first batch fails with a timeout
-        transactionManager.markSequenceUnresolved(tp0);
+        transactionManager.markSequenceUnresolved(b1);
         transactionManager.handleFailedBatch(b1, new TimeoutException(), false);
         assertTrue(transactionManager.hasUnresolvedSequences());
 
         // The second batch succeeds, but sequence numbers are still not resolved
         transactionManager.handleCompletedBatch(b2, new ProduceResponse.PartitionResponse(
                 Errors.NONE, 500L, time.milliseconds(), 0L));
-        transactionManager.resetIdempotentProducerIdIfNeeded();
+        transactionManager.bumpIdempotentEpochAndResetIdIfNeeded();
         assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch());
         assertTrue(transactionManager.hasUnresolvedSequences());
 
-        // When the last inflight batch fails, we have to reset the producerId
+        // When the last inflight batch fails, we have to bump the epoch
         transactionManager.handleFailedBatch(b3, new TimeoutException(), false);
-        transactionManager.resetIdempotentProducerIdIfNeeded();
-        assertFalse(transactionManager.hasProducerId());
+
+        // Run sender loop to trigger epoch bump
+        runUntil(() -> transactionManager.producerIdAndEpoch().epoch == 6);
         assertFalse(transactionManager.hasUnresolvedSequences());
         assertEquals(0, transactionManager.sequenceNumber(tp0).intValue());
     }
 
     @Test
+    public void testAbortTransactionAndReuseSequenceNumberOnError() throws InterruptedException {
+        final long pid = 13131L;
+        final short epoch = 1;
+
+        apiVersions.update("0", new NodeApiVersions(Arrays.asList(
+                new ApiVersion(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 1),
+                new ApiVersion(ApiKeys.PRODUCE.id, (short) 0, (short) 7))));
+
+        doInitTransactions(pid, epoch);
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+
+        Future<RecordMetadata> responseFuture0 = appendToAccumulator(tp0);
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, pid);
+        prepareProduceResponse(Errors.NONE, pid, epoch);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));  // Send AddPartitionsRequest
+        runUntil(responseFuture0::isDone);
+
+        Future<RecordMetadata> responseFuture1 = appendToAccumulator(tp0);
+        prepareProduceResponse(Errors.NONE, pid, epoch);
+        runUntil(responseFuture1::isDone);
+
+        Future<RecordMetadata> responseFuture2 = appendToAccumulator(tp0);
+        prepareProduceResponse(Errors.TOPIC_AUTHORIZATION_FAILED, pid, epoch);
+        runUntil(responseFuture2::isDone); // Receive abortable error
+
+        assertTrue(transactionManager.hasAbortableError());
+
+        TransactionalRequestResult abortResult = transactionManager.beginAbort();
+        prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, pid, epoch);
+        runUntil(abortResult::isCompleted);
+        assertTrue(abortResult.isSuccessful());
+        assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, pid);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));  // Send AddPartitionsRequest
+
+        assertEquals(2, transactionManager.sequenceNumber(tp0).intValue());
+    }
+
+    @Test
+    public void testAbortTransactionAndResetSequenceNumberOnUnknownProducerId() throws InterruptedException {
+        final long pid = 13131L;
+        final short epoch = 1;
+
+        // Set the InitProducerId version such that bumping the epoch number is not supported. This will test the case
+        // where the sequence number is reset on an UnknownProducerId error, allowing subsequent transactions to
+        // append to the log successfully
+        apiVersions.update("0", new NodeApiVersions(Arrays.asList(
+                new ApiVersion(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 1),
+                new ApiVersion(ApiKeys.PRODUCE.id, (short) 0, (short) 7))));
+
+        doInitTransactions(pid, epoch);
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+
+        transactionManager.maybeAddPartitionToTransaction(tp1);
+        Future<RecordMetadata> successPartitionResponseFuture = appendToAccumulator(tp1);
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp1, epoch, pid);
+        prepareProduceResponse(Errors.NONE, pid, epoch, tp1);
+        runUntil(successPartitionResponseFuture::isDone);
+        assertTrue(transactionManager.isPartitionAdded(tp1));
+
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+        Future<RecordMetadata> responseFuture0 = appendToAccumulator(tp0);
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, pid);
+        prepareProduceResponse(Errors.NONE, pid, epoch);
+        runUntil(responseFuture0::isDone);
+        assertTrue(transactionManager.isPartitionAdded(tp0));
+
+        Future<RecordMetadata> responseFuture1 = appendToAccumulator(tp0);
+        prepareProduceResponse(Errors.NONE, pid, epoch);
+        runUntil(responseFuture1::isDone);
+
+        Future<RecordMetadata> responseFuture2 = appendToAccumulator(tp0);
+        client.prepareResponse(produceRequestMatcher(pid, epoch, tp0),
+                produceResponse(tp0, 0, Errors.UNKNOWN_PRODUCER_ID, 0, 0));
+        runUntil(responseFuture2::isDone);
+
+        assertTrue(transactionManager.hasAbortableError());
+
+        TransactionalRequestResult abortResult = transactionManager.beginAbort();
+        prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, pid, epoch);
+        runUntil(abortResult::isCompleted);
+        assertTrue(abortResult.isSuccessful());
+        assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, pid);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));
+
+        assertEquals(0, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(1, transactionManager.sequenceNumber(tp1).intValue());
+    }
+
+    @Test
+    public void testBumpTransactionalEpochOnAbortableError() throws InterruptedException {
+        final long pid = 13131L;
+        final short initialEpoch = 1;
+        final short bumpedEpoch = initialEpoch + 1;
+
+        doInitTransactions(pid, initialEpoch);
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, pid);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));
+
+        Future<RecordMetadata> responseFuture0 = appendToAccumulator(tp0);
+        prepareProduceResponse(Errors.NONE, pid, initialEpoch);
+        runUntil(responseFuture0::isDone);
+
+        Future<RecordMetadata> responseFuture1 = appendToAccumulator(tp0);
+        prepareProduceResponse(Errors.NONE, pid, initialEpoch);
+        runUntil(responseFuture1::isDone);
+
+        Future<RecordMetadata> responseFuture2 = appendToAccumulator(tp0);
+        prepareProduceResponse(Errors.TOPIC_AUTHORIZATION_FAILED, pid, initialEpoch);
+        runUntil(responseFuture2::isDone);
+
+        assertTrue(transactionManager.hasAbortableError());
+        TransactionalRequestResult abortResult = transactionManager.beginAbort();
+
+        prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, pid, initialEpoch);
+        prepareInitPidResponse(Errors.NONE, false, pid, bumpedEpoch);
+        runUntil(() -> transactionManager.producerIdAndEpoch().epoch == bumpedEpoch);
+
+        assertTrue(abortResult.isCompleted());
+        assertTrue(abortResult.isSuccessful());
+        assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, bumpedEpoch, pid);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));
+
+        assertEquals(0, transactionManager.sequenceNumber(tp0).intValue());
+    }
+
+    @Test
+    public void testBumpTransactionalEpochOnUnknownProducerIdError() throws InterruptedException {
+        final long pid = 13131L;
+        final short initialEpoch = 1;
+        final short bumpedEpoch = 2;
+
+        doInitTransactions(pid, initialEpoch);
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, pid);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));
+
+        Future<RecordMetadata> responseFuture0 = appendToAccumulator(tp0);
+        prepareProduceResponse(Errors.NONE, pid, initialEpoch);
+        runUntil(responseFuture0::isDone);
+
+        Future<RecordMetadata> responseFuture1 = appendToAccumulator(tp0);
+        prepareProduceResponse(Errors.NONE, pid, initialEpoch);
+        runUntil(responseFuture1::isDone);
+
+        Future<RecordMetadata> responseFuture2 = appendToAccumulator(tp0);
+        client.prepareResponse(produceRequestMatcher(pid, initialEpoch, tp0),
+                produceResponse(tp0, 0, Errors.UNKNOWN_PRODUCER_ID, 0, 0));
+        runUntil(responseFuture2::isDone);
+
+        assertTrue(transactionManager.hasAbortableError());
+        TransactionalRequestResult abortResult = transactionManager.beginAbort();
+
+        prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, pid, initialEpoch);
+        prepareInitPidResponse(Errors.NONE, false, pid, bumpedEpoch);
+        runUntil(() -> transactionManager.producerIdAndEpoch().epoch == bumpedEpoch);
+
+        assertTrue(abortResult.isCompleted());
+        assertTrue(abortResult.isSuccessful());
+        assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, bumpedEpoch, pid);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));
+
+        assertEquals(0, transactionManager.sequenceNumber(tp0).intValue());
+    }
+
+    @Test
+    public void testBumpTransactionalEpochOnTimeout() throws InterruptedException {
+        final long pid = 13131L;
+        final short initialEpoch = 1;
+        final short bumpedEpoch = 2;
+
+        doInitTransactions(pid, initialEpoch);
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, pid);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));
+
+        Future<RecordMetadata> responseFuture0 = appendToAccumulator(tp0);
+        prepareProduceResponse(Errors.NONE, pid, initialEpoch);
+        runUntil(responseFuture0::isDone);
+
+        Future<RecordMetadata> responseFuture1 = appendToAccumulator(tp0);
+        prepareProduceResponse(Errors.NONE, pid, initialEpoch);
+        runUntil(responseFuture1::isDone);
+
+        Future<RecordMetadata> responseFuture2 = appendToAccumulator(tp0);
+        runUntil(client::hasInFlightRequests); // Send Produce Request
+
+        // Sleep 10 seconds to make sure that the batches in the queue would be expired if they can't be drained.
+        time.sleep(10000);
+        // Disconnect the target node for the pending produce request. This will ensure that sender will try to
+        // expire the batch.
+        Node clusterNode = metadata.fetch().nodes().get(0);
+        client.disconnect(clusterNode.idString());
+        client.blackout(clusterNode, 100);
+
+        runUntil(responseFuture2::isDone); // We should try to flush the produce, but expire it instead without sending anything.
+
+        assertTrue(transactionManager.hasAbortableError());
+        TransactionalRequestResult abortResult = transactionManager.beginAbort();
+
+        sender.runOnce();  // handle the abort
+        time.sleep(110);  // Sleep to make sure the node blackout period has passed
+
+        prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId);
+        prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, pid, initialEpoch);
+        prepareInitPidResponse(Errors.NONE, false, pid, bumpedEpoch);
+        runUntil(() -> transactionManager.producerIdAndEpoch().epoch == bumpedEpoch);
+
+        assertTrue(abortResult.isCompleted());
+        assertTrue(abortResult.isSuccessful());
+        assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, bumpedEpoch, pid);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));
+
+        assertEquals(0, transactionManager.sequenceNumber(tp0).intValue());
+    }
+
+    @Test
+    public void testBumpTransactionalEpochOnRecoverableAddPartitionRequestError() {
+        final long producerId = 13131L;
+        final short initialEpoch = 1;
+        final short bumpedEpoch = 2;
+
+        doInitTransactions(producerId, initialEpoch);
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+        prepareAddPartitionsToTxnResponse(Errors.INVALID_PRODUCER_ID_MAPPING, tp0, initialEpoch, producerId);
+        runUntil(transactionManager::hasAbortableError);
+        TransactionalRequestResult abortResult = transactionManager.beginAbort();
+
+        prepareInitPidResponse(Errors.NONE, false, producerId, bumpedEpoch);
+        runUntil(abortResult::isCompleted);
+        assertEquals(bumpedEpoch, transactionManager.producerIdAndEpoch().epoch);
+        assertTrue(abortResult.isSuccessful());
+        assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
+    }
+
+    @Test
+    public void testBumpTransactionalEpochOnRecoverableAddOffsetsRequestError() throws InterruptedException {
+        final long producerId = 13131L;
+        final short initialEpoch = 1;
+        final short bumpedEpoch = 2;
+        final String consumerGroupId = "myconsumergroup";
+
+        doInitTransactions(producerId, initialEpoch);
+
+        transactionManager.beginTransaction();
+        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartitionToTransaction(tp0);
+
+        Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
+
+        assertFalse(responseFuture.isDone());
+        prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, producerId);
+        prepareProduceResponse(Errors.NONE, producerId, initialEpoch);
+        runUntil(responseFuture::isDone);
+
+        Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
+        offsets.put(tp0, new OffsetAndMetadata(1));
+        transactionManager.sendOffsetsToTransaction(offsets, new ConsumerGroupMetadata(consumerGroupId));
+        assertFalse(transactionManager.hasPendingOffsetCommits());
+        prepareAddOffsetsToTxnResponse(Errors.INVALID_PRODUCER_ID_MAPPING, consumerGroupId, producerId, initialEpoch);
+        runUntil(transactionManager::hasAbortableError);  // Send AddOffsetsRequest
+        TransactionalRequestResult abortResult = transactionManager.beginAbort();
+
+        prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, initialEpoch);
+        prepareInitPidResponse(Errors.NONE, false, producerId, bumpedEpoch);
+        runUntil(abortResult::isCompleted);
+        assertEquals(bumpedEpoch, transactionManager.producerIdAndEpoch().epoch);
+        assertTrue(abortResult.isSuccessful());
+        assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
+    }
+
+    @Test
+    public void testHealthyPartitionRetriesDuringEpochBump() throws InterruptedException {
+        final long producerId = 13131L;
+        final short epoch = 1;
+
+        // Use a custom Sender to allow multiple inflight requests
+        initializeTransactionManager(Optional.empty());
+        Sender sender = new Sender(logContext, this.client, this.metadata, this.accumulator, false,
+                MAX_REQUEST_SIZE, ACKS_ALL, MAX_RETRIES, new SenderMetricsRegistry(new Metrics(time)), this.time,
+                REQUEST_TIMEOUT, 50, transactionManager, apiVersions);
+        initializeIdempotentProducerId(producerId, epoch);
+
+        ProducerBatch tp0b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1");
+        ProducerBatch tp0b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
+        writeIdempotentBatchWithValue(transactionManager, tp0, "3");
+        ProducerBatch tp1b1 = writeIdempotentBatchWithValue(transactionManager, tp1, "4");
+        ProducerBatch tp1b2 = writeIdempotentBatchWithValue(transactionManager, tp1, "5");
+        assertEquals(3, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(2, transactionManager.sequenceNumber(tp1).intValue());
+
+        // First batch of each partition succeeds
+        long b1AppendTime = time.milliseconds();
+        ProduceResponse.PartitionResponse t0b1Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, b1AppendTime, 0L);
+        tp0b1.done(500L, b1AppendTime, null);
+        transactionManager.handleCompletedBatch(tp0b1, t0b1Response);
+
+        ProduceResponse.PartitionResponse t1b1Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, b1AppendTime, 0L);
+        tp1b1.done(500L, b1AppendTime, null);
+        transactionManager.handleCompletedBatch(tp1b1, t1b1Response);
+
+        // We bump the epoch and set sequence numbers back to 0
+        ProduceResponse.PartitionResponse t0b2Response = new ProduceResponse.PartitionResponse(
+                Errors.UNKNOWN_PRODUCER_ID, -1, -1, 500L);
+        assertTrue(transactionManager.canRetry(t0b2Response, tp0b2));
+
+        // Run sender loop to trigger epoch bump
+        runUntil(() -> transactionManager.producerIdAndEpoch().epoch == 2);
+
+        // tp0 batches should have had sequence and epoch rewritten, but tp1 batches should not
+        assertEquals(tp0b2, transactionManager.nextBatchBySequence(tp0));
+        assertEquals(0, transactionManager.firstInFlightSequence(tp0));
+        assertEquals(0, tp0b2.baseSequence());
+        assertTrue(tp0b2.sequenceHasBeenReset());
+        assertEquals(2, tp0b2.producerEpoch());
+
+        assertEquals(tp1b2, transactionManager.nextBatchBySequence(tp1));
+        assertEquals(1, transactionManager.firstInFlightSequence(tp1));
+        assertEquals(1, tp1b2.baseSequence());
+        assertFalse(tp1b2.sequenceHasBeenReset());
+        assertEquals(1, tp1b2.producerEpoch());
+
+        // New tp1 batches should not be drained from the accumulator while tp1 has in-flight requests using the old epoch
+        appendToAccumulator(tp1);
+        sender.runOnce();
+        assertEquals(1, accumulator.batches().get(tp1).size());
+
+        // Partition failover occurs and tp1 returns a NOT_LEADER_FOR_PARTITION error
+        // Despite having the old epoch, the batch should retry
+        ProduceResponse.PartitionResponse t1b2Response = new ProduceResponse.PartitionResponse(
+                Errors.NOT_LEADER_FOR_PARTITION, -1, -1, 600L);
+        assertTrue(transactionManager.canRetry(t1b2Response, tp1b2));
+        accumulator.reenqueue(tp1b2, time.milliseconds());
+
+        // The batch with the old epoch should be successfully drained, leaving the new one in the queue
+        sender.runOnce();
+        assertEquals(1, accumulator.batches().get(tp1).size());
+        assertNotEquals(tp1b2, accumulator.batches().get(tp1).peek());
+        assertEquals(epoch, tp1b2.producerEpoch());
+
+        // After successfully retrying, there should be no in-flight batches for tp1 and the sequence should be 0
+        t1b2Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, b1AppendTime, 0L);
+        tp1b2.done(500L, b1AppendTime, null);
+        transactionManager.handleCompletedBatch(tp1b2, t1b2Response);
+
+        assertFalse(transactionManager.hasInflightBatches(tp1));
+        assertEquals(0, transactionManager.sequenceNumber(tp1).intValue());
+
+        // The last batch should now be drained and sent
+        runUntil(() -> transactionManager.hasInflightBatches(tp1));
+        assertTrue(accumulator.batches().get(tp1).isEmpty());
+        ProducerBatch tp1b3 = transactionManager.nextBatchBySequence(tp1);
+        assertEquals(epoch + 1, tp1b3.producerEpoch());
+
+        ProduceResponse.PartitionResponse t1b3Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, b1AppendTime, 0L);
+        tp1b3.done(500L, b1AppendTime, null);
+        transactionManager.handleCompletedBatch(tp1b3, t1b3Response);
+
+        assertFalse(transactionManager.hasInflightBatches(tp1));
+        assertEquals(1, transactionManager.sequenceNumber(tp1).intValue());
+    }
+
+    @Test
     public void testRetryAbortTransaction() throws InterruptedException {
         verifyCommitOrAbortTransactionRetriable(TransactionResult.ABORT, TransactionResult.ABORT);
     }
@@ -2674,6 +3096,111 @@ public class TransactionManagerTest {
         verifyCommitOrAbortTransactionRetriable(TransactionResult.ABORT, TransactionResult.COMMIT);
     }
 
+    @Test
+    public void testCanBumpEpochDuringCoordinatorDisconnect() {
+        doInitTransactions(0, (short) 0);
+        runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null);
+        assertTrue(transactionManager.canBumpEpoch());
+
+        apiVersions.remove(transactionManager.coordinator(CoordinatorType.TRANSACTION).idString());
+        assertTrue(transactionManager.canBumpEpoch());
+    }
+
+    @Test
+    public void testFailedInflightBatchAfterEpochBump() throws InterruptedException {
+        final long producerId = 13131L;
+        final short epoch = 1;
+
+        // Use a custom Sender to allow multiple inflight requests
+        initializeTransactionManager(Optional.empty());
+        Sender sender = new Sender(logContext, this.client, this.metadata, this.accumulator, false,
+                MAX_REQUEST_SIZE, ACKS_ALL, MAX_RETRIES, new SenderMetricsRegistry(new Metrics(time)), this.time,
+                REQUEST_TIMEOUT, 50, transactionManager, apiVersions);
+        initializeIdempotentProducerId(producerId, epoch);
+
+        ProducerBatch tp0b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1");
+        ProducerBatch tp0b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
+        writeIdempotentBatchWithValue(transactionManager, tp0, "3");
+        ProducerBatch tp1b1 = writeIdempotentBatchWithValue(transactionManager, tp1, "4");
+        ProducerBatch tp1b2 = writeIdempotentBatchWithValue(transactionManager, tp1, "5");
+        assertEquals(3, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(2, transactionManager.sequenceNumber(tp1).intValue());
+
+        // First batch of each partition succeeds
+        long b1AppendTime = time.milliseconds();
+        ProduceResponse.PartitionResponse t0b1Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, b1AppendTime, 0L);
+        tp0b1.done(500L, b1AppendTime, null);
+        transactionManager.handleCompletedBatch(tp0b1, t0b1Response);
+
+        ProduceResponse.PartitionResponse t1b1Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, b1AppendTime, 0L);
+        tp1b1.done(500L, b1AppendTime, null);
+        transactionManager.handleCompletedBatch(tp1b1, t1b1Response);
+
+        // We bump the epoch and set sequence numbers back to 0
+        ProduceResponse.PartitionResponse t0b2Response = new ProduceResponse.PartitionResponse(
+                Errors.UNKNOWN_PRODUCER_ID, -1, -1, 500L);
+        assertTrue(transactionManager.canRetry(t0b2Response, tp0b2));
+
+        // Run sender loop to trigger epoch bump
+        runUntil(() -> transactionManager.producerIdAndEpoch().epoch == 2);
+
+        // tp0 batches should have had sequence and epoch rewritten, but tp1 batches should not
+        assertEquals(tp0b2, transactionManager.nextBatchBySequence(tp0));
+        assertEquals(0, transactionManager.firstInFlightSequence(tp0));
+        assertEquals(0, tp0b2.baseSequence());
+        assertTrue(tp0b2.sequenceHasBeenReset());
+        assertEquals(2, tp0b2.producerEpoch());
+
+        assertEquals(tp1b2, transactionManager.nextBatchBySequence(tp1));
+        assertEquals(1, transactionManager.firstInFlightSequence(tp1));
+        assertEquals(1, tp1b2.baseSequence());
+        assertFalse(tp1b2.sequenceHasBeenReset());
+        assertEquals(1, tp1b2.producerEpoch());
+
+        // New tp1 batches should not be drained from the accumulator while tp1 has in-flight requests using the old epoch
+        appendToAccumulator(tp1);
+        sender.runOnce();
+        assertEquals(1, accumulator.batches().get(tp1).size());
+
+        // Partition failover occurs and tp1 returns a NOT_LEADER_FOR_PARTITION error
+        // Despite having the old epoch, the batch should retry
+        ProduceResponse.PartitionResponse t1b2Response = new ProduceResponse.PartitionResponse(
+                Errors.NOT_LEADER_FOR_PARTITION, -1, -1, 600L);
+        assertTrue(transactionManager.canRetry(t1b2Response, tp1b2));
+        accumulator.reenqueue(tp1b2, time.milliseconds());
+
+        // The batch with the old epoch should be successfully drained, leaving the new one in the queue
+        sender.runOnce();
+        assertEquals(1, accumulator.batches().get(tp1).size());
+        assertNotEquals(tp1b2, accumulator.batches().get(tp1).peek());
+        assertEquals(epoch, tp1b2.producerEpoch());
+
+        // After successfully retrying, there should be no in-flight batches for tp1 and the sequence should be 0
+        t1b2Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, b1AppendTime, 0L);
+        tp1b2.done(500L, b1AppendTime, null);
+        transactionManager.handleCompletedBatch(tp1b2, t1b2Response);
+
+        assertFalse(transactionManager.hasInflightBatches(tp1));
+        assertEquals(0, transactionManager.sequenceNumber(tp1).intValue());
+
+        // The last batch should now be drained and sent
+        runUntil(() -> transactionManager.hasInflightBatches(tp1));
+        assertTrue(accumulator.batches().get(tp1).isEmpty());
+        ProducerBatch tp1b3 = transactionManager.nextBatchBySequence(tp1);
+        assertEquals(epoch + 1, tp1b3.producerEpoch());
+
+        ProduceResponse.PartitionResponse t1b3Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, b1AppendTime, 0L);
+        tp1b3.done(500L, b1AppendTime, null);
+        transactionManager.handleCompletedBatch(tp1b3, t1b3Response);
+
+        assertFalse(transactionManager.hasInflightBatches(tp1));
+        assertEquals(1, transactionManager.sequenceNumber(tp1).intValue());
+    }
+
     private FutureRecordMetadata appendToAccumulator(TopicPartition tp) throws InterruptedException {
         final long nowMs = time.milliseconds();
         return accumulator.append(tp, nowMs, "key".getBytes(), "value".getBytes(), Record.EMPTY_HEADERS,
@@ -2776,17 +3303,29 @@ public class TransactionManagerTest {
     }
 
     private void sendProduceResponse(Errors error, final long producerId, final short producerEpoch) {
-        client.respond(produceRequestMatcher(producerId, producerEpoch), produceResponse(tp0, 0, error, 0));
+        sendProduceResponse(error, producerId, producerEpoch, tp0);
+    }
+
+    private void sendProduceResponse(Errors error, final long producerId, final short producerEpoch, TopicPartition tp) {
+        client.respond(produceRequestMatcher(producerId, producerEpoch, tp), produceResponse(tp, 0, error, 0));
     }
 
     private void prepareProduceResponse(Errors error, final long producerId, final short producerEpoch) {
-        client.prepareResponse(produceRequestMatcher(producerId, producerEpoch), produceResponse(tp0, 0, error, 0));
+        prepareProduceResponse(error, producerId, producerEpoch, tp0);
+    }
+
+    private void prepareProduceResponse(Errors error, final long producerId, final short producerEpoch, TopicPartition tp) {
+        client.prepareResponse(produceRequestMatcher(producerId, producerEpoch, tp), produceResponse(tp, 0, error, 0));
     }
 
     private MockClient.RequestMatcher produceRequestMatcher(final long pid, final short epoch) {
+        return produceRequestMatcher(pid, epoch, tp0);
+    }
+
+    private MockClient.RequestMatcher produceRequestMatcher(final long pid, final short epoch, TopicPartition tp) {
         return body -> {
             ProduceRequest produceRequest = (ProduceRequest) body;
-            MemoryRecords records = produceRequest.partitionRecordsOrFail().get(tp0);
+            MemoryRecords records = produceRequest.partitionRecordsOrFail().get(tp);
             assertNotNull(records);
             Iterator<MutableRecordBatch> batchIterator = records.batches().iterator();
             assertTrue(batchIterator.hasNext());
@@ -2905,7 +3444,11 @@ public class TransactionManagerTest {
     }
 
     private ProduceResponse produceResponse(TopicPartition tp, long offset, Errors error, int throttleTimeMs) {
-        ProduceResponse.PartitionResponse resp = new ProduceResponse.PartitionResponse(error, offset, RecordBatch.NO_TIMESTAMP, 10);
+        return produceResponse(tp, offset, error, throttleTimeMs, 10);
+    }
+
+    private ProduceResponse produceResponse(TopicPartition tp, long offset, Errors error, int throttleTimeMs, int logStartOffset) {
+        ProduceResponse.PartitionResponse resp = new ProduceResponse.PartitionResponse(error, offset, RecordBatch.NO_TIMESTAMP, logStartOffset);
         Map<TopicPartition, ProduceResponse.PartitionResponse> partResp = singletonMap(tp, resp);
         return new ProduceResponse(partResp, throttleTimeMs);
     }
diff --git a/core/src/test/scala/integration/kafka/api/TransactionsBounceTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsBounceTest.scala
index 89cd474..861f6e1 100644
--- a/core/src/test/scala/integration/kafka/api/TransactionsBounceTest.scala
+++ b/core/src/test/scala/integration/kafka/api/TransactionsBounceTest.scala
@@ -113,8 +113,7 @@ class TransactionsBounceTest extends KafkaServerTestHarness {
         producer.beginTransaction()
         val shouldAbort = iteration % 3 == 0
         records.foreach { record =>
-          producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(outputTopic, record.key, record.value,
-            !shouldAbort), new ErrorLoggingCallback(outputTopic, record.key, record.value, true))
+          producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(outputTopic, null, record.key, record.value, !shouldAbort), new ErrorLoggingCallback(outputTopic, record.key, record.value, true))
         }
         trace(s"Sent ${records.size} messages. Committing offsets.")
         commit(producer, consumerGroup, consumer)
diff --git a/core/src/test/scala/integration/kafka/api/TransactionsExpirationTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsExpirationTest.scala
new file mode 100644
index 0000000..0492286
--- /dev/null
+++ b/core/src/test/scala/integration/kafka/api/TransactionsExpirationTest.scala
@@ -0,0 +1,122 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.api
+
+import java.util.Properties
+
+import kafka.integration.KafkaServerTestHarness
+import kafka.server.KafkaConfig
+import kafka.utils.TestUtils
+import kafka.utils.TestUtils.consumeRecords
+import org.apache.kafka.clients.consumer.KafkaConsumer
+import org.apache.kafka.clients.producer.KafkaProducer
+import org.apache.kafka.common.errors.InvalidPidMappingException
+import org.junit.{After, Before, Test}
+
+import scala.collection.JavaConverters._
+import scala.collection.Seq
+
+// Test class that uses a very small transaction timeout to trigger InvalidPidMapping errors
+class TransactionsExpirationTest extends KafkaServerTestHarness {
+  val topic1 = "topic1"
+  val topic2 = "topic2"
+  val numPartitions = 4
+  val replicationFactor = 3
+
+  var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
+  var consumer: KafkaConsumer[Array[Byte], Array[Byte]] = _
+
+  override def generateConfigs: Seq[KafkaConfig] = {
+    TestUtils.createBrokerConfigs(3, zkConnect).map(KafkaConfig.fromProps(_, serverProps()))
+  }
+
+  @Before
+  override def setUp(): Unit = {
+    super.setUp()
+
+    producer = TestUtils.createTransactionalProducer("transactionalProducer", servers)
+    consumer = TestUtils.createConsumer(TestUtils.getBrokerListStrFromServers(servers),
+      enableAutoCommit = false,
+      readCommitted = true)
+
+    TestUtils.createTopic(zkClient, topic1, numPartitions, 3, servers, new Properties())
+    TestUtils.createTopic(zkClient, topic2, numPartitions, 3, servers, new Properties())
+  }
+
+  @After
+  override def tearDown(): Unit = {
+    producer.close()
+    consumer.close()
+
+    super.tearDown()
+  }
+
+  @Test
+  def testBumpTransactionalEpochAfterInvalidProducerIdMapping(): Unit = {
+    producer.initTransactions()
+
+    // Start and then abort a transaction to allow the transactional ID to expire
+    producer.beginTransaction()
+    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, 0, "2", "2", willBeCommitted = false))
+    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, 0, "4", "4", willBeCommitted = false))
+    producer.abortTransaction()
+
+    // Wait for the transactional ID to expire
+    Thread.sleep(3000)
+
+    // Start a new transaction and attempt to send, which will trigger an AddPartitionsToTxnRequest, which will fail due to the expired producer ID
+    producer.beginTransaction()
+    val failedFuture = producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, 3, "1", "1", willBeCommitted = false))
+    Thread.sleep(500)
+
+    org.apache.kafka.test.TestUtils.assertFutureThrows(failedFuture, classOf[InvalidPidMappingException])
+    producer.abortTransaction()
+
+    producer.beginTransaction()
+    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "2", willBeCommitted = true))
+    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, 2, "4", "4", willBeCommitted = true))
+    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "1", "1", willBeCommitted = true))
+    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, 3, "3", "3", willBeCommitted = true))
+    producer.commitTransaction()
+
+    consumer.subscribe(List(topic1, topic2).asJava)
+
+    val records = consumeRecords(consumer, 4)
+    records.foreach { record =>
+      TestUtils.assertCommittedAndGetValue(record)
+    }
+  }
+  private def serverProps() = {
+    val serverProps = new Properties()
+    serverProps.put(KafkaConfig.AutoCreateTopicsEnableProp, false.toString)
+    // Set a smaller value for the number of partitions for the __consumer_offsets topic
+    // so that the creation of that topic/partition(s) and subsequent leader assignment doesn't take relatively long
+    serverProps.put(KafkaConfig.OffsetsTopicPartitionsProp, 1.toString)
+    serverProps.put(KafkaConfig.TransactionsTopicPartitionsProp, 3.toString)
+    serverProps.put(KafkaConfig.TransactionsTopicReplicationFactorProp, 2.toString)
+    serverProps.put(KafkaConfig.TransactionsTopicMinISRProp, 2.toString)
+    serverProps.put(KafkaConfig.ControlledShutdownEnableProp, true.toString)
+    serverProps.put(KafkaConfig.UncleanLeaderElectionEnableProp, false.toString)
+    serverProps.put(KafkaConfig.AutoLeaderRebalanceEnableProp, false.toString)
+    serverProps.put(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0")
+    serverProps.put(KafkaConfig.TransactionsAbortTimedOutTransactionCleanupIntervalMsProp, "200")
+    serverProps.put(KafkaConfig.TransactionalIdExpirationMsProp, "2000")
+    serverProps.put(KafkaConfig.TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp, "500")
+    serverProps
+  }
+}
diff --git a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
index 2ea9f7c..7b2a945 100644
--- a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
+++ b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
@@ -20,24 +20,24 @@ package kafka.api
 import java.lang.{Long => JLong}
 import java.nio.charset.StandardCharsets
 import java.time.Duration
-import java.util.{Optional, Properties}
 import java.util.concurrent.TimeUnit
+import java.util.{Optional, Properties}
 
 import kafka.integration.KafkaServerTestHarness
 import kafka.server.KafkaConfig
 import kafka.utils.TestUtils
 import kafka.utils.TestUtils.consumeRecords
-import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerGroupMetadata, KafkaConsumer, OffsetAndMetadata}
+import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer, OffsetAndMetadata}
 import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord}
-import org.apache.kafka.common.{KafkaException, TopicPartition}
 import org.apache.kafka.common.errors.{ProducerFencedException, TimeoutException}
-import org.junit.{After, Before, Test}
+import org.apache.kafka.common.{KafkaException, TopicPartition}
 import org.junit.Assert._
+import org.junit.{After, Before, Test}
 import org.scalatest.Assertions.fail
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable.Buffer
 import scala.collection.Seq
+import scala.collection.mutable.Buffer
 import scala.concurrent.ExecutionException
 
 class TransactionsTest extends KafkaServerTestHarness {
@@ -48,6 +48,7 @@ class TransactionsTest extends KafkaServerTestHarness {
 
   val topic1 = "topic1"
   val topic2 = "topic2"
+  val numPartitions = 4
 
   val transactionalProducers = Buffer[KafkaProducer[Array[Byte], Array[Byte]]]()
   val transactionalConsumers = Buffer[KafkaConsumer[Array[Byte], Array[Byte]]]()
@@ -60,7 +61,6 @@ class TransactionsTest extends KafkaServerTestHarness {
   @Before
   override def setUp(): Unit = {
     super.setUp()
-    val numPartitions = 4
     val topicConfig = new Properties()
     topicConfig.put(KafkaConfig.MinInSyncReplicasProp, 2.toString)
     createTopic(topic1, numPartitions, numServers, topicConfig)
@@ -91,14 +91,14 @@ class TransactionsTest extends KafkaServerTestHarness {
     producer.initTransactions()
 
     producer.beginTransaction()
-    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "2", "2", willBeCommitted = false))
-    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "4", "4", willBeCommitted = false))
+    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "2", willBeCommitted = false))
+    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "4", "4", willBeCommitted = false))
     producer.flush()
     producer.abortTransaction()
 
     producer.beginTransaction()
-    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "1", "1", willBeCommitted = true))
-    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "3", "3", willBeCommitted = true))
+    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = true))
+    producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = true))
     producer.commitTransaction()
 
     consumer.subscribe(List(topic1, topic2).asJava)
@@ -274,7 +274,7 @@ class TransactionsTest extends KafkaServerTestHarness {
         records.foreach { record =>
           val key = new String(record.key(), StandardCharsets.UTF_8)
           val value = new String(record.value(), StandardCharsets.UTF_8)
-          producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, key, value, willBeCommitted = shouldCommit))
+          producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, key, value, willBeCommitted = shouldCommit))
         }
 
         commit(producer, consumerGroupId, consumer)
@@ -317,13 +317,13 @@ class TransactionsTest extends KafkaServerTestHarness {
     producer1.initTransactions()
 
     producer1.beginTransaction()
-    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "1", "1", willBeCommitted = false))
-    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "3", "3", willBeCommitted = false))
+    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = false))
+    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = false))
 
     producer2.initTransactions()  // ok, will abort the open transaction.
     producer2.beginTransaction()
-    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "2", "4", willBeCommitted = true))
-    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "2", "4", willBeCommitted = true))
+    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "4", willBeCommitted = true))
+    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "4", willBeCommitted = true))
 
     try {
       producer1.commitTransaction()
@@ -354,13 +354,13 @@ class TransactionsTest extends KafkaServerTestHarness {
     producer1.initTransactions()
 
     producer1.beginTransaction()
-    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "1", "1", willBeCommitted = false))
-    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "3", "3", willBeCommitted = false))
+    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = false))
+    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = false))
 
     producer2.initTransactions()  // ok, will abort the open transaction.
     producer2.beginTransaction()
-    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "2", "4", willBeCommitted = true))
-    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "2", "4", willBeCommitted = true))
+    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "4", willBeCommitted = true))
+    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "4", willBeCommitted = true))
 
     try {
       producer1.sendOffsetsToTransaction(Map(new TopicPartition("foobartopic", 0) -> new OffsetAndMetadata(110L)).asJava,
@@ -417,16 +417,16 @@ class TransactionsTest extends KafkaServerTestHarness {
     producer1.initTransactions()
 
     producer1.beginTransaction()
-    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "1", "1", willBeCommitted = false))
-    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "3", "3", willBeCommitted = false))
+    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = false))
+    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = false))
 
     producer2.initTransactions()  // ok, will abort the open transaction.
     producer2.beginTransaction()
-    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "2", "4", willBeCommitted = true)).get()
-    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "2", "4", willBeCommitted = true)).get()
+    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "4", willBeCommitted = true)).get()
+    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "4", willBeCommitted = true)).get()
 
     try {
-      val result =  producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "1", "5", willBeCommitted = false))
+      val result =  producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "5", willBeCommitted = false))
       val recordMetadata = result.get()
       error(s"Missed a producer fenced exception when writing to ${recordMetadata.topic}-${recordMetadata.partition}. Grab the logs!!")
       servers.foreach { server =>
@@ -460,20 +460,20 @@ class TransactionsTest extends KafkaServerTestHarness {
 
     producer1.initTransactions()
     producer1.beginTransaction()
-    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "1", "1", willBeCommitted = false))
-    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "3", "3", willBeCommitted = false))
+    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = false))
+    producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = false))
     producer1.abortTransaction()
 
     producer2.initTransactions()  // ok, will abort the open transaction.
     producer2.beginTransaction()
-    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "2", "4", willBeCommitted = true))
+    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "4", willBeCommitted = true))
       .get(20, TimeUnit.SECONDS)
-    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "2", "4", willBeCommitted = true))
+    producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "4", willBeCommitted = true))
       .get(20, TimeUnit.SECONDS)
 
     try {
       producer1.beginTransaction()
-      val result =  producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "1", "5", willBeCommitted = false))
+      val result =  producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "5", willBeCommitted = false))
       val recordMetadata = result.get()
       error(s"Missed a producer fenced exception when writing to ${recordMetadata.topic}-${recordMetadata.partition}. Grab the logs!!")
       servers.foreach { server =>
@@ -504,7 +504,7 @@ class TransactionsTest extends KafkaServerTestHarness {
     producer.beginTransaction()
 
     // The first message and hence the first AddPartitions request should be successfully sent.
-    val firstMessageResult = producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "1", "1", willBeCommitted = false)).get()
+    val firstMessageResult = producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = false)).get()
     assertTrue(firstMessageResult.hasOffset)
 
     // Wait for the expiration cycle to kick in.
@@ -512,7 +512,7 @@ class TransactionsTest extends KafkaServerTestHarness {
 
     try {
       // Now that the transaction has expired, the second send should fail with a ProducerFencedException.
-      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "2", "2", willBeCommitted = false)).get()
+      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "2", willBeCommitted = false)).get()
       fail("should have raised a ProducerFencedException since the transaction has expired")
     } catch {
       case _: ProducerFencedException =>
@@ -603,10 +603,68 @@ class TransactionsTest extends KafkaServerTestHarness {
     }
   }
 
+  @Test
+  def testBumpTransactionalEpoch(): Unit = {
+    val producer = createTransactionalProducer("transactionalProducer", deliveryTimeoutMs = 5000)
+    val consumer = transactionalConsumers.head
+    try {
+      // Create a topic with RF=1 so that a single broker failure will render it unavailable
+      val testTopic = "test-topic"
+      createTopic(testTopic, numPartitions, 1, new Properties)
+      val partitionLeader = TestUtils.waitUntilLeaderIsKnown(servers, new TopicPartition(testTopic, 0))
+
+      producer.initTransactions()
+
+      producer.beginTransaction()
+      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(testTopic, 0, "4", "4", willBeCommitted = true))
+      producer.commitTransaction()
+
+      var producerStateEntry =
+        servers(partitionLeader).logManager.getLog(new TopicPartition(testTopic, 0)).get.producerStateManager.activeProducers.head._2
+      val producerId = producerStateEntry.producerId
+      val initialProducerEpoch = producerStateEntry.producerEpoch
+
+      producer.beginTransaction()
+      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "2", willBeCommitted = false))
+      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(testTopic, 0, "4", "4", willBeCommitted = false))
+
+      killBroker(partitionLeader) // kill the partition leader to prevent the batch from being submitted
+      val failedFuture = producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(testTopic, 0, "3", "3", willBeCommitted = false))
+      Thread.sleep(6000) // Wait for the record to time out
+      restartDeadBrokers()
+
+      org.apache.kafka.test.TestUtils.assertFutureThrows(failedFuture, classOf[TimeoutException])
+      producer.abortTransaction()
+
+      producer.beginTransaction()
+      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "2", willBeCommitted = true))
+      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "4", "4", willBeCommitted = true))
+      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(testTopic, 0, "1", "1", willBeCommitted = true))
+      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(testTopic, 0, "3", "3", willBeCommitted = true))
+      producer.commitTransaction()
+
+      consumer.subscribe(List(topic1, topic2, testTopic).asJava)
+
+      val records = consumeRecords(consumer, 5)
+      records.foreach { record =>
+        TestUtils.assertCommittedAndGetValue(record)
+      }
+
+      // Producers can safely abort and continue after the last record of a transaction timing out, so it's possible to
+      // get here without having bumped the epoch. If bumping the epoch is possible, the producer will attempt to, so
+      // check there that the epoch has actually increased
+      producerStateEntry =
+        servers(partitionLeader).logManager.getLog(new TopicPartition(testTopic, 0)).get.producerStateManager.activeProducers(producerId)
+      assertTrue(producerStateEntry.producerEpoch > initialProducerEpoch)
+    } finally {
+      producer.close(Duration.ZERO)
+    }
+  }
+
   private def sendTransactionalMessagesWithValueRange(producer: KafkaProducer[Array[Byte], Array[Byte]], topic: String,
                                                       start: Int, end: Int, willBeCommitted: Boolean): Unit = {
     for (i <- start until end) {
-      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic, i.toString, i.toString, willBeCommitted))
+      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic, null, value = i.toString, willBeCommitted = willBeCommitted, key = i.toString))
     }
     producer.flush()
   }
@@ -650,12 +708,13 @@ class TransactionsTest extends KafkaServerTestHarness {
 
   private def createTransactionalProducer(transactionalId: String,
                                           transactionTimeoutMs: Long = 60000,
-                                          maxBlockMs: Long = 60000): KafkaProducer[Array[Byte], Array[Byte]] = {
+                                          maxBlockMs: Long = 60000,
+                                          deliveryTimeoutMs: Int = 120000): KafkaProducer[Array[Byte], Array[Byte]] = {
     val producer = TestUtils.createTransactionalProducer(transactionalId, servers,
       transactionTimeoutMs = transactionTimeoutMs,
-      maxBlockMs = maxBlockMs)
+      maxBlockMs = maxBlockMs,
+      deliveryTimeoutMs = deliveryTimeoutMs)
     transactionalProducers += producer
     producer
   }
-
 }
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index 6edff6e..1debb28 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -597,7 +597,8 @@ object TestUtils extends Logging {
                            trustStoreFile: Option[File] = None,
                            saslProperties: Option[Properties] = None,
                            keySerializer: Serializer[K] = new ByteArraySerializer,
-                           valueSerializer: Serializer[V] = new ByteArraySerializer): KafkaProducer[K, V] = {
+                           valueSerializer: Serializer[V] = new ByteArraySerializer,
+                           enableIdempotence: Boolean = false): KafkaProducer[K, V] = {
     val producerProps = new Properties
     producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList)
     producerProps.put(ProducerConfig.ACKS_CONFIG, acks.toString)
@@ -609,6 +610,7 @@ object TestUtils extends Logging {
     producerProps.put(ProducerConfig.LINGER_MS_CONFIG, lingerMs.toString)
     producerProps.put(ProducerConfig.BATCH_SIZE_CONFIG, batchSize.toString)
     producerProps.put(ProducerConfig.COMPRESSION_TYPE_CONFIG, compressionType)
+    producerProps.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, enableIdempotence.toString)
     producerProps ++= producerSecurityConfigs(securityProtocol, trustStoreFile, saslProperties)
     new KafkaProducer[K, V](producerProps, keySerializer, valueSerializer)
   }
@@ -1376,7 +1378,8 @@ object TestUtils extends Logging {
                                   servers: Seq[KafkaServer],
                                   batchSize: Int = 16384,
                                   transactionTimeoutMs: Long = 60000,
-                                  maxBlockMs: Long = 60000) = {
+                                  maxBlockMs: Long = 60000,
+                                  deliveryTimeoutMs: Int = 120000) = {
     val props = new Properties()
     props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, TestUtils.getBrokerListStrFromServers(servers))
     props.put(ProducerConfig.ACKS_CONFIG, "all")
@@ -1385,6 +1388,8 @@ object TestUtils extends Logging {
     props.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true")
     props.put(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG, transactionTimeoutMs.toString)
     props.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, maxBlockMs.toString)
+    props.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, deliveryTimeoutMs.toString)
+    props.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, deliveryTimeoutMs.toString)
     new KafkaProducer[Array[Byte], Array[Byte]](props, new ByteArraySerializer, new ByteArraySerializer)
   }
 
@@ -1426,20 +1431,18 @@ object TestUtils extends Logging {
     asString(record.value)
   }
 
-  def producerRecordWithExpectedTransactionStatus(topic: String, key: Array[Byte], value: Array[Byte],
-                                                  willBeCommitted: Boolean) : ProducerRecord[Array[Byte], Array[Byte]] = {
+  def producerRecordWithExpectedTransactionStatus(topic: String, partition: Integer, key: Array[Byte], value: Array[Byte], willBeCommitted: Boolean): ProducerRecord[Array[Byte], Array[Byte]] = {
     val header = new Header {override def key() = transactionStatusKey
       override def value() = if (willBeCommitted)
         committedValue
       else
         abortedValue
     }
-    new ProducerRecord[Array[Byte], Array[Byte]](topic, null, key, value, Collections.singleton(header))
+    new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, key, value, Collections.singleton(header))
   }
 
-  def producerRecordWithExpectedTransactionStatus(topic: String, key: String, value: String,
-                                                  willBeCommitted: Boolean) : ProducerRecord[Array[Byte], Array[Byte]] = {
-    producerRecordWithExpectedTransactionStatus(topic, asBytes(key), asBytes(value), willBeCommitted)
+  def producerRecordWithExpectedTransactionStatus(topic: String, partition: Integer, key: String, value: String, willBeCommitted: Boolean): ProducerRecord[Array[Byte], Array[Byte]] = {
+    producerRecordWithExpectedTransactionStatus(topic, partition, asBytes(key), asBytes(value), willBeCommitted)
   }
 
   // Collect the current positions for all partition in the consumers current assignment.


Mime
View raw message