kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From guozh...@apache.org
Subject [2/3] kafka git commit: KAFKA-5130: Refactor transaction coordinator's in-memory cache; plus fixes on transaction metadata synchronization
Date Fri, 12 May 2017 22:01:06 GMT
http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index a81e47b..a76617e 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -19,9 +19,9 @@ package kafka.coordinator.transaction
 import kafka.utils.nonthreadsafe
 import org.apache.kafka.common.TopicPartition
 
-import scala.collection.mutable
+import scala.collection.{immutable, mutable}
 
-private[coordinator] sealed trait TransactionState { def byte: Byte }
+private[transaction] sealed trait TransactionState { def byte: Byte }
 
 /**
  * Transaction has not existed yet
@@ -29,7 +29,7 @@ private[coordinator] sealed trait TransactionState { def byte: Byte }
  * transition: received AddPartitionsToTxnRequest => Ongoing
  *             received AddOffsetsToTxnRequest => Ongoing
  */
-private[coordinator] case object Empty extends TransactionState { val byte: Byte = 0 }
+private[transaction] case object Empty extends TransactionState { val byte: Byte = 0 }
 
 /**
  * Transaction has started and ongoing
@@ -39,37 +39,37 @@ private[coordinator] case object Empty extends TransactionState { val byte: Byte
  *             received AddPartitionsToTxnRequest => Ongoing
  *             received AddOffsetsToTxnRequest => Ongoing
  */
-private[coordinator] case object Ongoing extends TransactionState { val byte: Byte = 1 }
+private[transaction] case object Ongoing extends TransactionState { val byte: Byte = 1 }
 
 /**
  * Group is preparing to commit
  *
  * transition: received acks from all partitions => CompleteCommit
  */
-private[coordinator] case object PrepareCommit extends TransactionState { val byte: Byte = 2}
+private[transaction] case object PrepareCommit extends TransactionState { val byte: Byte = 2}
 
 /**
  * Group is preparing to abort
  *
  * transition: received acks from all partitions => CompleteAbort
  */
-private[coordinator] case object PrepareAbort extends TransactionState { val byte: Byte = 3 }
+private[transaction] case object PrepareAbort extends TransactionState { val byte: Byte = 3 }
 
 /**
  * Group has completed commit
  *
  * Will soon be removed from the ongoing transaction cache
  */
-private[coordinator] case object CompleteCommit extends TransactionState { val byte: Byte = 4 }
+private[transaction] case object CompleteCommit extends TransactionState { val byte: Byte = 4 }
 
 /**
  * Group has completed abort
  *
  * Will soon be removed from the ongoing transaction cache
  */
-private[coordinator] case object CompleteAbort extends TransactionState { val byte: Byte = 5 }
+private[transaction] case object CompleteAbort extends TransactionState { val byte: Byte = 5 }
 
