kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From guozh...@apache.org
Subject [1/2] kafka git commit: KAFKA-5280: Protect txn metadata map with read-write lock
Date Tue, 23 May 2017 20:34:47 GMT
Repository: kafka
Updated Branches:
  refs/heads/trunk 495877505 -> 1bf648331


http://git-wip-us.apache.org/repos/asf/kafka/blob/1bf64833/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index 3c94916..e225588 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -16,9 +16,7 @@
  */
 package kafka.coordinator.transaction
 
-import kafka.server.DelayedOperationPurgatory
 import kafka.utils.MockScheduler
-import kafka.utils.timer.MockTimer
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.TransactionResult
@@ -37,7 +35,7 @@ class TransactionCoordinatorTest {
   val pidManager: ProducerIdManager = EasyMock.createNiceMock(classOf[ProducerIdManager])
   val transactionManager: TransactionStateManager = EasyMock.createNiceMock(classOf[TransactionStateManager])
   val transactionMarkerChannelManager: TransactionMarkerChannelManager = EasyMock.createNiceMock(classOf[TransactionMarkerChannelManager])
-  val capturedTxn: Capture[TransactionMetadata] = EasyMock.newCapture()
+  val capturedTxn: Capture[Option[TransactionMetadata]] = EasyMock.newCapture()
   val capturedErrorsCallback: Capture[Errors => Unit] = EasyMock.newCapture()
   val brokerId = 0
   val coordinatorEpoch = 0
@@ -46,7 +44,6 @@ class TransactionCoordinatorTest {
   private val producerEpoch:Short = 1
   private val txnTimeoutMs = 1
 
-  private val txnMarkerPurgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("test", new MockTimer, reaperEnabled = false)
   private val partitions = mutable.Set[TopicPartition](new TopicPartition("topic1", 0))
   private val scheduler = new MockScheduler(time)
 
@@ -55,7 +52,6 @@ class TransactionCoordinatorTest {
     pidManager,
     transactionManager,
     transactionMarkerChannelManager,
-    txnMarkerPurgatory,
     time)
 
   var result: InitProducerIdResult = _
@@ -74,12 +70,6 @@ class TransactionCoordinatorTest {
 
   private def initPidGenericMocks(transactionalId: String): Unit = {
     mockPidManager()
-    EasyMock.expect(transactionManager.isCoordinatorFor(EasyMock.eq(transactionalId)))
-      .andReturn(true)
-      .anyTimes()
-    EasyMock.expect(transactionManager.isCoordinatorLoadingInProgress(EasyMock.anyString()))
-      .andReturn(false)
-      .anyTimes()
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
       .anyTimes()
@@ -111,21 +101,13 @@ class TransactionCoordinatorTest {
   def shouldInitPidWithEpochZeroForNewTransactionalId(): Unit = {
     initPidGenericMocks(transactionalId)
 
-    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
-      .andAnswer(new IAnswer[Option[CoordinatorEpochAndTxnMetadata]] {
-        override def answer(): Option[CoordinatorEpochAndTxnMetadata] = {
-          if (capturedTxn.hasCaptured)
-            Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.capture(capturedTxn)))
+      .andAnswer(new IAnswer[Either[Errors, Option[CoordinatorEpochAndTxnMetadata]]] {
+        override def answer(): Either[Errors, Option[CoordinatorEpochAndTxnMetadata]] = {
+          if (capturedTxn.hasCaptured && capturedTxn.getValue.isDefined)
+            Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue.get)))
           else
-            None
-        }
-      })
-      .once()
-
-    EasyMock.expect(transactionManager.addTransaction(EasyMock.capture(capturedTxn)))
-      .andAnswer(new IAnswer[CoordinatorEpochAndTxnMetadata] {
-        override def answer(): CoordinatorEpochAndTxnMetadata = {
-          CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue)
+            Right(None)
         }
       })
       .once()
@@ -148,20 +130,35 @@ class TransactionCoordinatorTest {
   }
 
   @Test
-  def shouldRespondWithNotCoordinatorOnInitPidWhenNotCoordinatorForId(): Unit = {
-    mockPidManager()
-    EasyMock.replay(pidManager)
-    coordinator.handleInitProducerId("some-pid", txnTimeoutMs, initProducerIdMockCallback)
+  def shouldRespondWithNotCoordinatorOnInitPidWhenNotCoordinator(): Unit = {
+    EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
+      .andReturn(true)
+      .anyTimes()
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+      .andReturn(Left(Errors.NOT_COORDINATOR))
+    EasyMock.replay(transactionManager)
+
+    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, initProducerIdMockCallback)
     assertEquals(InitProducerIdResult(-1, -1, Errors.NOT_COORDINATOR), result)
   }
 
   @Test
-  def shouldRespondWithInvalidPidMappingOnAddPartitionsToTransactionWhenTransactionalIdNotPresent(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
+  def shouldRespondWithCoordinatorLoadInProgressOnInitPidWhenCoordintorLoading(): Unit = {
+    EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
+      .anyTimes()
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+      .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
+    EasyMock.replay(transactionManager)
+
+    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, initProducerIdMockCallback)
+    assertEquals(InitProducerIdResult(-1, -1, Errors.COORDINATOR_LOAD_IN_PROGRESS), result)
+  }
 
-    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
-      .andReturn(None)
+  @Test
+  def shouldRespondWithInvalidPidMappingOnAddPartitionsToTransactionWhenTransactionalIdNotPresent(): Unit = {
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+      .andReturn(Right(None))
     EasyMock.replay(transactionManager)
 
     coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback)
@@ -182,16 +179,18 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithNotCoordinatorOnAddPartitionsWhenNotCoordinator(): Unit = {
-    coordinator.handleAddPartitionsToTransaction("txn", 0L, 1, partitions, errorsCallback)
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Left(Errors.NOT_COORDINATOR))
+    EasyMock.replay(transactionManager)
+
+    coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback)
     assertEquals(Errors.NOT_COORDINATOR, error)
   }
 
   @Test
   def shouldRespondWithCoordinatorLoadInProgressOnAddPartitionsWhenCoordintorLoading(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.isCoordinatorLoadingInProgress(transactionalId))
-    .andReturn(true)
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
 
     EasyMock.replay(transactionManager)
 
@@ -210,11 +209,8 @@ class TransactionCoordinatorTest {
   }
 
   def validateConcurrentTransactions(state: TransactionState): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
-        0, 0, 0, state, mutable.Set.empty, 0, 0))))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, 0, state, mutable.Set.empty, 0, 0)))))
 
     EasyMock.replay(transactionManager)
 
