kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From j...@apache.org
Subject [2/2] kafka git commit: KAFKA-5283; Handle producer epoch/sequence overflow
Date Fri, 02 Jun 2017 06:41:41 GMT
KAFKA-5283; Handle producer epoch/sequence overflow

- Producer sequence numbers should wrap around
- Generate a new producerId if the producer epoch would overflow

Author: Jason Gustafson <jason@confluent.io>

Reviewers: Ismael Juma <ismael@juma.me.uk>, Apurva Mehta <apurva@confluent.io>, Guozhang Wang <wangguoz@gmail.com>

Closes #3183 from hachikuji/KAFKA-5283


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/1c882ee5
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/1c882ee5
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/1c882ee5

Branch: refs/heads/trunk
Commit: 1c882ee5fb4ef2d256c914bd69239d58d9706108
Parents: 0c3e466
Author: Jason Gustafson <jason@confluent.io>
Authored: Thu Jun 1 23:37:31 2017 -0700
Committer: Jason Gustafson <jason@confluent.io>
Committed: Thu Jun 1 23:37:36 2017 -0700

----------------------------------------------------------------------
 .../kafka/common/record/DefaultRecord.java      |   4 +-
 .../kafka/common/record/DefaultRecordBatch.java |  10 +-
 .../common/record/DefaultRecordBatchTest.java   |  36 ++++
 .../transaction/TransactionCoordinator.scala    |  70 +++----
 .../TransactionMarkerChannelManager.scala       |   4 +-
 ...nsactionMarkerRequestCompletionHandler.scala |   4 +-
 .../transaction/TransactionMetadata.scala       | 131 ++++++++-----
 .../transaction/TransactionStateManager.scala   |  78 ++++----
 .../scala/kafka/log/ProducerStateManager.scala  |  18 +-
 .../kafka/api/TransactionsTest.scala            |  10 +-
 .../TransactionCoordinatorTest.scala            | 146 ++++++++++----
 .../TransactionMarkerChannelManagerTest.scala   |   4 +-
 ...tionMarkerRequestCompletionHandlerTest.scala |   8 +-
 .../transaction/TransactionMetadataTest.scala   | 188 +++++++++++++++++++
 .../TransactionStateManagerTest.scala           |  66 +++----
 .../kafka/log/ProducerStateManagerTest.scala    |  27 +++
 16 files changed, 605 insertions(+), 199 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
index 5972b42..8910b30 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
@@ -323,7 +323,9 @@ public class DefaultRecord implements Record {
 
         int offsetDelta = ByteUtils.readVarint(buffer);
         long offset = baseOffset + offsetDelta;
-        int sequence = baseSequence >= 0 ? baseSequence + offsetDelta : RecordBatch.NO_SEQUENCE;
+        int sequence = baseSequence >= 0 ?
+                DefaultRecordBatch.incrementSequence(baseSequence, offsetDelta) :
+                RecordBatch.NO_SEQUENCE;
 
         ByteBuffer key = null;
         int keySize = ByteUtils.readVarint(buffer);

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
index 7a0e530..c05cab8 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
@@ -185,7 +185,9 @@ public class DefaultRecordBatch extends AbstractRecordBatch implements MutableRe
         int baseSequence = baseSequence();
         if (baseSequence == RecordBatch.NO_SEQUENCE)
             return RecordBatch.NO_SEQUENCE;
-        return baseSequence() + lastOffsetDelta();
+
+        int delta = lastOffsetDelta();
+        return incrementSequence(baseSequence, delta);
     }
 
     @Override
@@ -462,6 +464,12 @@ public class DefaultRecordBatch extends AbstractRecordBatch implements MutableRe
         return RECORD_BATCH_OVERHEAD + DefaultRecord.recordSizeUpperBound(key, value, headers);
     }
 