-private[coordinator] object TransactionMetadata {
+private[transaction] object TransactionMetadata {
   def apply(pid: Long, epoch: Short, txnTimeoutMs: Int, timestamp: Long) = new TransactionMetadata(pid, epoch, txnTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
 
   def apply(pid: Long, epoch: Short, txnTimeoutMs: Int, state: TransactionState, timestamp: Long) = new TransactionMetadata(pid, epoch, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
@@ -89,7 +89,7 @@ private[coordinator] object TransactionMetadata {
   def isValidTransition(oldState: TransactionState, newState: TransactionState): Boolean = TransactionMetadata.validPreviousStates(newState).contains(oldState)
 
   private val validPreviousStates: Map[TransactionState, Set[TransactionState]] =
-    Map(Empty -> Set(),
+    Map(Empty -> Set(Empty, CompleteCommit, CompleteAbort),
       Ongoing -> Set(Ongoing, Empty, CompleteCommit, CompleteAbort),
       PrepareCommit -> Set(Ongoing),
       PrepareAbort -> Set(Ongoing),
@@ -97,24 +97,33 @@ private[coordinator] object TransactionMetadata {
       CompleteAbort -> Set(PrepareAbort))
 }
 
+// this is a immutable object representing the target transition of the transaction metadata
+private[transaction] case class TransactionMetadataTransition(producerId: Long,
+                                                              producerEpoch: Short,
+                                                              txnTimeoutMs: Int,
+                                                              txnState: TransactionState,
+                                                              topicPartitions: immutable.Set[TopicPartition],
+                                                              txnStartTimestamp: Long,
+                                                              txnLastUpdateTimestamp: Long)
+
 /**
   *
-  * @param pid                   producer id
+  * @param producerId            producer id
   * @param producerEpoch         current epoch of the producer
   * @param txnTimeoutMs          timeout to be used to abort long running transactions
-  * @param state                 the current state of the transaction
-  * @param topicPartitions       set of partitions that are part of this transaction
-  * @param transactionStartTime  time the transaction was started, i.e., when first partition is added
-  * @param lastUpdateTimestamp   updated when any operation updates the TransactionMetadata. To be used for expiration
+  * @param state                 current state of the transaction
+  * @param topicPartitions       current set of partitions that are part of this transaction
+  * @param txnStartTimestamp     time the transaction was started, i.e., when first partition is added
+  * @param txnLastUpdateTimestamp   updated when any operation updates the TransactionMetadata. To be used for expiration
   */
 @nonthreadsafe
-private[coordinator] class TransactionMetadata(val pid: Long,
+private[transaction] class TransactionMetadata(val producerId: Long,
                                                var producerEpoch: Short,
                                                var txnTimeoutMs: Int,
                                                var state: TransactionState,
                                                val topicPartitions: mutable.Set[TopicPartition],
-                                               var transactionStartTime: Long = -1,
-                                               var lastUpdateTimestamp: Long) {
+                                               var txnStartTimestamp: Long = -1,
+                                               var txnLastUpdateTimestamp: Long) {
 
   // pending state is used to indicate the state that this transaction is going to
   // transit to, and for blocking future attempts to transit it again if it is not legal;
@@ -125,50 +134,171 @@ private[coordinator] class TransactionMetadata(val pid: Long,
     topicPartitions ++= partitions
   }
 
-  def prepareTransitionTo(newState: TransactionState): Boolean = {
+  def removePartition(topicPartition: TopicPartition): Unit = {
+    if (pendingState.isDefined || (state != PrepareCommit && state != PrepareAbort))
+      throw new IllegalStateException(s"Transation metadata's current state is $state, and its pending state is $state " +
+        s"while trying to remove partitions whose txn marker has been sent, this is not expected")
+
+    topicPartitions -= topicPartition
+  }
+
+  def prepareNoTransit(): TransactionMetadataTransition =
+    // do not call transitTo as it will set the pending state
+    TransactionMetadataTransition(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
+
+  def prepareIncrementProducerEpoch(newTxnTimeoutMs: Int,
+                                    updateTimestamp: Long): TransactionMetadataTransition = {
+
+    prepareTransitionTo(Empty, (producerEpoch + 1).toShort, newTxnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp)
+  }
+
+  def prepareNewPid(updateTimestamp: Long): TransactionMetadataTransition = {
+
+    prepareTransitionTo(Empty, producerEpoch, txnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp)
+  }
+
+  def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition],
+                           updateTimestamp: Long): TransactionMetadataTransition = {
+
+    if (state == Empty || state == CompleteCommit || state == CompleteAbort) {
+      prepareTransitionTo(Ongoing, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet, updateTimestamp, updateTimestamp)
+    } else {
+      prepareTransitionTo(Ongoing, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet, txnStartTimestamp, updateTimestamp)
+    }
+  }
+
+  def prepareAbortOrCommit(newState: TransactionState,
+                           updateTimestamp: Long): TransactionMetadataTransition = {
+
+    prepareTransitionTo(newState, producerEpoch, txnTimeoutMs, topicPartitions.toSet, txnStartTimestamp, updateTimestamp)
+  }
+
+  def prepareComplete(updateTimestamp: Long): TransactionMetadataTransition = {
+    val newState = if (state == PrepareCommit) CompleteCommit else CompleteAbort
+    prepareTransitionTo(newState, producerEpoch, txnTimeoutMs, topicPartitions.toSet, txnStartTimestamp, updateTimestamp)
+  }
+
+  // visible for testing only
+  def copy(): TransactionMetadata = {
+    val cloned = new TransactionMetadata(producerId, producerEpoch, txnTimeoutMs, state,
+      mutable.Set.empty ++ topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
+    cloned.pendingState = pendingState
+
+    cloned
+  }
+
+  private def prepareTransitionTo(newState: TransactionState,
+                                  newEpoch: Short,
+                                  newTxnTimeoutMs: Int,
+                                  newTopicPartitions: immutable.Set[TopicPartition],
+                                  newTxnStartTimestamp: Long,
+                                  updateTimestamp: Long): TransactionMetadataTransition = {
     if (pendingState.isDefined)
-      throw new IllegalStateException(s"Preparing transaction state transition to $newState while it already a pending state ${pendingState.get}")
+      throw new IllegalStateException(s"Preparing transaction state transition to $newState " +
+        s"while it already a pending state ${pendingState.get}")
 
     // check that the new state transition is valid and update the pending state if necessary
     if (TransactionMetadata.validPreviousStates(newState).contains(state)) {
       pendingState = Some(newState)
-      true
+
+      TransactionMetadataTransition(producerId, newEpoch, newTxnTimeoutMs, newState, newTopicPartitions, newTxnStartTimestamp, updateTimestamp)
     } else {
-      false
+      throw new IllegalStateException(s"Preparing transaction state transition to $newState failed since the target state" +
+        s" $newState is not a valid previous state of the current state $state")
     }
   }
 
-  def completeTransitionTo(newState: TransactionState): Boolean = {
+  def completeTransitionTo(newMetadata: TransactionMetadataTransition): Unit = {
+    // metadata transition is valid only if all the following conditions are met:
+    //
+    // 1. the new state is already indicated in the pending state.
+    // 2. the pid is the same (i.e. this field should never be changed)
+    // 3. the epoch should be either the same value or old value + 1.
+    // 4. the last update time is no smaller than the old value.
+    // 4. the old partitions set is a subset of the new partitions set.
+    //
+    // plus, we should only try to update the metadata after the corresponding log entry has been successfully written and replicated (see TransactionStateManager#appendTransactionToLog)
+    //
+    // if valid, transition is done via overwriting the whole object to ensure synchronization
+
     val toState = pendingState.getOrElse(throw new IllegalStateException("Completing transaction state transition while it does not have a pending state"))
-    if (toState != newState) {
-      false
+
+    if (toState != newMetadata.txnState ||
+      producerId != newMetadata.producerId ||
+      txnLastUpdateTimestamp > newMetadata.txnLastUpdateTimestamp) {
+
+      throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata state")
     } else {
+      val updated = toState match {
+        case Empty => // from initPid
+          if (producerEpoch > newMetadata.producerEpoch ||
+            producerEpoch < newMetadata.producerEpoch - 1 ||
+            newMetadata.topicPartitions.nonEmpty ||
+            newMetadata.txnStartTimestamp != -1) {
+
+            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+          } else {
+            txnTimeoutMs = newMetadata.txnTimeoutMs
+            producerEpoch = newMetadata.producerEpoch
+          }
+
+        case Ongoing => // from addPartitions
+          if (producerEpoch != newMetadata.producerEpoch ||
+            !topicPartitions.subsetOf(newMetadata.topicPartitions) ||
+            txnTimeoutMs != newMetadata.txnTimeoutMs ||
+            txnStartTimestamp > newMetadata.txnStartTimestamp) {
+
+            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+          } else {
+            txnStartTimestamp = newMetadata.txnStartTimestamp
+            addPartitions(newMetadata.topicPartitions)
+          }
+
+        case PrepareAbort | PrepareCommit => // from endTxn
+          if (producerEpoch != newMetadata.producerEpoch ||
+            !topicPartitions.toSet.equals(newMetadata.topicPartitions) ||
+            txnTimeoutMs != newMetadata.txnTimeoutMs ||
+            txnStartTimestamp != newMetadata.txnStartTimestamp) {
+
+            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+          }
+
+        case CompleteAbort | CompleteCommit => // from write markers
+          if (producerEpoch != newMetadata.producerEpoch ||
+            txnTimeoutMs != newMetadata.txnTimeoutMs ||
+            newMetadata.txnStartTimestamp == -1) {
+
+            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+          } else {
+            txnStartTimestamp = newMetadata.txnStartTimestamp
+            topicPartitions.clear()
+          }
+      }
+
+      txnLastUpdateTimestamp = newMetadata.txnLastUpdateTimestamp
       pendingState = None
       state = toState
-      true
     }
   }
 
-  def copy(): TransactionMetadata =
-    new TransactionMetadata(pid, producerEpoch, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition] ++ topicPartitions, transactionStartTime, lastUpdateTimestamp)
+  def pendingTransitionInProgress: Boolean = pendingState.isDefined
 
-  override def toString = s"TransactionMetadata($pendingState, $pid, $producerEpoch, $txnTimeoutMs, $state, $topicPartitions, $transactionStartTime, $lastUpdateTimestamp)"
+  override def toString = s"TransactionMetadata($pendingState, $producerId, $producerEpoch, $txnTimeoutMs, $state, $topicPartitions, $txnStartTimestamp, $txnLastUpdateTimestamp)"
 
   override def equals(that: Any): Boolean = that match {
     case other: TransactionMetadata =>
-      pid == other.pid &&
+      producerId == other.producerId &&
       producerEpoch == other.producerEpoch &&
       txnTimeoutMs == other.txnTimeoutMs &&
       state.equals(other.state) &&
       topicPartitions.equals(other.topicPartitions) &&
-      transactionStartTime == other.transactionStartTime &&
-      lastUpdateTimestamp == other.lastUpdateTimestamp
+      txnStartTimestamp == other.txnStartTimestamp &&
+      txnLastUpdateTimestamp == other.txnLastUpdateTimestamp
     case _ => false
   }
 
-
   override def hashCode(): Int = {
-    val state = Seq(pid, txnTimeoutMs, topicPartitions, lastUpdateTimestamp)
-    state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
+    val fields = Seq(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions, txnStartTimestamp, txnLastUpdateTimestamp)
+    fields.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
   }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index f5dc3c0..7a03fc3 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -33,17 +33,19 @@ import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.{FileRecords, MemoryRecords, SimpleRecord}
 import org.apache.kafka.common.requests.IsolationLevel
 import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
+import org.apache.kafka.common.requests.TransactionResult
 import org.apache.kafka.common.utils.{Time, Utils}
 
 import scala.collection.mutable
 import scala.collection.JavaConverters._
 
 
-object TransactionManager {
+object TransactionStateManager {
   // default transaction management config values
-  val DefaultTransactionalIdExpirationMs = TimeUnit.DAYS.toMillis(7).toInt
-  val DefaultTransactionsMaxTimeoutMs = TimeUnit.MINUTES.toMillis(15).toInt
-  val DefaultRemoveExpiredTransactionsIntervalMs = TimeUnit.MINUTES.toMillis(1).toInt
+  // TODO: this needs to be replaces by the config values
+  val DefaultTransactionsMaxTimeoutMs: Int = TimeUnit.MINUTES.toMillis(15).toInt
+  val DefaultTransactionalIdExpirationMs: Int = TimeUnit.DAYS.toMillis(7).toInt
+  val DefaultRemoveExpiredTransactionsIntervalMs: Int = TimeUnit.MINUTES.toMillis(1).toInt
 }
 
 /**
@@ -62,7 +64,7 @@ class TransactionStateManager(brokerId: Int,
 
   this.logIdent = "[Transaction Log Manager " + brokerId + "]: "
 
-  type WriteTxnMarkers = WriteTxnMarkerArgs => Unit
+  type SendTxnMarkersCallback = (String, Int, TransactionResult, TransactionMetadata, TransactionMetadataTransition) => Unit
 
   /** shutting down flag */
   private val shuttingDown = new AtomicBoolean(false)
@@ -70,40 +72,72 @@ class TransactionStateManager(brokerId: Int,
   /** lock protecting access to loading and owned partition sets */
   private val stateLock = new ReentrantLock()
 
-  /** partitions of transaction topic that are assigned to this manager, partition lock should be called BEFORE accessing this set */
-  private val ownedPartitions: mutable.Map[Int, Int] = mutable.Map()
-
   /** partitions of transaction topic that are being loaded, partition lock should be called BEFORE accessing this set */
   private val loadingPartitions: mutable.Set[Int] = mutable.Set()
 
-  /** transaction metadata cache indexed by transactional id */
-  private val transactionMetadataCache = new Pool[String, TransactionMetadata]
+  /** transaction metadata cache indexed by assigned transaction topic partition ids */
+  private val transactionMetadataCache: mutable.Map[Int, TxnMetadataCacheEntry] = mutable.Map()
 
   /** number of partitions for the transaction log topic */
   private val transactionTopicPartitionCount = getTransactionTopicPartitionCount
 
+  // this is best-effort expiration and hence not grabing the lock on metadata upon checking its state
+  // we will get the lock when actually trying to transit the transaction metadata to abort later.
+  def transactionsToExpire(): Iterable[TransactionalIdAndProducerIdEpoch] = {
+    val now = time.milliseconds()
+    transactionMetadataCache.flatMap { case (_, entry) =>
+        entry.metadataPerTransactionalId.filter { case (txnId, txnMetadata) =>
+          if (isCoordinatorLoadingInProgress(txnId) || txnMetadata.pendingTransitionInProgress) {
+            false
+          } else {
+            txnMetadata.state match {
+              case Ongoing =>
+                txnMetadata.txnStartTimestamp + txnMetadata.txnTimeoutMs < now
+              case _ => false
+            }
+          }
+        }.map { case (txnId, txnMetadata) =>
+          TransactionalIdAndProducerIdEpoch(txnId, txnMetadata.producerId, txnMetadata.producerEpoch)
+        }
+    }
+  }
+
   def enablePidExpiration() {
-    if (!scheduler.isStarted)
-      scheduler.startup()
     // TODO: add pid expiration logic
   }
 
   /**
    * Get the transaction metadata associated with the given transactional id, or null if not found
    */
-  def getTransactionState(transactionalId: String): Option[TransactionMetadata] = {
-    Option(transactionMetadataCache.get(transactionalId))
+  def getTransactionState(transactionalId: String): Option[CoordinatorEpochAndTxnMetadata] = {
+    val partitionId = partitionFor(transactionalId)
+
+    transactionMetadataCache.get(partitionId).flatMap { cacheEntry =>
+      cacheEntry.metadataPerTransactionalId.get(transactionalId) match {
+        case null => None
+        case txnMetadata => Some(CoordinatorEpochAndTxnMetadata(cacheEntry.coordinatorEpoch, txnMetadata))
+      }
+    }
   }
 
   /**
    * Add a new transaction metadata, or retrieve the metadata if it already exists with the associated transactional id
+   * along with the current coordinator epoch for that belonging transaction topic partition
    */
-  def addTransaction(transactionalId: String, txnMetadata: TransactionMetadata): TransactionMetadata = {
-    val currentTxnMetadata = transactionMetadataCache.putIfNotExists(transactionalId, txnMetadata)
-    if (currentTxnMetadata != null) {
-      currentTxnMetadata
-    } else {
-      txnMetadata
+  def addTransaction(transactionalId: String, txnMetadata: TransactionMetadata): CoordinatorEpochAndTxnMetadata = {
+    val partitionId = partitionFor(transactionalId)
+
+    transactionMetadataCache.get(partitionId) match {
+      case Some(txnMetadataCacheEntry) =>
+        val currentTxnMetadata = txnMetadataCacheEntry.metadataPerTransactionalId.putIfNotExists(transactionalId, txnMetadata)
+        if (currentTxnMetadata != null) {
+          CoordinatorEpochAndTxnMetadata(txnMetadataCacheEntry.coordinatorEpoch, currentTxnMetadata)
+        } else {
+          CoordinatorEpochAndTxnMetadata(txnMetadataCacheEntry.coordinatorEpoch, txnMetadata)
+        }
+
+      case None =>
+        throw new IllegalStateException(s"The metadata cache entry for txn partition $partitionId does not exist.")
     }
   }
 
@@ -129,13 +163,13 @@ class TransactionStateManager(brokerId: Int,
 
   def partitionFor(transactionalId: String): Int = Utils.abs(transactionalId.hashCode) % transactionTopicPartitionCount
 
-  def coordinatorEpochFor(transactionId: String): Option[Int] = inLock (stateLock) {
-    ownedPartitions.get(partitionFor(transactionId))
+  def isCoordinatorFor(txnTopicPartitionId: Int): Boolean = inLock(stateLock) {
+    transactionMetadataCache.contains(txnTopicPartitionId)
   }
 
   def isCoordinatorFor(transactionalId: String): Boolean = inLock(stateLock) {
     val partitionId = partitionFor(transactionalId)
-    ownedPartitions.contains(partitionId)
+    transactionMetadataCache.contains(partitionId)
   }
 
   def isCoordinatorLoadingInProgress(transactionalId: String): Boolean = inLock(stateLock) {
@@ -143,19 +177,6 @@ class TransactionStateManager(brokerId: Int,
     loadingPartitions.contains(partitionId)
   }
 
-
-  def transactionsToExpire(): Iterable[TransactionalIdAndMetadata] = {
-    val now = time.milliseconds()
-    transactionMetadataCache.filter { case (_, metadata) =>
-      metadata.state match {
-        case Ongoing =>
-          metadata.transactionStartTime + metadata.txnTimeoutMs < now
-        case _ => false
-      }
-    }.map {case (id, metadata) =>
-      TransactionalIdAndMetadata(id, metadata)
-    }
-  }
   /**
    * Gets the partition count of the transaction log topic from ZooKeeper.
    * If the topic does not exist, the default partition count is returned.
@@ -164,162 +185,159 @@ class TransactionStateManager(brokerId: Int,
     zkUtils.getTopicPartitionCount(Topic.TransactionStateTopicName).getOrElse(config.transactionLogNumPartitions)
   }
 
-  private def loadTransactionMetadata(topicPartition: TopicPartition, writeTxnMarkers: WriteTxnMarkers) {
-    def highWaterMark = replicaManager.getLogEndOffset(topicPartition).getOrElse(-1L)
+  private def loadTransactionMetadata(topicPartition: TopicPartition, coordinatorEpoch: Int): Pool[String, TransactionMetadata] =  {
+    def logEndOffset = replicaManager.getLogEndOffset(topicPartition).getOrElse(-1L)
 
     val startMs = time.milliseconds()
+    val loadedTransactions = new Pool[String, TransactionMetadata]
+
     replicaManager.getLog(topicPartition) match {
       case None =>
         warn(s"Attempted to load offsets and group metadata from $topicPartition, but found no log")
 
       case Some(log) =>
         lazy val buffer = ByteBuffer.allocate(config.transactionLogLoadBufferSize)
-        val loadedTransactions = mutable.Map.empty[String, TransactionMetadata]
-        val removedTransactionalIds = mutable.Set.empty[String]
 
         // loop breaks if leader changes at any time during the load, since getHighWatermark is -1
         var currOffset = log.logStartOffset
-        while (currOffset < highWaterMark
-                && loadingPartitions.contains(topicPartition.partition())
-                && !shuttingDown.get()) {
-          buffer.clear()
-          val fetchDataInfo = log.read(currOffset, config.transactionLogLoadBufferSize, maxOffset = None,
-            minOneMessage = true, isolationLevel = IsolationLevel.READ_UNCOMMITTED)
-          val memRecords = fetchDataInfo.records match {
-            case records: MemoryRecords => records
-            case fileRecords: FileRecords =>
-              buffer.clear()
-              val bufferRead = fileRecords.readInto(buffer, 0)
-              MemoryRecords.readableRecords(bufferRead)
-          }
-
-          memRecords.batches.asScala.foreach { batch =>
-            for (record <- batch.asScala) {
-              require(record.hasKey, "Transaction state log's key should not be null")
-              TransactionLog.readMessageKey(record.key) match {
-
-                case txnKey: TxnKey =>
-                  // load transaction metadata along with transaction state
-                  val transactionalId: String = txnKey.transactionalId
-                  if (!record.hasValue) {
-                    loadedTransactions.remove(transactionalId)
-                    removedTransactionalIds.add(transactionalId)
-                  } else {
-                    val txnMetadata = TransactionLog.readMessageValue(record.value)
-                    loadedTransactions.put(transactionalId, txnMetadata)
-                    removedTransactionalIds.remove(transactionalId)
-                  }
-
-                case unknownKey =>
-                  // TODO: Metrics
-                  throw new IllegalStateException(s"Unexpected message key $unknownKey while loading offsets and group metadata")
-              }
 
-              currOffset = batch.nextOffset
+        try {
+          while (currOffset < logEndOffset
+            && loadingPartitions.contains(topicPartition.partition())
+            && !shuttingDown.get()) {
+            val fetchDataInfo = log.read(currOffset, config.transactionLogLoadBufferSize, maxOffset = None,
+              minOneMessage = true, isolationLevel = IsolationLevel.READ_UNCOMMITTED)
+            val memRecords = fetchDataInfo.records match {
+              case records: MemoryRecords => records
+              case fileRecords: FileRecords =>
+                buffer.clear()
+                val bufferRead = fileRecords.readInto(buffer, 0)
+                MemoryRecords.readableRecords(bufferRead)
             }
-          }
 
-          loadedTransactions.foreach {
-            case (transactionalId, txnMetadata) =>
-              val currentTxnMetadata = addTransaction(transactionalId, txnMetadata)
-              if (!txnMetadata.eq(currentTxnMetadata)) {
-                // treat this as a fatal failure as this should never happen
-                fatal(s"Attempt to load $transactionalId's metadata $txnMetadata failed " +
-                  s"because there is already a different cached transaction metadata $currentTxnMetadata.")
+            memRecords.batches.asScala.foreach { batch =>
+              for (record <- batch.asScala) {
+                require(record.hasKey, "Transaction state log's key should not be null")
+                TransactionLog.readMessageKey(record.key) match {
+
+                  case txnKey: TxnKey =>
+                    // load transaction metadata along with transaction state
+                    val transactionalId: String = txnKey.transactionalId
+                    if (!record.hasValue) {
+                      loadedTransactions.remove(transactionalId)
+                    } else {
+                      val txnMetadata = TransactionLog.readMessageValue(record.value)
+                      loadedTransactions.put(transactionalId, txnMetadata)
+                    }
+
+                  case unknownKey =>
+                    // TODO: Metrics
+                    throw new IllegalStateException(s"Unexpected message key $unknownKey while loading offsets and group metadata")
+                }
 
-                throw new KafkaException("Loading transaction topic partition failed.")
-              }
-              // if state is PrepareCommit or PrepareAbort we need to complete the transaction
-              if (currentTxnMetadata.state == PrepareCommit || currentTxnMetadata.state == PrepareAbort) {
-                writeTxnMarkers(WriteTxnMarkerArgs(transactionalId,
-                  txnMetadata.pid,
-                  txnMetadata.producerEpoch,
-                  txnMetadata.state,
-                  txnMetadata,
-                  coordinatorEpochFor(transactionalId).get
-                ))
+                currOffset = batch.nextOffset
               }
-          }
-
-          removedTransactionalIds.foreach { transactionalId =>
-            if (transactionMetadataCache.contains(transactionalId)) {
-              // the cache already contains a transaction which should be removed,
-              // treat this as a fatal failure as this should never happen
-              fatal(s"Unexpected to see $transactionalId's metadata while " +
-                s"loading partition $topicPartition since its latest state is a tombstone")
-
-              throw new KafkaException("Loading transaction topic partition failed.")
             }
-          }
 
-          info(s"Finished loading ${loadedTransactions.size} transaction metadata from $topicPartition in ${time.milliseconds() - startMs} milliseconds")
+            info(s"Finished loading ${loadedTransactions.size} transaction metadata from $topicPartition in ${time.milliseconds() - startMs} milliseconds")
+          }
+        } catch {
+          case t: Throwable => error(s"Error loading transactions from transaction log $topicPartition", t)
         }
     }
+
+    loadedTransactions
+  }
+
+  /**
+    * Add a transaction topic partition into the cache
+    */
+  def addLoadedTransactionsToCache(txnTopicPartition: Int, coordinatorEpoch: Int, metadataPerTransactionalId: Pool[String, TransactionMetadata]): Unit = {
+    val txnMetadataCacheEntry = TxnMetadataCacheEntry(coordinatorEpoch, metadataPerTransactionalId)
+    val currentTxnMetadataCacheEntry = transactionMetadataCache.put(txnTopicPartition, txnMetadataCacheEntry)
+
+    if (currentTxnMetadataCacheEntry.isDefined) {
+      val coordinatorEpoch = currentTxnMetadataCacheEntry.get.coordinatorEpoch
+      val metadataPerTxnId = currentTxnMetadataCacheEntry.get.metadataPerTransactionalId
+      info(s"The metadata cache for txn partition $txnTopicPartition has already exist with epoch $coordinatorEpoch " +
+        s"and ${metadataPerTxnId.size} entries while trying to add to it; " +
+        s"it is likely that another process for loading from the transaction log has just executed earlier before")
+
+      throw new IllegalStateException(s"The metadata cache entry for txn partition $txnTopicPartition has already exist while trying to add to it.")
+    }
   }
 
   /**
    * When this broker becomes a leader for a transaction log partition, load this partition and
    * populate the transaction metadata cache with the transactional ids.
    */
-  def loadTransactionsForPartition(partition: Int, coordinatorEpoch: Int, writeTxnMarkers: WriteTxnMarkers) {
+  def loadTransactionsForTxnTopicPartition(partitionId: Int, coordinatorEpoch: Int, sendTxnMarkers: SendTxnMarkersCallback) {
     validateTransactionTopicPartitionCountIsStable()
 
-    val topicPartition = new TopicPartition(Topic.TransactionStateTopicName, partition)
+    val topicPartition = new TopicPartition(Topic.TransactionStateTopicName, partitionId)
 
     inLock(stateLock) {
-      ownedPartitions.put(partition, coordinatorEpoch)
-      loadingPartitions.add(partition)
+      loadingPartitions.add(partitionId)
     }
 
     def loadTransactions() {
       info(s"Loading transaction metadata from $topicPartition")
-      try {
-        loadTransactionMetadata(topicPartition, writeTxnMarkers)
-      } catch {
-        case t: Throwable => error(s"Error loading transactions from transaction log $topicPartition", t)
-      } finally {
-        inLock(stateLock) {
-          loadingPartitions.remove(partition)
-        }
+      val loadedTransactions = loadTransactionMetadata(topicPartition, coordinatorEpoch)
+
+      loadedTransactions.foreach {
+        case (transactionalId, txnMetadata) =>
+          val result = txnMetadata synchronized {
+            // if state is PrepareCommit or PrepareAbort we need to complete the transaction
+            txnMetadata.state match {
+              case PrepareAbort =>
+                Some(TransactionResult.ABORT, txnMetadata.prepareComplete(time.milliseconds()))
+              case PrepareCommit =>
+                Some(TransactionResult.COMMIT, txnMetadata.prepareComplete(time.milliseconds()))
+              case _ =>
+                // nothing need to be done
+                None
+            }
+          }
+
+          result.foreach { case (command, newMetadata) =>
+            sendTxnMarkers(transactionalId, coordinatorEpoch, command, txnMetadata, newMetadata)
+          }
+      }
+
+      inLock(stateLock) {
+        addLoadedTransactionsToCache(topicPartition.partition, coordinatorEpoch, loadedTransactions)
+        loadingPartitions.remove(partitionId)
       }
     }
 
-    scheduler.schedule(topicPartition.toString, loadTransactions _)
+    scheduler.schedule(s"load-txns-for-partition-$topicPartition", loadTransactions _)
   }
 
   /**
    * When this broker becomes a follower for a transaction log partition, clear out the cache for corresponding transactional ids
    * that belong to that partition.
    */
-  def removeTransactionsForPartition(partition: Int) {
+  def removeTransactionsForTxnTopicPartition(partitionId: Int) {
     validateTransactionTopicPartitionCountIsStable()
 
-    val topicPartition = new TopicPartition(Topic.TransactionStateTopicName, partition)
-
-    inLock(stateLock) {
-      ownedPartitions.remove(partition)
-      loadingPartitions.remove(partition)
-    }
+    val topicPartition = new TopicPartition(Topic.TransactionStateTopicName, partitionId)
 
     def removeTransactions() {
-      var numTxnsRemoved = 0
-
       inLock(stateLock) {
-        for (transactionalId <- transactionMetadataCache.keys) {
-          if (partitionFor(transactionalId) == partition) {
-            // we do not need to worry about whether the transactional id has any ongoing transaction or not since
-            // the new leader will handle it
-            transactionMetadataCache.remove(transactionalId)
-            numTxnsRemoved += 1
-          }
+        transactionMetadataCache.remove(partitionId) match {
+          case Some(txnMetadataCacheEntry) =>
+            info(s"Removed ${txnMetadataCacheEntry.metadataPerTransactionalId.size} cached transaction metadata for $topicPartition on follower transition")
+
+          case None =>
+            info(s"Trying to remove cached transaction metadata for $topicPartition on follower transition but there is no entries remaining; " +
+              s"it is likely that another process for removing the cached entries has just executed earlier before")
         }
 
-        if (numTxnsRemoved > 0)
-          info(s"Removed $numTxnsRemoved cached transaction metadata for $topicPartition on follower transition")
+        loadingPartitions.remove(partitionId)
       }
     }
 
-    scheduler.schedule(topicPartition.toString, removeTransactions _)
+    scheduler.schedule(s"remove-txns-for-partition-$topicPartition", removeTransactions _)
   }
 
   private def validateTransactionTopicPartitionCountIsStable(): Unit = {
@@ -330,12 +348,13 @@ class TransactionStateManager(brokerId: Int,
 
   // TODO: check broker message format and error if < V2
   def appendTransactionToLog(transactionalId: String,
-                             txnMetadata: TransactionMetadata,
-                             responseCallback: Errors => Unit) {
+                             coordinatorEpoch: Int,
+                             newMetadata: TransactionMetadataTransition,
+                             responseCallback: Errors => Unit): Unit = {
 
     // generate the message for this transaction metadata
     val keyBytes = TransactionLog.keyToBytes(transactionalId)
-    val valueBytes = TransactionLog.valueToBytes(txnMetadata)
+    val valueBytes = TransactionLog.valueToBytes(newMetadata)
     val timestamp = time.milliseconds()
 
     val records = MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, new SimpleRecord(timestamp, keyBytes, valueBytes))
@@ -355,7 +374,7 @@ class TransactionStateManager(brokerId: Int,
       var responseError = if (status.error == Errors.NONE) {
         Errors.NONE
       } else {
-        debug(s"Transaction state update $txnMetadata for $transactionalId failed when appending to log " +
+        debug(s"Transaction state update $newMetadata for $transactionalId failed when appending to log " +
           s"due to ${status.error.exceptionName}")
 
         // transform the log append error code to the corresponding coordinator error code
@@ -365,14 +384,14 @@ class TransactionStateManager(brokerId: Int,
                | Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND
                | Errors.REQUEST_TIMED_OUT => // note that for timed out request we return NOT_AVAILABLE error code to let client retry
 
-            debug(s"Appending transaction message $txnMetadata for $transactionalId failed due to " +
+            info(s"Appending transaction message $newMetadata for $transactionalId failed due to " +
               s"${status.error.exceptionName}, returning ${Errors.COORDINATOR_NOT_AVAILABLE} to the client")
 
             Errors.COORDINATOR_NOT_AVAILABLE
 
           case Errors.NOT_LEADER_FOR_PARTITION =>
 
-            debug(s"Appending transaction message $txnMetadata for $transactionalId failed due to " +
+            info(s"Appending transaction message $newMetadata for $transactionalId failed due to " +
               s"${status.error.exceptionName}, returning ${Errors.NOT_COORDINATOR} to the client")
 
             Errors.NOT_COORDINATOR
@@ -380,13 +399,13 @@ class TransactionStateManager(brokerId: Int,
           case Errors.MESSAGE_TOO_LARGE
                | Errors.RECORD_LIST_TOO_LARGE =>
 
-            error(s"Appending transaction message $txnMetadata for $transactionalId failed due to " +
+            error(s"Appending transaction message $newMetadata for $transactionalId failed due to " +
               s"${status.error.exceptionName}, returning UNKNOWN error code to the client")
 
             Errors.UNKNOWN
 
           case other =>
-            error(s"Appending metadata message $txnMetadata for $transactionalId failed due to " +
+            error(s"Appending metadata message $newMetadata for $transactionalId failed due to " +
               s"unexpected error: ${status.error.message}")
 
             other
@@ -394,44 +413,42 @@ class TransactionStateManager(brokerId: Int,
       }
 
       if (responseError == Errors.NONE) {
-        def completeStateTransition(metadata: TransactionMetadata, newState: TransactionState): Boolean = {
-          // there is no transition in this case
-          if (metadata.state == Empty && newState == Empty)
-            true
-          else
-            metadata.completeTransitionTo(txnMetadata.state)
-        }
         // now try to update the cache: we need to update the status in-place instead of
         // overwriting the whole object to ensure synchronization
-          getTransactionState(transactionalId) match {
-            case Some(metadata) =>
-              metadata synchronized {
-                if (metadata.pid == txnMetadata.pid &&
-                  metadata.producerEpoch == txnMetadata.producerEpoch &&
-                  metadata.txnTimeoutMs == txnMetadata.txnTimeoutMs &&
-                  completeStateTransition(metadata, txnMetadata.state)) {
-                  // only topic-partition lists could possibly change (state should have transited in the above condition)
-                  metadata.addPartitions(txnMetadata.topicPartitions.toSet)
-                } else {
-                  throw new IllegalStateException(s"Completing transaction state transition to $txnMetadata while its current state is $metadata.")
-                }
+        getTransactionState(transactionalId) match {
+          case Some(epochAndMetadata) =>
+            val metadata = epochAndMetadata.transactionMetadata
+
+            metadata synchronized {
+              if (epochAndMetadata.coordinatorEpoch != coordinatorEpoch) {
+                // the cache may have been changed due to txn topic partition emigration and immigration,
+                // in this case directly return NOT_COORDINATOR to client and let it to re-discover the transaction coordinator
+                info(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId failed after the transaction message " +
+                  s"has been appended to the log. The cached coordinator epoch has changed to ${epochAndMetadata.coordinatorEpoch}")
+
+                responseError = Errors.NOT_COORDINATOR
+              } else {
+                metadata.completeTransitionTo(newMetadata)
+
+                debug(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId succeeded")
               }
+            }
 
-            case None =>
-              // this transactional id no longer exists, maybe the corresponding partition has already been migrated out.
-              // return NOT_COORDINATOR to let the client retry
-              debug(s"Updating $transactionalId's transaction state to $txnMetadata for $transactionalId failed after the transaction message " +
-                s"has been appended to the log. The partition for $transactionalId may have migrated as the metadata is no longer in the cache")
+          case None =>
+            // this transactional id no longer exists, maybe the corresponding partition has already been migrated out.
+            // return NOT_COORDINATOR to let the client re-discover the transaction coordinator
+            info(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId failed after the transaction message " +
+              s"has been appended to the log. The partition ${partitionFor(transactionalId)} may have migrated as the metadata is no longer in the cache")
 
-              responseError = Errors.NOT_COORDINATOR
-          }
+            responseError = Errors.NOT_COORDINATOR
+        }
       }
 
       responseCallback(responseError)
     }
 
     replicaManager.appendRecords(
-      txnMetadata.txnTimeoutMs.toLong,
+      newMetadata.txnTimeoutMs.toLong,
       TransactionLog.EnforcedRequiredAcks,
       internalTopicsAllowed = true,
       isFromClient = false,
@@ -441,25 +458,25 @@ class TransactionStateManager(brokerId: Int,
 
   def shutdown() {
     shuttingDown.set(true)
-    if (scheduler.isStarted)
-      scheduler.shutdown()
-
-    transactionMetadataCache.clear()
-
-    ownedPartitions.clear()
     loadingPartitions.clear()
+    transactionMetadataCache.clear()
 
     info("Shutdown complete")
   }
 }
 
-private[transaction] case class TransactionConfig(transactionalIdExpirationMs: Int = TransactionManager.DefaultTransactionalIdExpirationMs,
-                                                  transactionMaxTimeoutMs: Int = TransactionManager.DefaultTransactionsMaxTimeoutMs,
+
+private[transaction] case class TxnMetadataCacheEntry(coordinatorEpoch: Int, metadataPerTransactionalId: Pool[String, TransactionMetadata])
+
+private[transaction] case class CoordinatorEpochAndTxnMetadata(coordinatorEpoch: Int, transactionMetadata: TransactionMetadata)
+
+private[transaction] case class TransactionConfig(transactionalIdExpirationMs: Int = TransactionStateManager.DefaultTransactionalIdExpirationMs,
+                                                  transactionMaxTimeoutMs: Int = TransactionStateManager.DefaultTransactionsMaxTimeoutMs,
                                                   transactionLogNumPartitions: Int = TransactionLog.DefaultNumPartitions,
                                                   transactionLogReplicationFactor: Short = TransactionLog.DefaultReplicationFactor,
                                                   transactionLogSegmentBytes: Int = TransactionLog.DefaultSegmentBytes,
                                                   transactionLogLoadBufferSize: Int = TransactionLog.DefaultLoadBufferSize,
                                                   transactionLogMinInsyncReplicas: Int = TransactionLog.DefaultMinInSyncReplicas,
-                                                  removeExpiredTransactionsIntervalMs: Int = TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs)
+                                                  removeExpiredTransactionsIntervalMs: Int = TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
 
-case class TransactionalIdAndMetadata(transactionalId: String, metadata: TransactionMetadata)
+case class TransactionalIdAndProducerIdEpoch(transactionalId: String, producerId: Long, producerEpoch: Short)

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/server/DelayedOperation.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/DelayedOperation.scala b/core/src/main/scala/kafka/server/DelayedOperation.scala
index c0efc53..6401600 100644
--- a/core/src/main/scala/kafka/server/DelayedOperation.scala
+++ b/core/src/main/scala/kafka/server/DelayedOperation.scala
@@ -118,7 +118,8 @@ object DelayedOperationPurgatory {
 
   def apply[T <: DelayedOperation](purgatoryName: String,
                                    brokerId: Int = 0,
-                                   purgeInterval: Int = 1000): DelayedOperationPurgatory[T] = {
+                                   purgeInterval: Int = 1000,
+                                   reaperEnabled: Boolean = true): DelayedOperationPurgatory[T] = {
     val timer = new SystemTimer(purgatoryName)
     new DelayedOperationPurgatory[T](purgatoryName, timer, brokerId, purgeInterval)
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/server/KafkaConfig.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala
index 76f6380..690d167 100755
--- a/core/src/main/scala/kafka/server/KafkaConfig.scala
+++ b/core/src/main/scala/kafka/server/KafkaConfig.scala
@@ -23,7 +23,7 @@ import kafka.api.{ApiVersion, KAFKA_0_10_0_IV1}
 import kafka.cluster.EndPoint
 import kafka.consumer.ConsumerConfig
 import kafka.coordinator.group.OffsetConfig
-import kafka.coordinator.transaction.{TransactionLog, TransactionManager}
+import kafka.coordinator.transaction.{TransactionLog, TransactionStateManager}
 import kafka.message.{BrokerCompressionCodec, CompressionCodec, Message, MessageSet}
 import kafka.utils.CoreUtils
 import org.apache.kafka.clients.CommonClientConfigs
@@ -158,14 +158,14 @@ object Defaults {
   val OffsetCommitRequiredAcks = OffsetConfig.DefaultOffsetCommitRequiredAcks
 
   /** ********* Transaction management configuration ***********/
-  val TransactionalIdExpirationMs = TransactionManager.DefaultTransactionalIdExpirationMs
-  val TransactionsMaxTimeoutMs = TransactionManager.DefaultTransactionsMaxTimeoutMs
+  val TransactionalIdExpirationMs = TransactionStateManager.DefaultTransactionalIdExpirationMs
+  val TransactionsMaxTimeoutMs = TransactionStateManager.DefaultTransactionsMaxTimeoutMs
   val TransactionsTopicMinISR = TransactionLog.DefaultMinInSyncReplicas
   val TransactionsLoadBufferSize = TransactionLog.DefaultLoadBufferSize
   val TransactionsTopicReplicationFactor = TransactionLog.DefaultReplicationFactor
   val TransactionsTopicPartitions = TransactionLog.DefaultNumPartitions
   val TransactionsTopicSegmentBytes = TransactionLog.DefaultSegmentBytes
-  val TransactionsExpiredTransactionCleanupIntervalMS = TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs
+  val TransactionsExpiredTransactionCleanupIntervalMS = TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs
 
   /** ********* Quota Configuration ***********/
   val ProducerQuotaBytesPerSecondDefault = ClientQuotaManagerConfig.QuotaBytesPerSecondDefault

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/server/MetadataCache.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/MetadataCache.scala b/core/src/main/scala/kafka/server/MetadataCache.scala
index 1b334ac..2e4c19a 100755
--- a/core/src/main/scala/kafka/server/MetadataCache.scala
+++ b/core/src/main/scala/kafka/server/MetadataCache.scala
@@ -159,6 +159,24 @@ class MetadataCache(brokerId: Int) extends Logging {
     }
   }
 
+  def getPartitionLeaderEndpoint(topic: String, partitionId: Int, listenerName: ListenerName): Option[Node] = {
+    inReadLock(partitionMetadataLock) {
+      cache.get(topic).flatMap(_.get(partitionId)) match {
+        case Some(partitionInfo) =>
+          val leaderId = partitionInfo.leaderIsrAndControllerEpoch.leaderAndIsr.leader
+          try {
+            getAliveEndpoint(leaderId, listenerName)
+          } catch {
+            case e: BrokerEndPointNotAvailableException =>
+              None
+          }
+
+        case None =>
+          None
+      }
+    }
+  }
+
   def getControllerId: Option[Int] = controllerId
 
   // This method returns the deleted TopicPartitions received from UpdateMetadataRequest

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/server/ReplicaManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 99c1b45..9cd92f7 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -149,11 +149,11 @@ class ReplicaManager(val config: KafkaConfig,
   private val lastIsrPropagationMs = new AtomicLong(System.currentTimeMillis())
 
   val delayedProducePurgatory = DelayedOperationPurgatory[DelayedProduce](
-    purgatoryName = "Produce", localBrokerId, config.producerPurgatoryPurgeIntervalRequests)
+    purgatoryName = "Produce", brokerId = localBrokerId, purgeInterval = config.producerPurgatoryPurgeIntervalRequests)
   val delayedFetchPurgatory = DelayedOperationPurgatory[DelayedFetch](
-    purgatoryName = "Fetch", localBrokerId, config.fetchPurgatoryPurgeIntervalRequests)
+    purgatoryName = "Fetch", brokerId = localBrokerId, purgeInterval = config.fetchPurgatoryPurgeIntervalRequests)
   val delayedDeleteRecordsPurgatory = DelayedOperationPurgatory[DelayedDeleteRecords](
-    purgatoryName = "DeleteRecords", localBrokerId, config.deleteRecordsPurgatoryPurgeIntervalRequests)
+    purgatoryName = "DeleteRecords", brokerId = localBrokerId, purgeInterval = config.deleteRecordsPurgatoryPurgeIntervalRequests)
 
   val leaderCount = newGauge(
     "LeaderCount",

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorIntegrationTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorIntegrationTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorIntegrationTest.scala
index 20d1161..df23952 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorIntegrationTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorIntegrationTest.scala
@@ -55,6 +55,12 @@ class TransactionCoordinatorIntegrationTest extends KafkaServerTestHarness {
     val txnId = "txn"
     tc.handleInitPid(txnId, 900000, callback)
 
+    while(initPidResult == null) {
+      Utils.sleep(1)
+    }
+
+    Assert.assertEquals(Errors.NONE, initPidResult.error)
+
     @volatile var addPartitionErrors: Errors = null
     def addPartitionsCallback(errors: Errors): Unit = {
         addPartitionErrors = errors

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index a9f1bca..395bfb9 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -40,6 +40,7 @@ class TransactionCoordinatorTest {
   val capturedTxn: Capture[TransactionMetadata] = EasyMock.newCapture()
   val capturedErrorsCallback: Capture[Errors => Unit] = EasyMock.newCapture()
   val brokerId = 0
+  val coordinatorEpoch = 0
   private val transactionalId = "known"
   private val pid = 10
   private val epoch:Short = 1
@@ -50,11 +51,11 @@ class TransactionCoordinatorTest {
   private val scheduler = new MockScheduler(time)
 
   val coordinator: TransactionCoordinator = new TransactionCoordinator(brokerId,
+    scheduler,
     pidManager,
     transactionManager,
     transactionMarkerChannelManager,
     txnMarkerPurgatory,
-    scheduler,
     time)
 
   var result: InitPidResult = _
@@ -76,7 +77,6 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(EasyMock.eq(transactionalId)))
       .andReturn(true)
       .anyTimes()
-
     EasyMock.expect(transactionManager.isCoordinatorLoadingInProgress(EasyMock.anyString()))
       .andReturn(false)
       .anyTimes()
@@ -85,7 +85,6 @@ class TransactionCoordinatorTest {
       .anyTimes()
   }
 
-
   @Test
   def shouldAcceptInitPidAndReturnNextPidWhenTransactionalIdIsEmpty(): Unit = {
     mockPidManager()
@@ -111,28 +110,30 @@ class TransactionCoordinatorTest {
   @Test
   def shouldInitPidWithEpochZeroForNewTransactionalId(): Unit = {
     initPidGenericMocks(transactionalId)
-    EasyMock.expect(transactionManager.addTransaction(EasyMock.eq(transactionalId), EasyMock.capture(capturedTxn)))
-      .andAnswer(new IAnswer[TransactionMetadata] {
-        override def answer(): TransactionMetadata = {
-          capturedTxn.getValue
+
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
+      .andAnswer(new IAnswer[Option[CoordinatorEpochAndTxnMetadata]] {
+        override def answer(): Option[CoordinatorEpochAndTxnMetadata] = {
+          if (capturedTxn.hasCaptured)
+            Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue))
+          else
+            None
         }
       })
       .once()
-    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
-      .andAnswer(new IAnswer[Option[TransactionMetadata]] {
-        override def answer(): Option[TransactionMetadata] = {
-          if (capturedTxn.hasCaptured) {
-            Some(capturedTxn.getValue)
-          } else {
-            None
-          }
+
+    EasyMock.expect(transactionManager.addTransaction(EasyMock.eq(transactionalId), EasyMock.capture(capturedTxn)))
+      .andAnswer(new IAnswer[CoordinatorEpochAndTxnMetadata] {
+        override def answer(): CoordinatorEpochAndTxnMetadata = {
+          CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue)
         }
       })
       .once()
 
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
-      EasyMock.capture(capturedTxn),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.anyObject().asInstanceOf[TransactionMetadataTransition],
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {
@@ -143,7 +144,7 @@ class TransactionCoordinatorTest {
     EasyMock.replay(pidManager, transactionManager)
 
     coordinator.handleInitPid(transactionalId, txnTimeoutMs, initPidMockCallback)
-    assertEquals(InitPidResult(0L, 0, Errors.NONE), result)
+    assertEquals(InitPidResult(nextPid - 1, 0, Errors.NONE), result)
   }
 
   @Test
@@ -212,7 +213,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(0, 0, 0, state, mutable.Set.empty, 0, 0)))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(0, 0, 0, state, mutable.Set.empty, 0, 0))))
 
     EasyMock.replay(transactionManager)
 
@@ -225,7 +226,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(0, 10, 0, PrepareCommit, mutable.Set.empty, 0, 0)))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(0, 10, 0, PrepareCommit, mutable.Set.empty, 0, 0))))
 
     EasyMock.replay(transactionManager)
 
@@ -254,20 +255,23 @@ class TransactionCoordinatorTest {
   }
 
   def validateSuccessfulAddPartitions(previousState: TransactionState): Unit = {
+    val txnMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds())
+
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(0, 0, 0, previousState, mutable.Set.empty, 0, 0)))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
 
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
-      EasyMock.eq(new TransactionMetadata(0, 0, 0, Ongoing, partitions, if (previousState == Ongoing) 0 else time.milliseconds(), time.milliseconds())),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.anyObject().asInstanceOf[TransactionMetadataTransition],
       EasyMock.capture(capturedErrorsCallback)
     ))
 
     EasyMock.replay(transactionManager)
 
-    coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback)
+    coordinator.handleAddPartitionsToTransaction(transactionalId, pid, epoch, partitions, errorsCallback)
 
     EasyMock.verify(transactionManager)
   }
@@ -277,7 +281,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(0, 0, 0, Empty, partitions, 0, 0)))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(0, 0, 0, Empty, partitions, 0, 0))))
 
     EasyMock.replay(transactionManager)
 
@@ -304,7 +308,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(10, 0, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(10, 0, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
@@ -317,7 +321,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 0, TransactionResult.COMMIT, errorsCallback)
@@ -330,7 +334,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
@@ -343,7 +347,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.ABORT, errorsCallback)
@@ -356,7 +360,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
@@ -369,7 +373,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.ABORT, errorsCallback)
@@ -378,15 +382,15 @@ class TransactionCoordinatorTest {
   }
 
   @Test
-  def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = {
+  def shouldReturnConcurrentTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
-    assertEquals(Errors.INVALID_TXN_STATE, error)
+    assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
     EasyMock.verify(transactionManager)
   }
 
@@ -395,7 +399,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
@@ -425,29 +429,6 @@ class TransactionCoordinatorTest {
     EasyMock.verify(transactionManager)
   }
 
-
-  @Test
-  def shouldAppendCompleteAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(): Unit = {
-    mockComplete(PrepareAbort)
-
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.ABORT, errorsCallback)
-
-    EasyMock.verify(transactionManager)
-  }
-
-  @Test
-  def shouldAppendCompleteCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(): Unit = {
-    mockComplete(PrepareCommit)
-
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
-
-    EasyMock.verify(transactionManager)
-  }
-
   @Test
   def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(): Unit = {
     coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, errorsCallback)
@@ -506,18 +487,29 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): Unit = {
+    val txnMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, 0, 0)
+
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
+      .anyTimes()
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
 
-    val metadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(metadata))
-      .once()
-
-    mockComplete(PrepareAbort)
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
+      .anyTimes()
 
+    val originalMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, 0, 0)
+    EasyMock.expect(transactionManager.appendTransactionToLog(
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())),
+      EasyMock.capture(capturedErrorsCallback)))
+      .andAnswer(new IAnswer[Unit] {
+        override def answer(): Unit = {
+          capturedErrorsCallback.getValue.apply(Errors.NONE)
+        }
+      })
 
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
@@ -529,8 +521,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRemoveTransactionsForPartitionOnEmigration(): Unit = {
-    EasyMock.expect(transactionManager.removeTransactionsForPartition(0))
-    EasyMock.expect(transactionMarkerChannelManager.removeStateForPartition(0))
+    EasyMock.expect(transactionManager.removeTransactionsForTxnTopicPartition(0))
+    EasyMock.expect(transactionMarkerChannelManager.removeMarkersForTxnTopicPartition(0))
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
     coordinator.handleTxnEmigration(0)
@@ -539,114 +531,22 @@ class TransactionCoordinatorTest {
   }
 
   @Test
-  def shouldRetryOnCommitWhenTxnMarkerRequestFailsWithErrorOtherThanNotCoordinator(): Unit = {
-    val prepareMetadata = mockPrepare(PrepareCommit, runCallback = true)
-
-    EasyMock.expect(transactionManager.coordinatorEpochFor(transactionalId))
-      .andReturn(Some(0))
-
-    EasyMock.expect(transactionMarkerChannelManager.addTxnMarkerRequest(
-      EasyMock.eq(0),
-      EasyMock.anyObject(),
-      EasyMock.anyInt(),
-      EasyMock.capture(capturedErrorsCallback)
-    )).andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        capturedErrorsCallback.getValue.apply(Errors.NETWORK_EXCEPTION)
-      }
-    }).andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        capturedErrorsCallback.getValue.apply(Errors.NONE)
-      }
-    })
-
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(prepareMetadata))
-      .once()
-
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
-
-    EasyMock.verify(transactionMarkerChannelManager)
-  }
-
-  @Test
-  def shouldNotRetryOnCommitWhenTxnMarkerRequestFailsWithNotCoordinator(): Unit = {
-    val prepareMetadata = mockPrepare(PrepareCommit, runCallback = true)
-
-    EasyMock.expect(transactionManager.coordinatorEpochFor(transactionalId))
-      .andReturn(Some(0))
-
-    EasyMock.expect(transactionMarkerChannelManager.addTxnMarkerRequest(
-      EasyMock.eq(0),
-      EasyMock.anyObject(),
-      EasyMock.anyInt(),
-      EasyMock.capture(capturedErrorsCallback)
-    )).andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        capturedErrorsCallback.getValue.apply(Errors.NOT_COORDINATOR)
-      }
-    })
-
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
-
-    EasyMock.verify(transactionMarkerChannelManager)
-  }
-
-  @Test
-  def shouldNotRetryOnCommitWhenAppendToLogFailsWithNotCoordinator(): Unit = {
-    mockComplete(PrepareCommit, Errors.NOT_COORDINATOR)
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
-
-    EasyMock.verify(transactionManager)
-  }
-
-  @Test
-  def shouldRetryOnCommitWhenAppendToLogFailsErrorsOtherThanNotCoordinator(): Unit = {
-    mockComplete(PrepareCommit, Errors.ILLEGAL_GENERATION)
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
-
-    EasyMock.verify(transactionManager)
-  }
-
-  @Test
   def shouldAbortExpiredTransactionsInOngoingState(): Unit = {
-    EasyMock.expect(transactionManager.transactionsToExpire())
-    .andReturn(List(TransactionalIdAndMetadata(transactionalId,
-      new TransactionMetadata(pid, epoch, 0, Ongoing, partitions, time.milliseconds(), time.milliseconds()))))
-
-    // should bump the epoch and append to the log
-    val metadata = new TransactionMetadata(pid, (epoch + 1).toShort, 0, Ongoing, partitions, time.milliseconds(), time.milliseconds())
-    EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
-      EasyMock.eq(metadata),
-      EasyMock.capture(capturedErrorsCallback)))
-    .andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        capturedErrorsCallback.getValue.apply(Errors.NONE)
-      }
-    }).once()
+    val txnMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
 
+    EasyMock.expect(transactionManager.transactionsToExpire())
+      .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, pid, epoch)))
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(metadata))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
       .once()
 
-    // now should perform the rollback and append the state as PrepareAbort
-    val abortMetadata = metadata.copy()
-    abortMetadata.state = PrepareAbort
-    // need to allow for the time.sleep below
-    abortMetadata.lastUpdateTimestamp = time.milliseconds() + TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs
+    val newMetadata = txnMetadata.copy().prepareAbortOrCommit(PrepareAbort, time.milliseconds() + TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
 
     EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
-      EasyMock.eq(abortMetadata),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(newMetadata),
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {}
@@ -656,147 +556,135 @@ class TransactionCoordinatorTest {
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
     coordinator.startup(false)
-    time.sleep(TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs)
+    time.sleep(TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
     scheduler.tick()
     EasyMock.verify(transactionManager)
   }
 
   @Test
   def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = {
-    val metadata = new TransactionMetadata(pid, epoch, 0, Ongoing, partitions, time.milliseconds(), time.milliseconds())
-    metadata.prepareTransitionTo(PrepareCommit)
+    val metadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
+    metadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds())
 
     EasyMock.expect(transactionManager.transactionsToExpire())
-      .andReturn(List(TransactionalIdAndMetadata(transactionalId,
-        metadata)))
+      .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, pid, epoch)))
     
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-    coordinator.startup(false)
 
-    time.sleep(TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs)
+    coordinator.startup(false)
+    time.sleep(TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
     scheduler.tick()
     EasyMock.verify(transactionManager)
-
   }
 
   private def validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(state: TransactionState) = {
-    val transactionId = "tid"
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionId))
+    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true).anyTimes()
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true).anyTimes()
 
     val metadata = new TransactionMetadata(0, 0, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0)