@@ -224,11 +220,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithInvalidTnxProduceEpochOnAddPartitionsWhenEpochsAreDifferent(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
-        0, 10, 0, PrepareCommit, mutable.Set.empty, 0, 0))))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 10, 0, PrepareCommit, mutable.Set.empty, 0, 0)))))
 
     EasyMock.replay(transactionManager)
 
@@ -260,10 +253,8 @@ class TransactionCoordinatorTest {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, previousState,
       mutable.Set.empty, time.milliseconds(), time.milliseconds())
 
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
 
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
@@ -281,11 +272,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithErrorsNoneOnAddPartitionWhenNoErrorsAndPartitionsTheSame(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0,
-        0, Empty, partitions, 0, 0))))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, 0, Empty, partitions, 0, 0)))))
 
     EasyMock.replay(transactionManager)
 
@@ -297,9 +285,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId)).andReturn(None)
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(None))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
@@ -309,11 +296,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 10, 0,
-        0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 10, 0, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
@@ -323,11 +307,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
-        producerId, 1, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, 1, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, errorsCallback)
@@ -337,11 +318,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(): Unit ={
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
-        producerId, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
@@ -351,11 +329,9 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(): Unit ={
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
-        producerId, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+    val txnMetadata = new TransactionMetadata(transactionalId, producerId, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback)
@@ -365,11 +341,9 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
-        producerId, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+    val txnMetadata = new TransactionMetadata(transactionalId, producerId, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
@@ -379,11 +353,9 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
-        producerId, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+    val txnMetadata = new TransactionMetadata(transactionalId, producerId, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback)
@@ -393,11 +365,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReturnConcurrentTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
-        producerId, 1, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, 1, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
@@ -407,11 +376,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
-        producerId, 1, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, 1, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
@@ -449,26 +415,32 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(): Unit = {
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.eq(None)))
+      .andReturn(Left(Errors.NOT_COORDINATOR))
+    EasyMock.replay(transactionManager)
+
     coordinator.handleEndTransaction("", 0, 0, TransactionResult.COMMIT, errorsCallback)
     assertEquals(Errors.INVALID_REQUEST, error)
   }
 
   @Test
   def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(): Unit = {
-    coordinator.handleEndTransaction("id", 0, 0, TransactionResult.COMMIT, errorsCallback)
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Left(Errors.NOT_COORDINATOR))
+    EasyMock.replay(transactionManager)
+
+    coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
     assertEquals(Errors.NOT_COORDINATOR, error)
   }
 
   @Test
   def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(): Unit = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(EasyMock.anyString()))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.isCoordinatorLoadingInProgress(EasyMock.anyString()))