+    static int incrementSequence(int baseSequence, int increment) {
+        if (baseSequence > Integer.MAX_VALUE - increment)
+            return increment - (Integer.MAX_VALUE - baseSequence) - 1;
+        return baseSequence + increment;
+    }
+
     private abstract class RecordIterator implements CloseableIterator<Record> {
         private final Long logAppendTime;
         private final long baseOffset;

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
index 3db1159..726b619 100644
--- a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
@@ -90,6 +90,35 @@ public class DefaultRecordBatchTest {
     }
 
     @Test
+    public void buildDefaultRecordBatchWithSequenceWrapAround() {
+        long pid = 23423L;
+        short epoch = 145;
+        int baseSequence = Integer.MAX_VALUE - 1;
+        ByteBuffer buffer = ByteBuffer.allocate(2048);
+
+        MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE,
+                TimestampType.CREATE_TIME, 1234567L, RecordBatch.NO_TIMESTAMP, pid, epoch, baseSequence);
+        builder.appendWithOffset(1234567, 1L, "a".getBytes(), "v".getBytes());
+        builder.appendWithOffset(1234568, 2L, "b".getBytes(), "v".getBytes());
+        builder.appendWithOffset(1234569, 3L, "c".getBytes(), "v".getBytes());
+
+        MemoryRecords records = builder.build();
+        List<MutableRecordBatch> batches = TestUtils.toList(records.batches());
+        assertEquals(1, batches.size());
+        RecordBatch batch = batches.get(0);
+
+        assertEquals(pid, batch.producerId());
+        assertEquals(epoch, batch.producerEpoch());
+        assertEquals(baseSequence, batch.baseSequence());
+        assertEquals(0, batch.lastSequence());
+        List<Record> allRecords = TestUtils.toList(batch);
+        assertEquals(3, allRecords.size());
+        assertEquals(Integer.MAX_VALUE - 1, allRecords.get(0).sequence());
+        assertEquals(Integer.MAX_VALUE, allRecords.get(1).sequence());
+        assertEquals(0, allRecords.get(2).sequence());
+    }
+
+    @Test
     public void testSizeInBytes() {
         Header[] headers = new Header[] {
             new RecordHeader("foo", "value".getBytes()),
@@ -265,4 +294,11 @@ public class DefaultRecordBatchTest {
         }
     }
 
+    @Test
+    public void testIncrementSequence() {
+        assertEquals(10, DefaultRecordBatch.incrementSequence(5, 5));
+        assertEquals(0, DefaultRecordBatch.incrementSequence(Integer.MAX_VALUE, 1));
+        assertEquals(4, DefaultRecordBatch.incrementSequence(Integer.MAX_VALUE - 5, 10));
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index 5c39635..040ab38 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -110,18 +110,22 @@ class TransactionCoordinator(brokerId: Int,
       // check transactionTimeoutMs is not larger than the broker configured maximum allowed value
       responseCallback(initTransactionError(Errors.INVALID_TRANSACTION_TIMEOUT))
     } else {
-      val producerId = producerIdManager.generateProducerId()
-      val now = time.milliseconds()
-      val createdMetadata = new TransactionMetadata(transactionalId = transactionalId,
-        producerId = producerId,
-        producerEpoch = 0,
-        txnTimeoutMs = transactionTimeoutMs,
-        state = Empty,
-        topicPartitions = collection.mutable.Set.empty[TopicPartition],
-        txnLastUpdateTimestamp = now)
-
-      // only try to get a new producerId and update the cache if the transactional id is unknown
-      val result: Either[InitProducerIdResult, (Int, TxnTransitMetadata)] = txnManager.getAndMaybeAddTransactionState(transactionalId, Some(createdMetadata)) match {
+      val coordinatorEpochAndMetadata = txnManager.getTransactionState(transactionalId) match {
+        case Right(None) =>
+          val producerId = producerIdManager.generateProducerId()
+          val createdMetadata = new TransactionMetadata(transactionalId = transactionalId,
+            producerId = producerId,
+            producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
+            txnTimeoutMs = transactionTimeoutMs,
+            state = Empty,
+            topicPartitions = collection.mutable.Set.empty[TopicPartition],
+            txnLastUpdateTimestamp = time.milliseconds())
+          txnManager.putTransactionStateIfNotExists(transactionalId, createdMetadata)
+
+        case other => other
+      }
+
+      val result: Either[InitProducerIdResult, (Int, TxnTransitMetadata)] = coordinatorEpochAndMetadata match {
         case Left(err) =>
           Left(initTransactionError(err))
 
@@ -129,15 +133,8 @@ class TransactionCoordinator(brokerId: Int,
           val coordinatorEpoch = existingEpochAndMetadata.coordinatorEpoch
           val txnMetadata = existingEpochAndMetadata.transactionMetadata
 
-          // there might be a concurrent thread that has just updated the mapping
-          // with the transactional id at the same time (hence reference equality will fail);
-          // in this case we will treat it as the metadata has existed already
           txnMetadata synchronized {
-            if (!txnMetadata.eq(createdMetadata)) {
-              initProducerIdWithExistingMetadata(transactionalId, transactionTimeoutMs, coordinatorEpoch, txnMetadata)
-            } else {
-              Right(coordinatorEpoch, txnMetadata.prepareNewProducerId(time.milliseconds()))
-            }
+            prepareInitProduceIdTransit(transactionalId, transactionTimeoutMs, coordinatorEpoch, txnMetadata)
           }
 
         case Right(None) =>
@@ -182,10 +179,10 @@ class TransactionCoordinator(brokerId: Int,
     }
   }
 
-  private def initProducerIdWithExistingMetadata(transactionalId: String,
-                                                 transactionTimeoutMs: Int,
-                                                 coordinatorEpoch: Int,
-                                                 txnMetadata: TransactionMetadata): Either[InitProducerIdResult, (Int, TxnTransitMetadata)] = {
+  private def prepareInitProduceIdTransit(transactionalId: String,
+                                          transactionTimeoutMs: Int,
+                                          coordinatorEpoch: Int,
+                                          txnMetadata: TransactionMetadata): Either[InitProducerIdResult, (Int, TxnTransitMetadata)] = {
     if (txnMetadata.pendingTransitionInProgress) {
       // return a retriable exception to let the client backoff and retry
       Left(initTransactionError(Errors.CONCURRENT_TRANSACTIONS))
@@ -193,15 +190,25 @@ class TransactionCoordinator(brokerId: Int,
       // caller should have synchronized on txnMetadata already
       txnMetadata.state match {
         case PrepareAbort | PrepareCommit =>
-          // reply to client and let client backoff and retry
+          // reply to client and let it backoff and retry
           Left(initTransactionError(Errors.CONCURRENT_TRANSACTIONS))
 
         case CompleteAbort | CompleteCommit | Empty =>
-          // try to append and then update
-          Right(coordinatorEpoch, txnMetadata.prepareIncrementProducerEpoch(transactionTimeoutMs, time.milliseconds()))
+          val transitMetadata = if (txnMetadata.isProducerEpochExhausted) {
+            val newProducerId = producerIdManager.generateProducerId()
+            txnMetadata.prepareProducerIdRotation(newProducerId, transactionTimeoutMs, time.milliseconds())
+          } else {
+            txnMetadata.prepareIncrementProducerEpoch(transactionTimeoutMs, time.milliseconds())
+          }
+
+          Right(coordinatorEpoch, transitMetadata)
 
         case Ongoing =>
-          // indicate to abort the current ongoing txn first
+          // indicate to abort the current ongoing txn first. Note that this epoch is never returned to the
+          // user. We will abort the ongoing transaction and return CONCURRENT_TRANSACTIONS to the client.
+          // This forces the client to retry, which will ensure that the epoch is bumped a second time. In
+          // particular, if fencing the current producer exhausts the available epochs for the current producerId,
+          // then when the client retries, we will generate a new producerId.
           Right(coordinatorEpoch, txnMetadata.prepareFenceProducerEpoch())
         case Dead =>
           throw new IllegalStateException(s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " +
@@ -220,7 +227,7 @@ class TransactionCoordinator(brokerId: Int,
     } else {
       // try to update the transaction metadata and append the updated metadata to txn log;
       // if there is no such metadata treat it as invalid producerId mapping error.
-      val result: Either[Errors, (Int, TxnTransitMetadata)] = txnManager.getAndMaybeAddTransactionState(transactionalId) match {
+      val result: Either[Errors, (Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId) match {
         case Left(err) =>
           Left(err)
 
@@ -286,7 +293,7 @@ class TransactionCoordinator(brokerId: Int,
     if (transactionalId == null || transactionalId.isEmpty)
       responseCallback(Errors.INVALID_REQUEST)
     else {
-      val preAppendResult: Either[Errors, (Int, TxnTransitMetadata)] = txnManager.getAndMaybeAddTransactionState(transactionalId) match {
+      val preAppendResult: Either[Errors, (Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId) match {
         case Left(err) =>
           Left(err)
 
@@ -296,7 +303,6 @@ class TransactionCoordinator(brokerId: Int,
         case Right(Some(epochAndTxnMetadata)) =>
           val txnMetadata = epochAndTxnMetadata.transactionMetadata
           val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch
-          val now = time.milliseconds()
 
           txnMetadata synchronized {
             if (txnMetadata.producerId != producerId)
@@ -349,7 +355,7 @@ class TransactionCoordinator(brokerId: Int,
         case Right((coordinatorEpoch, newMetadata)) =>
           def sendTxnMarkersCallback(error: Errors): Unit = {
             if (error == Errors.NONE) {
-              val preSendResult: Either[Errors, (TransactionMetadata, TxnTransitMetadata)] = txnManager.getAndMaybeAddTransactionState(transactionalId) match {
+              val preSendResult: Either[Errors, (TransactionMetadata, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId) match {
                 case Left(err) =>
                   Left(err)
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
index 344863f..c6cead6 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
@@ -215,7 +215,7 @@ class TransactionMarkerChannelManager(config: KafkaConfig,
         case Errors.NONE =>
           trace(s"Completed sending transaction markers for $transactionalId as $txnResult")
 
-          txnStateManager.getAndMaybeAddTransactionState(transactionalId) match {
+          txnStateManager.getTransactionState(transactionalId) match {
             case Left(Errors.NOT_COORDINATOR) =>
               info(s"I am no longer the coordinator for $transactionalId with coordinator epoch $coordinatorEpoch; cancel appending $newMetadata to transaction log")
 
@@ -291,7 +291,7 @@ class TransactionMarkerChannelManager(config: KafkaConfig,
           }
 
         case None =>
-          txnStateManager.getAndMaybeAddTransactionState(transactionalId) match {
+          txnStateManager.getTransactionState(transactionalId) match {
             case Left(error) =>
               info(s"Encountered $error trying to fetch transaction metadata for $transactionalId with coordinator epoch $coordinatorEpoch; cancel sending markers to its partition leaders")
               txnMarkerPurgatory.cancelForKey(transactionalId)

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala
index da40001..68edc65 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala
@@ -42,7 +42,7 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int,
         val transactionalId = txnIdAndMarker.txnId
         val txnMarker = txnIdAndMarker.txnMarkerEntry
 
-        txnStateManager.getAndMaybeAddTransactionState(transactionalId) match {
+        txnStateManager.getTransactionState(transactionalId) match {
 
           case Left(Errors.NOT_COORDINATOR) =>
             info(s"I am no longer the coordinator for $transactionalId; cancel sending transaction markers $txnMarker to the brokers")
@@ -93,7 +93,7 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int,
         if (errors == null)
           throw new IllegalStateException(s"WriteTxnMarkerResponse does not contain expected error map for producer id ${txnMarker.producerId}")
 
-        txnStateManager.getAndMaybeAddTransactionState(transactionalId) match {
+        txnStateManager.getTransactionState(transactionalId) match {
           case Left(Errors.NOT_COORDINATOR) =>
             info(s"I am no longer the coordinator for $transactionalId; cancel sending transaction markers $txnMarker to the brokers")
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index dbf0ec5..a92e6be 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -18,6 +18,7 @@ package kafka.coordinator.transaction
 
 import kafka.utils.{Logging, nonthreadsafe}
 import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.record.RecordBatch
 
 import scala.collection.{immutable, mutable}
 
@@ -142,7 +143,7 @@ private[transaction] case class TxnTransitMetadata(producerId: Long,
   */
 @nonthreadsafe
 private[transaction] class TransactionMetadata(val transactionalId: String,
-                                               val producerId: Long,
+                                               var producerId: Long,
                                                var producerEpoch: Short,
                                                var txnTimeoutMs: Int,
                                                var state: TransactionState,
@@ -174,6 +175,9 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
   }
 
   def prepareFenceProducerEpoch(): TxnTransitMetadata = {
+    if (producerEpoch == Short.MaxValue)
+      throw new IllegalStateException(s"Cannot fence producer with epoch equal to Short.MaxValue since this would overflow")
+
     // bump up the epoch to let the txn markers be able to override the current producer epoch
     producerEpoch = (producerEpoch + 1).toShort
 
@@ -181,45 +185,62 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
     TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
   }
 
-  def prepareIncrementProducerEpoch(newTxnTimeoutMs: Int,
-                                    updateTimestamp: Long): TxnTransitMetadata = {
+  def prepareIncrementProducerEpoch(newTxnTimeoutMs: Int, updateTimestamp: Long): TxnTransitMetadata = {
+    if (isProducerEpochExhausted)
+      throw new IllegalStateException(s"Cannot allocate any more producer epochs for producerId $producerId")
 
-    prepareTransitionTo(Empty, (producerEpoch + 1).toShort, newTxnTimeoutMs, immutable.Set.empty[TopicPartition],
-      -1, updateTimestamp)
+    val nextEpoch = if (producerEpoch == RecordBatch.NO_PRODUCER_EPOCH) 0 else producerEpoch + 1
+    prepareTransitionTo(Empty, producerId, nextEpoch.toShort, newTxnTimeoutMs, immutable.Set.empty[TopicPartition], -1,
+      updateTimestamp)
   }
 
-  def prepareNewProducerId(updateTimestamp: Long): TxnTransitMetadata = {
-    prepareTransitionTo(Empty, producerEpoch, txnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp)
+  def prepareProducerIdRotation(newProducerId: Long, newTxnTimeoutMs: Int, updateTimestamp: Long): TxnTransitMetadata = {
+    if (hasPendingTransaction)
+      throw new IllegalStateException("Cannot rotate producer ids while a transaction is still pending")
+    prepareTransitionTo(Empty, newProducerId, 0, newTxnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp)
   }
 
-  def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition],
-                           updateTimestamp: Long): TxnTransitMetadata = {
-
-    if (state == Empty || state == CompleteCommit || state == CompleteAbort) {
-      prepareTransitionTo(Ongoing, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet,
-        updateTimestamp, updateTimestamp)
-    } else {
-      prepareTransitionTo(Ongoing, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet,
-        txnStartTimestamp, updateTimestamp)
+  def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition], updateTimestamp: Long): TxnTransitMetadata = {
+    val newTxnStartTimestamp = state match {
+      case Empty | CompleteAbort | CompleteCommit => updateTimestamp
+      case _ => txnStartTimestamp
     }
+
+    prepareTransitionTo(Ongoing, producerId, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet,
+      newTxnStartTimestamp, updateTimestamp)
   }
 
-  def prepareAbortOrCommit(newState: TransactionState,
-                           updateTimestamp: Long): TxnTransitMetadata = {
-    prepareTransitionTo(newState, producerEpoch, txnTimeoutMs, topicPartitions.toSet, txnStartTimestamp, updateTimestamp)
+  def prepareAbortOrCommit(newState: TransactionState, updateTimestamp: Long): TxnTransitMetadata = {
+    prepareTransitionTo(newState, producerId, producerEpoch, txnTimeoutMs, topicPartitions.toSet, txnStartTimestamp,
+      updateTimestamp)
   }
 
   def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = {
     val newState = if (state == PrepareCommit) CompleteCommit else CompleteAbort
-    prepareTransitionTo(newState, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition], txnStartTimestamp, updateTimestamp)
+    prepareTransitionTo(newState, producerId, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition], txnStartTimestamp,
+      updateTimestamp)
   }
 
+  def prepareDead(): TxnTransitMetadata = {
+    prepareTransitionTo(Dead, producerId, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition], txnStartTimestamp,
+      txnLastUpdateTimestamp)
+  }
 
-  def prepareDead : TxnTransitMetadata = {
-    prepareTransitionTo(Dead, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition], txnStartTimestamp, txnLastUpdateTimestamp)
+  /**
+   * Check if the epochs have been exhausted for the current producerId. We do not allow the client to use an
+   * epoch equal to Short.MaxValue to ensure that the coordinator will always be able to fence an existing producer.
+   */
+  def isProducerEpochExhausted: Boolean = producerEpoch >= Short.MaxValue - 1
+
+  private def hasPendingTransaction: Boolean = {
+    state match {
+      case Ongoing | PrepareAbort | PrepareCommit => true
+      case _ => false
+    }
   }
 
   private def prepareTransitionTo(newState: TransactionState,
+                                  newProducerId: Long,
                                   newEpoch: Short,
                                   newTxnTimeoutMs: Int,
                                   newTopicPartitions: immutable.Set[TopicPartition],
@@ -229,9 +250,15 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
       throw new IllegalStateException(s"Preparing transaction state transition to $newState " +
         s"while it already a pending state ${pendingState.get}")
 
+    if (newProducerId < 0)
+      throw new IllegalArgumentException(s"Illegal new producer id $newProducerId")
+
+    if (newEpoch < 0)
+      throw new IllegalArgumentException(s"Illegal new producer epoch $newEpoch")
+
     // check that the new state transition is valid and update the pending state if necessary
     if (TransactionMetadata.validPreviousStates(newState).contains(state)) {
-      val transitMetadata = TxnTransitMetadata(producerId, newEpoch, newTxnTimeoutMs, newState,
+      val transitMetadata = TxnTransitMetadata(newProducerId, newEpoch, newTxnTimeoutMs, newState,
         newTopicPartitions, newTxnStartTimestamp, updateTimestamp)
       debug(s"TransactionalId $transactionalId prepare transition from $state to $transitMetadata")
       pendingState = Some(newState)
@@ -246,74 +273,73 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
     // metadata transition is valid only if all the following conditions are met:
     //
     // 1. the new state is already indicated in the pending state.
-    // 2. the producerId is the same (i.e. this field should never be changed)
-    // 3. the epoch should be either the same value or old value + 1.
-    // 4. the last update time is no smaller than the old value.
+    // 2. the epoch should be either the same value, the old value + 1, or 0 if we have a new producerId.
+    // 3. the last update time is no smaller than the old value.
     // 4. the old partitions set is a subset of the new partitions set.
     //
-    // plus, we should only try to update the metadata after the corresponding log entry has been successfully written and replicated (see TransactionStateManager#appendTransactionToLog)
+    // plus, we should only try to update the metadata after the corresponding log entry has been successfully
+    // written and replicated (see TransactionStateManager#appendTransactionToLog)
     //
     // if valid, transition is done via overwriting the whole object to ensure synchronization
 
-    val toState = pendingState.getOrElse(throw new IllegalStateException("Completing transaction state transition while it does not have a pending state"))
+    val toState = pendingState.getOrElse(throw new IllegalStateException(s"TransactionalId $transactionalId " +
+      "completing transaction state transition while it does not have a pending state"))
 
-    if (toState != transitMetadata.txnState ||
-      producerId != transitMetadata.producerId ||
-      txnLastUpdateTimestamp > transitMetadata.txnLastUpdateTimestamp) {
-
-      throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata state")
+    if (toState != transitMetadata.txnState || txnLastUpdateTimestamp > transitMetadata.txnLastUpdateTimestamp) {
+      throwStateTransitionFailure(toState)
     } else {
       toState match {
         case Empty => // from initPid
-          if (producerEpoch > transitMetadata.producerEpoch ||
-            producerEpoch < transitMetadata.producerEpoch - 1 ||
+          if ((producerEpoch != transitMetadata.producerEpoch && !validProducerEpochBump(transitMetadata)) ||
             transitMetadata.topicPartitions.nonEmpty ||
             transitMetadata.txnStartTimestamp != -1) {
 
-            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+            throwStateTransitionFailure(toState)
           } else {
             txnTimeoutMs = transitMetadata.txnTimeoutMs
             producerEpoch = transitMetadata.producerEpoch
+            producerId = transitMetadata.producerId
           }
 
         case Ongoing => // from addPartitions
-          if (producerEpoch != transitMetadata.producerEpoch ||
+          if (!validProducerEpoch(transitMetadata) ||
             !topicPartitions.subsetOf(transitMetadata.topicPartitions) ||
             txnTimeoutMs != transitMetadata.txnTimeoutMs ||
             txnStartTimestamp > transitMetadata.txnStartTimestamp) {
 
-            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+            throwStateTransitionFailure(toState)
           } else {
             txnStartTimestamp = transitMetadata.txnStartTimestamp
             addPartitions(transitMetadata.topicPartitions)
           }
 
         case PrepareAbort | PrepareCommit => // from endTxn
-          if (producerEpoch != transitMetadata.producerEpoch ||
+          if (!validProducerEpoch(transitMetadata) ||
             !topicPartitions.toSet.equals(transitMetadata.topicPartitions) ||
             txnTimeoutMs != transitMetadata.txnTimeoutMs ||
             txnStartTimestamp != transitMetadata.txnStartTimestamp) {
 
-            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+            throwStateTransitionFailure(toState)
           }
 
         case CompleteAbort | CompleteCommit => // from write markers
-          if (producerEpoch != transitMetadata.producerEpoch ||
+          info(s"transit start ${transitMetadata.txnStartTimestamp}")
+          if (!validProducerEpoch(transitMetadata) ||
             txnTimeoutMs != transitMetadata.txnTimeoutMs ||
             transitMetadata.txnStartTimestamp == -1) {
 
-            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+            throwStateTransitionFailure(toState)
           } else {
             txnStartTimestamp = transitMetadata.txnStartTimestamp
             topicPartitions.clear()
           }
+
         case Dead =>
           // The transactionalId was being expired. The completion of the operation should result in removal of the
           // the metadata from the cache, so we should never realistically transition to the dead state.
-          throw new IllegalStateException(s"TransactionalId : $transactionalId is trying to complete a transition to " +
+          throw new IllegalStateException(s"TransactionalId $transactionalId is trying to complete a transition to " +
             s"$toState. This means that the transactionalId was being expired, and the only acceptable completion of " +
             s"this operation is to remove the transaction metadata from the cache, not to persist the $toState in the log.")
-
       }
 
       debug(s"TransactionalId $transactionalId complete transition from $state to $transitMetadata")
@@ -323,6 +349,23 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
     }
   }
 
+  private def validProducerEpoch(transitMetadata: TxnTransitMetadata): Boolean = {
+    val transitEpoch = transitMetadata.producerEpoch
+    val transitProducerId = transitMetadata.producerId
+    transitEpoch == producerEpoch && transitProducerId == producerId
+  }
+
+  private def validProducerEpochBump(transitMetadata: TxnTransitMetadata): Boolean = {
+    val transitEpoch = transitMetadata.producerEpoch
+    val transitProducerId = transitMetadata.producerId
+    transitEpoch == producerEpoch + 1 || (transitEpoch == 0 && transitProducerId != producerId)
+  }
+
+  private def throwStateTransitionFailure(toState: TransactionState): Unit = {
+    throw new IllegalStateException(s"TransactionalId $transactionalId failed transition to state $toState " +
+      "due to unexpected metadata")
+  }
+
   def pendingTransitionInProgress: Boolean = pendingState.isDefined
 
   override def toString = {

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index 05edefb..da3ba48 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -130,7 +130,7 @@ class TransactionStateManager(brokerId: Int,
               txnMetadata.txnLastUpdateTimestamp <= now - config.transactionalIdExpirationMs
             }.map { case (transactionalId, txnMetadata) =>
               val txnMetadataTransition = txnMetadata synchronized {
-                txnMetadata.prepareDead
+                txnMetadata.prepareDead()
               }
               TransactionalIdCoordinatorEpochAndMetadata(transactionalId, entry.coordinatorEpoch, txnMetadataTransition)
             }
@@ -197,6 +197,11 @@ class TransactionStateManager(brokerId: Int,
     }, delay = config.removeExpiredTransactionalIdsIntervalMs, period = config.removeExpiredTransactionalIdsIntervalMs)
   }
 
+  def getTransactionState(transactionalId: String) = getAndMaybeAddTransactionState(transactionalId, None)
+
+  def putTransactionStateIfNotExists(transactionalId: String, txnMetadata: TransactionMetadata) =
+    getAndMaybeAddTransactionState(transactionalId, Some(txnMetadata))
+
   /**
    * Get the transaction metadata associated with the given transactional id, or an error if
    * the coordinator does not own the transaction partition or is still loading it; if not found
@@ -204,40 +209,41 @@ class TransactionStateManager(brokerId: Int,
    *
    * This function is covered by the state read lock
    */
-  def getAndMaybeAddTransactionState(transactionalId: String,
-                                     createdTxnMetadata: Option[TransactionMetadata] = None): Either[Errors, Option[CoordinatorEpochAndTxnMetadata]]
-  = inReadLock(stateLock) {
-    val partitionId = partitionFor(transactionalId)
-
-    if (loadingPartitions.exists(_.txnPartitionId == partitionId))
-      return Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)
-
-    if (leavingPartitions.exists(_.txnPartitionId == partitionId))
-      Right(Errors.NOT_COORDINATOR)
-
-    transactionMetadataCache.get(partitionId) match {
-      case Some(cacheEntry) =>
-        cacheEntry.metadataPerTransactionalId.get(transactionalId) match {
-          case null =>
-            createdTxnMetadata match {
-              case None =>
-                Right(None)
-
-              case Some(txnMetadata) =>
-                val currentTxnMetadata = cacheEntry.metadataPerTransactionalId.putIfNotExists(transactionalId, txnMetadata)
-                if (currentTxnMetadata != null) {
-                  Right(Some(CoordinatorEpochAndTxnMetadata(cacheEntry.coordinatorEpoch, currentTxnMetadata)))
-                } else {
-                  Right(Some(CoordinatorEpochAndTxnMetadata(cacheEntry.coordinatorEpoch, txnMetadata)))
-                }
-            }
+  private def getAndMaybeAddTransactionState(transactionalId: String,
+                                             createdTxnMetadata: Option[TransactionMetadata]): Either[Errors, Option[CoordinatorEpochAndTxnMetadata]] = {
+    inReadLock(stateLock) {
+      val partitionId = partitionFor(transactionalId)
+
+      if (loadingPartitions.exists(_.txnPartitionId == partitionId))
+        return Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)
+
+      if (leavingPartitions.exists(_.txnPartitionId == partitionId))
+        Right(Errors.NOT_COORDINATOR)
+
+      transactionMetadataCache.get(partitionId) match {
+        case Some(cacheEntry) =>
+          cacheEntry.metadataPerTransactionalId.get(transactionalId) match {
+            case null =>
+              createdTxnMetadata match {
+                case None =>
+                  Right(None)
+
+                case Some(txnMetadata) =>
+                  val currentTxnMetadata = cacheEntry.metadataPerTransactionalId.putIfNotExists(transactionalId, txnMetadata)
+                  if (currentTxnMetadata != null) {
+                    Right(Some(CoordinatorEpochAndTxnMetadata(cacheEntry.coordinatorEpoch, currentTxnMetadata)))
+                  } else {
+                    Right(Some(CoordinatorEpochAndTxnMetadata(cacheEntry.coordinatorEpoch, txnMetadata)))
+                  }
+              }
 
-          case currentTxnMetadata =>
-            Right(Some(CoordinatorEpochAndTxnMetadata(cacheEntry.coordinatorEpoch, currentTxnMetadata)))
-        }
+            case currentTxnMetadata =>
+              Right(Some(CoordinatorEpochAndTxnMetadata(cacheEntry.coordinatorEpoch, currentTxnMetadata)))
+          }
 
-      case None =>
-        Left(Errors.NOT_COORDINATOR)
+        case None =>
+          Left(Errors.NOT_COORDINATOR)
+      }
     }
   }
 
@@ -506,7 +512,7 @@ class TransactionStateManager(brokerId: Int,
       if (responseError == Errors.NONE) {
         // now try to update the cache: we need to update the status in-place instead of
         // overwriting the whole object to ensure synchronization
-        getAndMaybeAddTransactionState(transactionalId) match {
+        getTransactionState(transactionalId) match {
 
           case Left(err) =>
             responseCallback(err)
@@ -536,7 +542,7 @@ class TransactionStateManager(brokerId: Int,
         }
       } else {
         // Reset the pending state when returning an error, since there is no active transaction for the transactional id at this point.
-        getAndMaybeAddTransactionState(transactionalId) match {
+        getTransactionState(transactionalId) match {
           case Right(Some(epochAndTxnMetadata)) =>
             val metadata = epochAndTxnMetadata.transactionMetadata
             metadata synchronized {
@@ -567,7 +573,7 @@ class TransactionStateManager(brokerId: Int,
       // returns and before appendRecords() is called, since otherwise entries with a high coordinator epoch could have
       // been appended to the log in between these two events, and therefore appendRecords() would append entries with
       // an old coordinator epoch that can still be successfully replicated on followers and make the log in a bad state.
-      getAndMaybeAddTransactionState(transactionalId) match {
+      getTransactionState(transactionalId) match {
         case Left(err) =>
           responseCallback(err)
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/main/scala/kafka/log/ProducerStateManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/ProducerStateManager.scala b/core/src/main/scala/kafka/log/ProducerStateManager.scala
index 5ec91ce..d6e704a 100644
--- a/core/src/main/scala/kafka/log/ProducerStateManager.scala
+++ b/core/src/main/scala/kafka/log/ProducerStateManager.scala
@@ -91,11 +91,11 @@ private[log] class ProducerAppendInfo(val producerId: Long,
   private val transactions = ListBuffer.empty[TxnMetadata]
 
   private def validateAppend(producerEpoch: Short, firstSeq: Int, lastSeq: Int) = {
-    if (this.producerEpoch > producerEpoch) {
+    if (isFenced(producerEpoch)) {
       throw new ProducerFencedException(s"Producer's epoch is no longer valid. There is probably another producer " +
         s"with a newer epoch. $producerEpoch (request epoch), ${this.producerEpoch} (server epoch)")
     } else if (validateSequenceNumbers) {
-      if (this.producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || this.producerEpoch < producerEpoch) {
+      if (producerEpoch != this.producerEpoch) {
         if (firstSeq != 0)
           throw new OutOfOrderSequenceException(s"Invalid sequence number for new epoch: $producerEpoch " +
             s"(request epoch), $firstSeq (seq. number)")
@@ -107,13 +107,21 @@ private[log] class ProducerAppendInfo(val producerId: Long,
         throw new DuplicateSequenceNumberException(s"Duplicate sequence number for producerId $producerId: (incomingBatch.firstSeq, " +
           s"incomingBatch.lastSeq): ($firstSeq, $lastSeq), (lastEntry.firstSeq, lastEntry.lastSeq): " +
           s"(${this.firstSeq}, ${this.lastSeq}).")
-      } else if (firstSeq != this.lastSeq + 1L) {
+      } else if (!inSequence(firstSeq, lastSeq)) {
         throw new OutOfOrderSequenceException(s"Out of order sequence number for producerId $producerId: $firstSeq " +
           s"(incoming seq. number), ${this.lastSeq} (current end sequence number)")
       }
     }
   }
 
+  private def inSequence(firstSeq: Int, lastSeq: Int): Boolean = {
+    firstSeq == this.lastSeq + 1L || (firstSeq == 0 && this.lastSeq == Int.MaxValue)
+  }
+
+  private def isFenced(producerEpoch: Short): Boolean = {
+    producerEpoch < this.producerEpoch
+  }
+
   def append(batch: RecordBatch): Option[CompletedTxn] = {
     if (batch.isControlBatch) {
       val record = batch.iterator.next()
@@ -158,14 +166,14 @@ private[log] class ProducerAppendInfo(val producerId: Long,
                          producerEpoch: Short,
                          offset: Long,
                          timestamp: Long): CompletedTxn = {
-    if (this.producerEpoch > producerEpoch)
+    if (isFenced(producerEpoch))
       throw new ProducerFencedException(s"Invalid producer epoch: $producerEpoch (zombie): ${this.producerEpoch} (current)")
 
     if (this.coordinatorEpoch > endTxnMarker.coordinatorEpoch)
       throw new TransactionCoordinatorFencedException(s"Invalid coordinator epoch: ${endTxnMarker.coordinatorEpoch} " +
         s"(zombie), $coordinatorEpoch (current)")
 
-    if (producerEpoch > this.producerEpoch) {
+    if (producerEpoch != this.producerEpoch) {
       // it is possible that this control record is the first record seen from a new epoch (the producer
       // may fail before sending to the partition or the request itself could fail for some reason). In this
       // case, we bump the epoch and reset the sequence numbers

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
index 205dc6e..0a082ed 100644
--- a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
+++ b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
@@ -18,6 +18,7 @@
 package kafka.api
 
 import java.util.Properties
+import java.util.concurrent.TimeUnit
 
 import kafka.integration.KafkaServerTestHarness
 import kafka.server.KafkaConfig
@@ -26,13 +27,12 @@ import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaC
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.ProducerFencedException
 import org.apache.kafka.common.protocol.SecurityProtocol
-import org.junit.{After, Before, Ignore, Test}
+import org.junit.{After, Before, Test}
 import org.junit.Assert._
 
 import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.ExecutionException
-import scala.util.Random
 
 class TransactionsTest extends KafkaServerTestHarness {
   val numServers = 3
@@ -321,8 +321,10 @@ class TransactionsTest extends KafkaServerTestHarness {
 
       producer2.initTransactions()  // ok, will abort the open transaction.
       producer2.beginTransaction()
-      producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "2", "4", willBeCommitted = true)).get()
-      producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "2", "4", willBeCommitted = true)).get()
+      producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "2", "4", willBeCommitted = true))
+        .get(20, TimeUnit.SECONDS)
+      producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "2", "4", willBeCommitted = true))
+        .get(20, TimeUnit.SECONDS)
 
       try {
         producer1.beginTransaction()

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index 4d953eb..e67ed08 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -35,7 +35,7 @@ class TransactionCoordinatorTest {
   val pidManager: ProducerIdManager = EasyMock.createNiceMock(classOf[ProducerIdManager])
   val transactionManager: TransactionStateManager = EasyMock.createNiceMock(classOf[TransactionStateManager])
   val transactionMarkerChannelManager: TransactionMarkerChannelManager = EasyMock.createNiceMock(classOf[TransactionMarkerChannelManager])
-  val capturedTxn: Capture[Option[TransactionMetadata]] = EasyMock.newCapture()
+  val capturedTxn: Capture[TransactionMetadata] = EasyMock.newCapture()
   val capturedErrorsCallback: Capture[Errors => Unit] = EasyMock.newCapture()
   val brokerId = 0
   val coordinatorEpoch = 0
@@ -47,7 +47,7 @@ class TransactionCoordinatorTest {
   private val partitions = mutable.Set[TopicPartition](new TopicPartition("topic1", 0))
   private val scheduler = new MockScheduler(time)
 
-  val coordinator: TransactionCoordinator = new TransactionCoordinator(brokerId,
+  val coordinator = new TransactionCoordinator(brokerId,
     scheduler,
     pidManager,
     transactionManager,
@@ -101,11 +101,15 @@ class TransactionCoordinatorTest {
   def shouldInitPidWithEpochZeroForNewTransactionalId(): Unit = {
     initPidGenericMocks(transactionalId)
 
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.capture(capturedTxn)))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
+      .andReturn(Right(None))
+      .once()
+
+    EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.eq(transactionalId), EasyMock.capture(capturedTxn)))
       .andAnswer(new IAnswer[Either[Errors, Option[CoordinatorEpochAndTxnMetadata]]] {
         override def answer(): Either[Errors, Option[CoordinatorEpochAndTxnMetadata]] = {
-          if (capturedTxn.hasCaptured && capturedTxn.getValue.isDefined)
-            Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue.get)))
+          if (capturedTxn.hasCaptured)
+            Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue)))
           else
             Right(None)
         }
@@ -130,11 +134,40 @@ class TransactionCoordinatorTest {
   }
 
   @Test
+  def shouldGenerateNewProducerIdIfEpochsExhausted(): Unit = {
+    initPidGenericMocks(transactionalId)
+
+    val txnMetadata = new TransactionMetadata(transactionalId, producerId, (Short.MaxValue - 1).toShort,
+      txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds())
+
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
+
+    EasyMock.expect(transactionManager.appendTransactionToLog(
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.anyObject().asInstanceOf[TxnTransitMetadata],
+      EasyMock.capture(capturedErrorsCallback)
+    )).andAnswer(new IAnswer[Unit] {
+      override def answer(): Unit = {
+        capturedErrorsCallback.getValue.apply(Errors.NONE)
+      }
+    })
+
+    EasyMock.replay(pidManager, transactionManager)
+
+    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, initProducerIdMockCallback)
+    assertNotEquals(producerId, result.producerId)
+    assertEquals(0, result.producerEpoch)
+    assertEquals(Errors.NONE, result.error)
+  }
+
+  @Test
   def shouldRespondWithNotCoordinatorOnInitPidWhenNotCoordinator(): Unit = {
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
       .anyTimes()
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Left(Errors.NOT_COORDINATOR))
     EasyMock.replay(transactionManager)
 
@@ -147,7 +180,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
       .anyTimes()
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
     EasyMock.replay(transactionManager)
 
@@ -157,7 +190,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithInvalidPidMappingOnAddPartitionsToTransactionWhenTransactionalIdNotPresent(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(None))
     EasyMock.replay(transactionManager)
 
@@ -179,7 +212,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithNotCoordinatorOnAddPartitionsWhenNotCoordinator(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Left(Errors.NOT_COORDINATOR))
     EasyMock.replay(transactionManager)
 
@@ -189,7 +222,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithCoordinatorLoadInProgressOnAddPartitionsWhenCoordintorLoading(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
 
     EasyMock.replay(transactionManager)
@@ -209,7 +242,7 @@ class TransactionCoordinatorTest {
   }
 
   def validateConcurrentTransactions(state: TransactionState): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, 0, state, mutable.Set.empty, 0, 0)))))
 
     EasyMock.replay(transactionManager)
@@ -220,7 +253,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithInvalidTnxProduceEpochOnAddPartitionsWhenEpochsAreDifferent(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 10, 0, PrepareCommit, mutable.Set.empty, 0, 0)))))
 
     EasyMock.replay(transactionManager)
@@ -253,7 +286,7 @@ class TransactionCoordinatorTest {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, previousState,
       mutable.Set.empty, time.milliseconds(), time.milliseconds())
 
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
 
     EasyMock.expect(transactionManager.appendTransactionToLog(
@@ -272,7 +305,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithErrorsNoneOnAddPartitionWhenNoErrorsAndPartitionsTheSame(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, 0, Empty, partitions, 0, 0)))))
 
     EasyMock.replay(transactionManager)
@@ -285,7 +318,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(None))
     EasyMock.replay(transactionManager)
 
@@ -296,7 +329,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 10, 0, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
     EasyMock.replay(transactionManager)
 
@@ -307,7 +340,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, 1, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
     EasyMock.replay(transactionManager)
 
@@ -318,7 +351,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(): Unit ={
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
     EasyMock.replay(transactionManager)
 
@@ -330,7 +363,7 @@ class TransactionCoordinatorTest {
   @Test
   def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(): Unit ={
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
     EasyMock.replay(transactionManager)
 
@@ -342,7 +375,7 @@ class TransactionCoordinatorTest {
   @Test
   def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(): Unit = {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
     EasyMock.replay(transactionManager)
 
@@ -354,7 +387,7 @@ class TransactionCoordinatorTest {
   @Test
   def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
     EasyMock.replay(transactionManager)
 
@@ -365,7 +398,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReturnConcurrentTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, 1, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
     EasyMock.replay(transactionManager)
 
@@ -376,7 +409,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, 1, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
     EasyMock.replay(transactionManager)
 
@@ -385,7 +418,6 @@ class TransactionCoordinatorTest {
     EasyMock.verify(transactionManager)
   }
 
-
   @Test
   def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(): Unit = {
     mockPrepare(PrepareCommit)
@@ -415,7 +447,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.eq(None)))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Left(Errors.NOT_COORDINATOR))
     EasyMock.replay(transactionManager)
 
@@ -425,7 +457,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Left(Errors.NOT_COORDINATOR))
     EasyMock.replay(transactionManager)
 
@@ -435,7 +467,7 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(): Unit = {
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
 
     EasyMock.replay(transactionManager)
@@ -477,7 +509,11 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
 
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+    EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.eq(transactionalId), EasyMock.anyObject[TransactionMetadata]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
+      .anyTimes()
+
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
       .anyTimes()
 
@@ -503,6 +539,50 @@ class TransactionCoordinatorTest {
   }
 
   @Test
+  def shouldUseLastEpochToFenceWhenEpochsAreExhausted(): Unit = {
+    val txnMetadata = new TransactionMetadata(transactionalId, producerId, (Short.MaxValue - 1).toShort,
+      txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
+    assertTrue(txnMetadata.isProducerEpochExhausted)
+
+    EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
+      .andReturn(true)
+
+    EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.eq(transactionalId), EasyMock.anyObject[TransactionMetadata]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
+      .anyTimes()
+
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
+      .anyTimes()
+
+    EasyMock.expect(transactionManager.appendTransactionToLog(
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(TxnTransitMetadata(
+        producerId = producerId,
+        producerEpoch = Short.MaxValue,
+        txnTimeoutMs = txnTimeoutMs,
+        txnState = PrepareAbort,
+        topicPartitions = partitions.toSet,
+        txnStartTimestamp = time.milliseconds(),
+        txnLastUpdateTimestamp = time.milliseconds())),
+      EasyMock.capture(capturedErrorsCallback)))
+      .andAnswer(new IAnswer[Unit] {
+        override def answer(): Unit = {
+          capturedErrorsCallback.getValue.apply(Errors.NONE)
+        }
+      })
+
+    EasyMock.replay(transactionManager)
+
+    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, initProducerIdMockCallback)
+    assertEquals(Short.MaxValue, txnMetadata.producerEpoch)
+
+    assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result)
+    EasyMock.verify(transactionManager)
+  }
+
+  @Test
   def shouldRemoveTransactionsForPartitionOnEmigration(): Unit = {
     EasyMock.expect(transactionManager.removeTransactionsForTxnTopicPartition(0, coordinatorEpoch))
     EasyMock.expect(transactionMarkerChannelManager.removeMarkersForTxnTopicPartition(0))
@@ -522,7 +602,7 @@ class TransactionCoordinatorTest {
 
     EasyMock.expect(transactionManager.timedOutTransactions())
       .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
       .once()
 
@@ -554,7 +634,7 @@ class TransactionCoordinatorTest {
 
     EasyMock.expect(transactionManager.timedOutTransactions())
       .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
 
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
@@ -570,7 +650,7 @@ class TransactionCoordinatorTest {
       .andReturn(true).anyTimes()
 
     val metadata = new TransactionMetadata(transactionalId, 0, 0, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0)
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))).anyTimes()
 
     EasyMock.replay(transactionManager)
@@ -589,7 +669,7 @@ class TransactionCoordinatorTest {
       .andReturn(true)
 
     val metadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds())
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
 
     val capturedNewMetadata: Capture[TxnTransitMetadata] = EasyMock.newCapture()
@@ -625,7 +705,7 @@ class TransactionCoordinatorTest {
     val transition = TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, transactionState,
       partitions.toSet, now, now)
 
-    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata))))
       .once()
     EasyMock.expect(transactionManager.appendTransactionToLog(

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
index 4015a4f..b797e38 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
@@ -76,10 +76,10 @@ class TransactionMarkerChannelManagerTest {
     EasyMock.expect(txnStateManager.partitionFor(transactionalId2))
       .andReturn(txnTopicPartition2)
       .anyTimes()
-    EasyMock.expect(txnStateManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId1), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId1)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))))
       .anyTimes()
-    EasyMock.expect(txnStateManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId2), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId2)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata2))))
       .anyTimes()
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
index e3e67b7..df2f7df 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
@@ -57,7 +57,7 @@ class TransactionMarkerRequestCompletionHandlerTest {
     EasyMock.expect(txnStateManager.partitionFor(transactionalId))
       .andReturn(txnTopicPartition)
       .anyTimes()
-    EasyMock.expect(txnStateManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
       .anyTimes()
     EasyMock.replay(txnStateManager)
@@ -100,7 +100,7 @@ class TransactionMarkerRequestCompletionHandlerTest {
 
   @Test
   def shouldCompleteDelayedOperationWhenNotCoordinator(): Unit = {
-    EasyMock.expect(txnStateManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Left(Errors.NOT_COORDINATOR))
       .anyTimes()
     EasyMock.replay(txnStateManager)
@@ -110,7 +110,7 @@ class TransactionMarkerRequestCompletionHandlerTest {
 
   @Test
   def shouldCompleteDelayedOperationWhenCoordinatorLoading(): Unit = {
-    EasyMock.expect(txnStateManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
       .anyTimes()
     EasyMock.replay(txnStateManager)
@@ -120,7 +120,7 @@ class TransactionMarkerRequestCompletionHandlerTest {
 
   @Test
   def shouldCompleteDelayedOperationWhenCoordinatorEpochChanged(): Unit = {
-    EasyMock.expect(txnStateManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+    EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch+1, txnMetadata))))
       .anyTimes()
     EasyMock.replay(txnStateManager)

http://git-wip-us.apache.org/repos/asf/kafka/blob/1c882ee5/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala
new file mode 100644
index 0000000..4f2fe5f
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package kafka.coordinator.transaction
+
+import kafka.utils.MockTime
+import org.apache.kafka.common.record.RecordBatch
+import org.junit.Assert._
+import org.junit.Test
+
+import scala.collection.mutable
+
+class TransactionMetadataTest {
+
+  val time = new MockTime()
+
+  @Test
+  def testInitializeEpoch(): Unit = {
+    val transactionalId = "txnlId"
+    val producerId = 23423L
+    val producerEpoch = RecordBatch.NO_PRODUCER_EPOCH
+
+    val txnMetadata = new TransactionMetadata(
+      transactionalId = transactionalId,
+      producerId = producerId,
+      producerEpoch = producerEpoch,
+      txnTimeoutMs = 30000,
+      state = Empty,
+      topicPartitions = mutable.Set.empty,
+      txnLastUpdateTimestamp = time.milliseconds())
+
+    val transitMetadata = txnMetadata.prepareIncrementProducerEpoch(30000, time.milliseconds())
+    txnMetadata.completeTransitionTo(transitMetadata)
+    assertEquals(producerId, txnMetadata.producerId)
+    assertEquals(0, txnMetadata.producerEpoch)
+  }
+
+  @Test
+  def testNormalEpochBump(): Unit = {
+    val transactionalId = "txnlId"
+    val producerId = 23423L
+    val producerEpoch = 735.toShort
+
+    val txnMetadata = new TransactionMetadata(
+      transactionalId = transactionalId,
+      producerId = producerId,
+      producerEpoch = producerEpoch,
+      txnTimeoutMs = 30000,
+      state = Empty,
+      topicPartitions = mutable.Set.empty,
+      txnLastUpdateTimestamp = time.milliseconds())
+
+    val transitMetadata = txnMetadata.prepareIncrementProducerEpoch(30000, time.milliseconds())
+    txnMetadata.completeTransitionTo(transitMetadata)
+    assertEquals(producerId, txnMetadata.producerId)
+    assertEquals(producerEpoch + 1, txnMetadata.producerEpoch)
+  }
+
+  @Test(expected = classOf[IllegalStateException])
+  def testBumpEpochNotAllowedIfEpochsExhausted(): Unit = {
+    val transactionalId = "txnlId"
+    val producerId = 23423L
+    val producerEpoch = (Short.MaxValue - 1).toShort
+
+    val txnMetadata = new TransactionMetadata(
+      transactionalId = transactionalId,
+      producerId = producerId,
+      producerEpoch = producerEpoch,
+      txnTimeoutMs = 30000,
+      state = Empty,
+      topicPartitions = mutable.Set.empty,
+      txnLastUpdateTimestamp = time.milliseconds())
+    assertTrue(txnMetadata.isProducerEpochExhausted)
+
+    txnMetadata.prepareIncrementProducerEpoch(30000, time.milliseconds())
+  }
+
+  @Test
+  def testFenceProducerAfterEpochsExhausted(): Unit = {
+    val transactionalId = "txnlId"
+    val producerId = 23423L
+    val producerEpoch = (Short.MaxValue - 1).toShort
+
+    val txnMetadata = new TransactionMetadata(
+      transactionalId = transactionalId,
+      producerId = producerId,
+      producerEpoch = producerEpoch,
+      txnTimeoutMs = 30000,
+      state = Ongoing,
+      topicPartitions = mutable.Set.empty,
+      txnLastUpdateTimestamp = time.milliseconds())
+    assertTrue(txnMetadata.isProducerEpochExhausted)
+
+    txnMetadata.prepareFenceProducerEpoch()
+    assertEquals(Short.MaxValue, txnMetadata.producerEpoch)
+
+    val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())
+    txnMetadata.completeTransitionTo(transitMetadata)
+    assertEquals(producerId, transitMetadata.producerId)
+  }
+
+  @Test(expected = classOf[IllegalStateException])
+  def testFenceProducerNotAllowedIfItWouldOverflow(): Unit = {
+    val transactionalId = "txnlId"
+    val producerId = 23423L
+    val producerEpoch = Short.MaxValue
+
+    val txnMetadata = new TransactionMetadata(
+      transactionalId = transactionalId,
+      producerId = producerId,
+      producerEpoch = producerEpoch,
+      txnTimeoutMs = 30000,
+      state = Ongoing,
+      topicPartitions = mutable.Set.empty,
+      txnLastUpdateTimestamp = time.milliseconds())
+    assertTrue(txnMetadata.isProducerEpochExhausted)
+    txnMetadata.prepareFenceProducerEpoch()
+  }
+
+  @Test
+  def testRotateProducerId(): Unit = {
+    val transactionalId = "txnlId"
+    val producerId = 23423L
+    val producerEpoch = (Short.MaxValue - 1).toShort
+
+    val txnMetadata = new TransactionMetadata(
+      transactionalId = transactionalId,
+      producerId = producerId,
+      producerEpoch = producerEpoch,
+      txnTimeoutMs = 30000,
+      state = Empty,
+      topicPartitions = mutable.Set.empty,
+      txnLastUpdateTimestamp = time.milliseconds())
+
+    val newProducerId = 9893L
+    val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds())
+    txnMetadata.completeTransitionTo(transitMetadata)
+    assertEquals(newProducerId, txnMetadata.producerId)
+    assertEquals(0, txnMetadata.producerEpoch)
+  }
+
+  @Test(expected = classOf[IllegalStateException])
+  def testRotateProducerIdInOngoingState(): Unit = {
+    testRotateProducerIdInOngoingState(Ongoing)
+  }
+
+  @Test(expected = classOf[IllegalStateException])
+  def testRotateProducerIdInPrepareAbortState(): Unit = {
+    testRotateProducerIdInOngoingState(PrepareAbort)
+  }
+
+  @Test(expected = classOf[IllegalStateException])
+  def testRotateProducerIdInPrepareCommitState(): Unit = {
+    testRotateProducerIdInOngoingState(PrepareCommit)
+  }
+
+  private def testRotateProducerIdInOngoingState(state: TransactionState): Unit = {
+    val transactionalId = "txnlId"
+    val producerId = 23423L
+    val producerEpoch = (Short.MaxValue - 1).toShort
+
+    val txnMetadata = new TransactionMetadata(
+      transactionalId = transactionalId,
+      producerId = producerId,
+      producerEpoch = producerEpoch,
+      txnTimeoutMs = 30000,
+      state = state,
+      topicPartitions = mutable.Set.empty,
+      txnLastUpdateTimestamp = time.milliseconds())
+    val newProducerId = 9893L
+    txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds())
+  }
+
+
+}


Mime
View raw message