-    EasyMock.expect(transactionManager.getTransactionState(transactionId))
-      .andReturn(Some(metadata)).anyTimes()
+    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))).anyTimes()
 
     EasyMock.replay(transactionManager)
 
-    coordinator.handleInitPid(transactionId, 10, initPidMockCallback)
+    coordinator.handleInitPid(transactionalId, 10, initPidMockCallback)
 
     assertEquals(InitPidResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result)
   }
 
   private def validateIncrementEpochAndUpdateMetadata(state: TransactionState) = {
-    val transactionId = "tid"
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionId))
+    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
 
-    val metadata = new TransactionMetadata(0, 0, 0, state, mutable.Set.empty[TopicPartition], 0, 0)
-    EasyMock.expect(transactionManager.getTransactionState(transactionId))
-      .andReturn(Some(metadata))
+    val metadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds())
+    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))
 
+    val capturedNewMetadata: Capture[TransactionMetadataTransition] = EasyMock.newCapture()
     EasyMock.expect(transactionManager.appendTransactionToLog(
-      EasyMock.eq(transactionId),
-      EasyMock.anyObject(classOf[TransactionMetadata]),
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.capture(capturedNewMetadata),
       EasyMock.capture(capturedErrorsCallback)
     )).andAnswer(new IAnswer[Unit] {
       override def answer(): Unit = {
+        metadata.completeTransitionTo(capturedNewMetadata.getValue)
         capturedErrorsCallback.getValue.apply(Errors.NONE)
       }
     })
 
     EasyMock.replay(transactionManager)
 