-      .andReturn(true)
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
 
     EasyMock.replay(transactionManager)
 
-    coordinator.handleEndTransaction("id", 0, 0, TransactionResult.COMMIT, errorsCallback)
+    coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
     assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, error)
   }
 
@@ -500,20 +472,17 @@ class TransactionCoordinatorTest {
   @Test
   def shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): Unit = {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, Ongoing,
-      partitions, 0, 0)
+      partitions, time.milliseconds(), time.milliseconds())
 
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-      .anyTimes()
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
 
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
       .anyTimes()
 
     val originalMetadata = new TransactionMetadata(transactionalId, producerId, (producerEpoch + 1).toShort,
-      txnTimeoutMs, Ongoing, partitions, 0, 0)
+      txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
       EasyMock.eq(coordinatorEpoch),
@@ -525,7 +494,7 @@ class TransactionCoordinatorTest {
         }
       })
 
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
+    EasyMock.replay(transactionManager)
 
     coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, initProducerIdMockCallback)
 
@@ -552,10 +521,8 @@ class TransactionCoordinatorTest {
 
     EasyMock.expect(transactionManager.transactionsToExpire())
       .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
       .once()
 
     val expectedTransition = TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, PrepareAbort,
@@ -586,6 +553,8 @@ class TransactionCoordinatorTest {
 
     EasyMock.expect(transactionManager.transactionsToExpire())
       .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
 
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
@@ -596,15 +565,12 @@ class TransactionCoordinatorTest {
   }
 
   private def validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(state: TransactionState) = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true).anyTimes()
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true).anyTimes()
 
-    val metadata = new TransactionMetadata(transactionalId, 0, 0, 0, state,
-      mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0)
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))).anyTimes()
+    val metadata = new TransactionMetadata(transactionalId, 0, 0, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0)
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))).anyTimes()
 
     EasyMock.replay(transactionManager)
 
@@ -614,15 +580,16 @@ class TransactionCoordinatorTest {
   }
 
   private def validateIncrementEpochAndUpdateMetadata(state: TransactionState) = {
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
+    EasyMock.expect(pidManager.generateProducerId())
+      .andReturn(producerId)
+      .anyTimes()
+
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
 
-    val metadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, state,
-      mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds())
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))
+    val metadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds())
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
 
     val capturedNewMetadata: Capture[TxnTransitMetadata] = EasyMock.newCapture()
     EasyMock.expect(transactionManager.appendTransactionToLog(
@@ -637,7 +604,7 @@ class TransactionCoordinatorTest {
       }
     })
 
-    EasyMock.replay(transactionManager)
+    EasyMock.replay(pidManager, transactionManager)
 
     val newTxnTimeoutMs = 10
     coordinator.handleInitProducerId(transactionalId, newTxnTimeoutMs, initProducerIdMockCallback)
@@ -657,11 +624,8 @@ class TransactionCoordinatorTest {
     val transition = TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, transactionState,
       partitions.toSet, now, now)
 
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
-      .andReturn(true)
-      .anyTimes()
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata)))
+    EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata))))
       .once()
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
@@ -679,62 +643,6 @@ class TransactionCoordinatorTest {
       time.milliseconds(), time.milliseconds())
   }
 