-    coordinator.handleInitPid(transactionId, 10, initPidMockCallback)
+    val newTxnTimeoutMs = 10
+    coordinator.handleInitPid(transactionalId, newTxnTimeoutMs, initPidMockCallback)
 
-    assertEquals(InitPidResult(0, 1, Errors.NONE), result)
-    assertEquals(10, metadata.txnTimeoutMs)
-    assertEquals(time.milliseconds(), metadata.lastUpdateTimestamp)
-    assertEquals(1, metadata.producerEpoch)
-    assertEquals(0, metadata.pid)
+    assertEquals(InitPidResult(pid, (epoch + 1).toShort, Errors.NONE), result)
+    assertEquals(newTxnTimeoutMs, metadata.txnTimeoutMs)
+    assertEquals(time.milliseconds(), metadata.txnLastUpdateTimestamp)
+    assertEquals((epoch + 1).toShort, metadata.producerEpoch)
+    assertEquals(pid, metadata.producerId)
   }
 
-  private def mockPrepare(transactionState: TransactionState, runCallback: Boolean = false) = {
-    val originalMetadata = new TransactionMetadata(pid,
-      epoch,
-      txnTimeoutMs,
-      Ongoing,
-      collection.mutable.Set.empty[TopicPartition],
-      0,
-      time.milliseconds())
-
-    val prepareCommitMetadata = new TransactionMetadata(pid,
-      epoch,
-      txnTimeoutMs,
-      transactionState,
-      collection.mutable.Set.empty[TopicPartition],
-      0,
-      time.milliseconds())
+  private def mockPrepare(transactionState: TransactionState, runCallback: Boolean = false): TransactionMetadata = {
+    val now = time.milliseconds()
+    val originalMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, now, now)
 
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
+      .anyTimes()
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(originalMetadata))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata)))
       .once()