-  private def mockComplete(transactionState: TransactionState, appendError: Errors = Errors.NONE): TransactionMetadata = {
-    val now = time.milliseconds()
-    val prepareMetadata = mockPrepare(transactionState, true)
-
-    val (finalState, txnResult) = if (transactionState == PrepareAbort)
-      (CompleteAbort, TransactionResult.ABORT)
-    else
-      (CompleteCommit, TransactionResult.COMMIT)
-
-    val completedMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, finalState,
-      collection.mutable.Set.empty[TopicPartition],
-      prepareMetadata.txnStartTimestamp,
-      prepareMetadata.txnLastUpdateTimestamp)
-
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, prepareMetadata)))
-      .once()
-
-    val newMetadata = TxnTransitMetadata(producerId = producerId,
-      producerEpoch = producerEpoch,
-      txnTimeoutMs = txnTimeoutMs,
-      txnState = finalState,
-      topicPartitions = Set.empty[TopicPartition],
-      txnStartTimestamp = prepareMetadata.txnStartTimestamp,
-      txnLastUpdateTimestamp = now)
-    EasyMock.expect(transactionMarkerChannelManager.addTxnMarkersToSend(
-      EasyMock.eq(transactionalId),
-      EasyMock.eq(coordinatorEpoch),
-      EasyMock.eq(txnResult),
-      EasyMock.eq(prepareMetadata),
-      EasyMock.eq(newMetadata))
-    ).once()
-
-    val firstAnswer = EasyMock.expect(transactionManager.appendTransactionToLog(
-      EasyMock.eq(transactionalId),
-      EasyMock.eq(coordinatorEpoch),
-      EasyMock.eq(newMetadata),
-      EasyMock.capture(capturedErrorsCallback)))
-      .andAnswer(new IAnswer[Unit] {
-        override def answer(): Unit = {
-          capturedErrorsCallback.getValue.apply(appendError)
-        }
-      })
-
-    // let it succeed next time
-    if (appendError != Errors.NONE && appendError != Errors.NOT_COORDINATOR) {
-      firstAnswer.andAnswer(new IAnswer[Unit] {
-        override def answer(): Unit = {
-          capturedErrorsCallback.getValue.apply(Errors.NONE)
-        }
-      })
-    }
-
-    completedMetadata
-  }
-
   def initProducerIdMockCallback(ret: InitProducerIdResult): Unit = {
     result = ret
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/1bf64833/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
index 9835db7..991bfbe 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
@@ -41,7 +41,6 @@ class TransactionMarkerChannelManagerTest {
 
   private val transactionalId1 = "txnId1"
   private val transactionalId2 = "txnId2"
-  private val transactionalId3 = "txnId3"
   private val producerId1 = 0.asInstanceOf[Long]
   private val producerId2 = 1.asInstanceOf[Long]
   private val producerId3 = 1.asInstanceOf[Long]
@@ -106,7 +105,7 @@ class TransactionMarkerChannelManagerTest {
       PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L)
     channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
 
-    assertEquals(1 * 2, txnMarkerPurgatory.watched)
+    assertEquals(1, txnMarkerPurgatory.watched)
     assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
     assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers())
     assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
@@ -145,7 +144,7 @@ class TransactionMarkerChannelManagerTest {
     channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds()))
     channelManager.addTxnMarkersToSend(transactionalId2, coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds()))
 
-    assertEquals(2 * 2, txnMarkerPurgatory.watched)
+    assertEquals(2, txnMarkerPurgatory.watched)
     assertEquals(2, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
     assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
     assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2))
@@ -202,14 +201,14 @@ class TransactionMarkerChannelManagerTest {
       PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
     channelManager.addTxnMarkersToSend(transactionalId2, coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds()))
 
-    assertEquals(2 * 2, txnMarkerPurgatory.watched)
+    assertEquals(2, txnMarkerPurgatory.watched)
     assertEquals(2, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
     assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
     assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2))
 
     channelManager.removeMarkersForTxnTopicPartition(txnTopicPartition1)
 
-    assertEquals(1 * 2, txnMarkerPurgatory.watched)
+    assertEquals(1, txnMarkerPurgatory.watched)
     assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
     assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
     assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2))