-
-    EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
-      EasyMock.eq(prepareCommitMetadata),
+    EasyMock.expect(transactionManager.appendTransactionToLog(
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(originalMetadata.copy().prepareAbortOrCommit(transactionState, now)),
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {
-          if (runCallback) capturedErrorsCallback.getValue.apply(Errors.NONE)
+          if (runCallback)
+            capturedErrorsCallback.getValue.apply(Errors.NONE)
         }
       }).once()
-    prepareCommitMetadata
-  }
-
-  private def mockComplete(transactionState: TransactionState, appendError: Errors = Errors.NONE) = {
 
+    new TransactionMetadata(pid, epoch, txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds())
+  }
 
-    val prepareMetadata: TransactionMetadata = mockPrepare(transactionState, true)
-    val finalState = if (transactionState == PrepareAbort) CompleteAbort else CompleteCommit
+  private def mockComplete(transactionState: TransactionState, appendError: Errors = Errors.NONE): TransactionMetadata = {
+    val now = time.milliseconds()
+    val prepareMetadata = mockPrepare(transactionState, true)
 
-    EasyMock.expect(transactionManager.coordinatorEpochFor(transactionalId))
-      .andReturn(Some(0))
+    val (finalState, txnResult) = if (transactionState == PrepareAbort)
+      (CompleteAbort, TransactionResult.ABORT)
+    else
+      (CompleteCommit, TransactionResult.COMMIT)
 
-    EasyMock.expect(transactionMarkerChannelManager.addTxnMarkerRequest(
-      EasyMock.eq(0),
-      EasyMock.anyObject(),
-      EasyMock.anyInt(),
-      EasyMock.capture(capturedErrorsCallback)
-    )).andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        capturedErrorsCallback.getValue.apply(Errors.NONE)
-      }
-    })
+    val completedMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, finalState,
+      collection.mutable.Set.empty[TopicPartition],
+      prepareMetadata.txnStartTimestamp,
+      prepareMetadata.txnLastUpdateTimestamp)
 
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(prepareMetadata))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, prepareMetadata)))
       .once()
 