http://git-wip-us.apache.org/repos/asf/kafka/blob/1bf64833/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
index c1123aa..a255830 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
@@ -57,8 +57,8 @@ class TransactionMarkerRequestCompletionHandlerTest {
     EasyMock.expect(txnStateManager.partitionFor(transactionalId))
       .andReturn(txnTopicPartition)
       .anyTimes()
-    EasyMock.expect(txnStateManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
+    EasyMock.expect(txnStateManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
       .anyTimes()
     EasyMock.replay(txnStateManager)
   }
@@ -99,9 +99,19 @@ class TransactionMarkerRequestCompletionHandlerTest {
   }
 
   @Test
-  def shouldCompleteDelayedOperationWhenNoMetadata(): Unit = {
-    EasyMock.expect(txnStateManager.getTransactionState(transactionalId))
-      .andReturn(None)
+  def shouldCompleteDelayedOperationWhenNotCoordinator(): Unit = {
+    EasyMock.expect(txnStateManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Left(Errors.NOT_COORDINATOR))
+      .anyTimes()
+    EasyMock.replay(txnStateManager)
+
+    verifyRemoveDelayedOperationOnError(Errors.NONE)
+  }
+
+  @Test
+  def shouldCompleteDelayedOperationWhenCoordinatorLoading(): Unit = {
+    EasyMock.expect(txnStateManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
       .anyTimes()
     EasyMock.replay(txnStateManager)
 
@@ -110,8 +120,8 @@ class TransactionMarkerRequestCompletionHandlerTest {
 
   @Test
   def shouldCompleteDelayedOperationWhenCoordinatorEpochChanged(): Unit = {
-    EasyMock.expect(txnStateManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch+1, txnMetadata)))
+    EasyMock.expect(txnStateManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch+1, txnMetadata))))
       .anyTimes()
     EasyMock.replay(txnStateManager)
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/1bf64833/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index bbf2f38..8682026 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -23,7 +23,7 @@ import kafka.server.{FetchDataInfo, LogOffsetMetadata, ReplicaManager}
 import kafka.utils.{MockScheduler, Pool, ZkUtils}
 import kafka.utils.TestUtils.fail
 import org.apache.kafka.common.TopicPartition
-import org.apache.kafka.common.internals.Topic.{GROUP_METADATA_TOPIC_NAME, TRANSACTION_STATE_TOPIC_NAME}
+import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.requests.IsolationLevel
@@ -98,10 +98,13 @@ class TransactionStateManagerTest {
   def testAddGetPids() {
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
 
-    assertEquals(None, transactionManager.getTransactionState(transactionalId1))
-    assertEquals(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1), transactionManager.addTransaction(txnMetadata1))
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
-    assertEquals(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata2), transactionManager.addTransaction(txnMetadata2))
+    assertEquals(Right(None), transactionManager.getAndMaybeAddTransactionState(transactionalId1))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))),
+      transactionManager.getAndMaybeAddTransactionState(transactionalId1, Some(txnMetadata1)))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))),
+      transactionManager.getAndMaybeAddTransactionState(transactionalId1))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))),
+      transactionManager.getAndMaybeAddTransactionState(transactionalId1, Some(txnMetadata2)))
   }
 
   @Test