-    val completedMetadata = new TransactionMetadata(pid,
-      epoch,
-      txnTimeoutMs,
-      finalState,
-      prepareMetadata.topicPartitions,
-      prepareMetadata.transactionStartTime,
-      prepareMetadata.lastUpdateTimestamp)
+    val newMetadata = prepareMetadata.copy().prepareComplete(now)
+    EasyMock.expect(transactionMarkerChannelManager.addTxnMarkersToSend(
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(txnResult),
+      EasyMock.eq(prepareMetadata),
+      EasyMock.eq(newMetadata))
+    ).once()
 
-    val firstAnswer = EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
-      EasyMock.eq(completedMetadata),
+    val firstAnswer = EasyMock.expect(transactionManager.appendTransactionToLog(
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(newMetadata),
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {
@@ -804,18 +692,18 @@ class TransactionCoordinatorTest {
         }
       })
 
-     if(appendError != Errors.NONE && appendError != Errors.NOT_COORDINATOR) {
-        firstAnswer.andAnswer(new IAnswer[Unit] {
-          override def answer(): Unit = {
-            capturedErrorsCallback.getValue.apply(Errors.NONE)
-          }
-        })
-     }
-
+    // let it succeed next time
+    if (appendError != Errors.NONE && appendError != Errors.NOT_COORDINATOR) {
+      firstAnswer.andAnswer(new IAnswer[Unit] {
+        override def answer(): Unit = {
+          capturedErrorsCallback.getValue.apply(Errors.NONE)
+        }
+      })
+    }
 
+    completedMetadata
   }
 