@@ -157,35 +160,51 @@ class TransactionStateManagerTest {
     prepareTxnLog(topicPartition, startOffset, records)
 
     // this partition should not be part of the owned partitions
-    assertFalse(transactionManager.isCoordinatorFor(transactionalId1))
-    assertFalse(transactionManager.isCoordinatorFor(transactionalId2))
+    transactionManager.getAndMaybeAddTransactionState(transactionalId1).fold(
+      err => assertEquals(Errors.NOT_COORDINATOR, err),
+      _ => fail(transactionalId1 + "'s transaction state is already in the cache")
+    )
+    transactionManager.getAndMaybeAddTransactionState(transactionalId2).fold(
+      err => assertEquals(Errors.NOT_COORDINATOR, err),
+      _ => fail(transactionalId2 + "'s transaction state is already in the cache")
+    )
 
     transactionManager.loadTransactionsForTxnTopicPartition(partitionId, 0, (_, _, _, _, _) => ())
 
     // let the time advance to trigger the background thread loading
     scheduler.tick()
 
-    val cachedPidMetadata1 = transactionManager.getTransactionState(transactionalId1).getOrElse(fail(transactionalId1 + "'s transaction state was not loaded into the cache"))
-    val cachedPidMetadata2 = transactionManager.getTransactionState(transactionalId2).getOrElse(fail(transactionalId2 + "'s transaction state was not loaded into the cache"))
+    transactionManager.getAndMaybeAddTransactionState(transactionalId1).fold(
+      err => fail(transactionalId1 + "'s transaction state access returns error " + err),
+      entry => entry.getOrElse(fail(transactionalId1 + "'s transaction state was not loaded into the cache"))
+    )
+
+    val cachedPidMetadata1 = transactionManager.getAndMaybeAddTransactionState(transactionalId1).fold(
+      err => fail(transactionalId1 + "'s transaction state access returns error " + err),
+      entry => entry.getOrElse(fail(transactionalId1 + "'s transaction state was not loaded into the cache"))
+    )
+    val cachedPidMetadata2 = transactionManager.getAndMaybeAddTransactionState(transactionalId2).fold(
+      err => fail(transactionalId2 + "'s transaction state access returns error " + err),
+      entry => entry.getOrElse(fail(transactionalId2 + "'s transaction state was not loaded into the cache"))
+    )
 
     // they should be equal to the latest status of the transaction
     assertEquals(txnMetadata1, cachedPidMetadata1.transactionMetadata)
     assertEquals(txnMetadata2, cachedPidMetadata2.transactionMetadata)
 
-    // this partition should now be part of the owned partitions
-    assertTrue(transactionManager.isCoordinatorFor(transactionalId1))
-    assertTrue(transactionManager.isCoordinatorFor(transactionalId2))
-
     transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch)
 
     // let the time advance to trigger the background thread removing
     scheduler.tick()
 
-    assertFalse(transactionManager.isCoordinatorFor(transactionalId1))
-    assertFalse(transactionManager.isCoordinatorFor(transactionalId2))
-
-    assertEquals(None, transactionManager.getTransactionState(transactionalId1))
-    assertEquals(None, transactionManager.getTransactionState(transactionalId2))
+    transactionManager.getAndMaybeAddTransactionState(transactionalId1).fold(
+      err => assertEquals(Errors.NOT_COORDINATOR, err),
+      _ => fail(transactionalId1 + "'s transaction state is still in the cache")
+    )
+    transactionManager.getAndMaybeAddTransactionState(transactionalId2).fold(
+      err => assertEquals(Errors.NOT_COORDINATOR, err),
+      _ => fail(transactionalId2 + "'s transaction state is still in the cache")
+    )
   }
 
   @Test
@@ -193,7 +212,7 @@ class TransactionStateManagerTest {
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
 
     // first insert the initial transaction metadata
-    transactionManager.addTransaction(txnMetadata1)
+    transactionManager.getAndMaybeAddTransactionState(transactionalId1, Some(txnMetadata1))
 
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.NONE
@@ -205,9 +224,10 @@ class TransactionStateManagerTest {
     // append the new metadata into log
     transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch, newMetadata, assertCallback)
 
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getAndMaybeAddTransactionState(transactionalId1))
 
     // append to log again with expected failures
+    txnMetadata1.pendingState = None
     val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds())
 
     // test COORDINATOR_NOT_AVAILABLE cases
@@ -215,37 +235,37 @@ class TransactionStateManagerTest {
 
     prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION)
     transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getAndMaybeAddTransactionState(transactionalId1))
 
     prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS)
     transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getAndMaybeAddTransactionState(transactionalId1))
 
     prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND)
     transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getAndMaybeAddTransactionState(transactionalId1))
 
     prepareForTxnMessageAppend(Errors.REQUEST_TIMED_OUT)
     transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getAndMaybeAddTransactionState(transactionalId1))
 
     // test NOT_COORDINATOR cases
     expectedError = Errors.NOT_COORDINATOR
 
     prepareForTxnMessageAppend(Errors.NOT_LEADER_FOR_PARTITION)
     transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getAndMaybeAddTransactionState(transactionalId1))
 
     // test NOT_COORDINATOR cases
     expectedError = Errors.UNKNOWN
 
     prepareForTxnMessageAppend(Errors.MESSAGE_TOO_LARGE)
     transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getAndMaybeAddTransactionState(transactionalId1))
 
     prepareForTxnMessageAppend(Errors.RECORD_LIST_TOO_LARGE)
     transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