-
   def initPidMockCallback(ret: InitPidResult): Unit = {
     result = ret
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
index cfb4a99..fe750b8 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
@@ -43,7 +43,7 @@ class TransactionLogTest extends JUnitSuite {
     txnMetadata.addPartitions(topicPartitions)
 
     intercept[IllegalStateException] {
-      TransactionLog.valueToBytes(txnMetadata)
+      TransactionLog.valueToBytes(txnMetadata.prepareNoTransit())
     }
   }
 
@@ -71,7 +71,7 @@ class TransactionLogTest extends JUnitSuite {
         txnMetadata.addPartitions(topicPartitions)
 
       val keyBytes = TransactionLog.keyToBytes(transactionalId)
-      val valueBytes = TransactionLog.valueToBytes(txnMetadata)
+      val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit())
 
       new SimpleRecord(keyBytes, valueBytes)
     }.toSeq
@@ -87,10 +87,10 @@ class TransactionLogTest extends JUnitSuite {
           val transactionalId = pidKey.transactionalId
           val txnMetadata = TransactionLog.readMessageValue(record.value())
 
-          assertEquals(pidMappings(transactionalId), txnMetadata.pid)
+          assertEquals(pidMappings(transactionalId), txnMetadata.producerId)
           assertEquals(epoch, txnMetadata.producerEpoch)
           assertEquals(transactionTimeoutMs, txnMetadata.txnTimeoutMs)
-          assertEquals(transactionStates(txnMetadata.pid), txnMetadata.state)
+          assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state)
 
           if (txnMetadata.state.equals(Empty))
             assertEquals(Set.empty[TopicPartition], txnMetadata.topicPartitions)


Mime
View raw message