+    assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getAndMaybeAddTransactionState(transactionalId1))
   }
 
   @Test
@@ -253,7 +273,7 @@ class TransactionStateManagerTest {
     transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
 
     // first insert the initial transaction metadata
-    transactionManager.addTransaction(txnMetadata1)
+    transactionManager.getAndMaybeAddTransactionState(transactionalId1, Some(txnMetadata1))
 
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.NOT_COORDINATOR
@@ -271,7 +291,8 @@ class TransactionStateManagerTest {
   @Test(expected = classOf[IllegalStateException])
   def testAppendTransactionToLogWhilePendingStateChanged() = {
     // first insert the initial transaction metadata
-    transactionManager.addTransaction(txnMetadata1)
+    transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
+    transactionManager.getAndMaybeAddTransactionState(transactionalId1, Some(txnMetadata1))
 
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.INVALID_PRODUCER_EPOCH
@@ -287,8 +308,11 @@ class TransactionStateManagerTest {
   }
 
   @Test
-  def shouldReturnNoneIfTransactionIdPartitionNotOwned(): Unit = {
-    assertEquals(None, transactionManager.getTransactionState(transactionalId1))
+  def shouldReturnNotCooridnatorErrorIfTransactionIdPartitionNotOwned(): Unit = {
+    transactionManager.getAndMaybeAddTransactionState(transactionalId1).fold(
+      err => assertEquals(Errors.NOT_COORDINATOR, err),
+      _ => fail(transactionalId1 + "'s transaction state is already in the cache")
+    )
   }
 
   @Test
@@ -297,12 +321,12 @@ class TransactionStateManagerTest {
       transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
     }
 
-    transactionManager.addTransaction(transactionMetadata("ongoing", producerId = 0, state = Ongoing))
-    transactionManager.addTransaction(transactionMetadata("not-expiring", producerId = 1, state = Ongoing, txnTimeout = 10000))
-    transactionManager.addTransaction(transactionMetadata("prepare-commit", producerId = 2, state = PrepareCommit))
-    transactionManager.addTransaction(transactionMetadata("prepare-abort", producerId = 3, state = PrepareAbort))
-    transactionManager.addTransaction(transactionMetadata("complete-commit", producerId = 4, state = CompleteCommit))
-    transactionManager.addTransaction(transactionMetadata("complete-abort", producerId = 5, state = CompleteAbort))
+    transactionManager.getAndMaybeAddTransactionState("ongoing", Some(transactionMetadata("ongoing", producerId = 0, state = Ongoing)))
+    transactionManager.getAndMaybeAddTransactionState("not-expiring", Some(transactionMetadata("not-expiring", producerId = 1, state = Ongoing, txnTimeout = 10000)))
+    transactionManager.getAndMaybeAddTransactionState("prepare-commit", Some(transactionMetadata("prepare-commit", producerId = 2, state = PrepareCommit)))
+    transactionManager.getAndMaybeAddTransactionState("prepare-abort", Some(transactionMetadata("prepare-abort", producerId = 3, state = PrepareAbort)))
+    transactionManager.getAndMaybeAddTransactionState("complete-commit", Some(transactionMetadata("complete-commit", producerId = 4, state = CompleteCommit)))
+    transactionManager.getAndMaybeAddTransactionState("complete-abort", Some(transactionMetadata("complete-abort", producerId = 5, state = CompleteAbort)))
 
     time.sleep(2000)
     val expiring = transactionManager.transactionsToExpire()
@@ -388,15 +412,16 @@ class TransactionStateManagerTest {
       internalTopicsAllowed = EasyMock.eq(true),
       isFromClient = EasyMock.eq(false),
       EasyMock.anyObject().asInstanceOf[Map[TopicPartition, MemoryRecords]],
-      EasyMock.capture(capturedArgument)))
-      .andAnswer(new IAnswer[Unit] {
+      EasyMock.capture(capturedArgument),
+      EasyMock.anyObject())
+    ).andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = capturedArgument.getValue.apply(
           Map(new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId) ->
             new PartitionResponse(error, 0L, RecordBatch.NO_TIMESTAMP)
           )
         )
       }
-      )
+    )
     EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject()))
       .andStubReturn(Some(RecordBatch.MAGIC_VALUE_V1))
 


Mime
View raw message