kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From guozh...@apache.org
Subject [1/3] kafka git commit: KAFKA-5130: Refactor transaction coordinator's in-memory cache; plus fixes on transaction metadata synchronization
Date Fri, 12 May 2017 22:01:05 GMT
Repository: kafka
Updated Branches:
  refs/heads/trunk 7baa58d79 -> 794e6dbd1


http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
index 29240a6..d02e072 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
@@ -16,17 +16,12 @@
  */
 package kafka.coordinator.transaction
 
-import kafka.api.{LeaderAndIsr, PartitionStateInfo}
-import kafka.common.{BrokerEndPointNotAvailableException, InterBrokerSendThread}
-import kafka.controller.LeaderIsrAndControllerEpoch
 import kafka.server.{DelayedOperationPurgatory, KafkaConfig, MetadataCache}
-import kafka.utils.{MockTime, TestUtils}
 import kafka.utils.timer.MockTimer
+import kafka.utils.TestUtils
 import org.apache.kafka.clients.NetworkClient
-import org.apache.kafka.common.network.ListenerName
-import org.apache.kafka.common.protocol.{Errors, SecurityProtocol}
 import org.apache.kafka.common.requests.{TransactionResult, WriteTxnMarkersRequest}
-import org.apache.kafka.common.utils.Utils
+import org.apache.kafka.common.utils.{MockTime, Utils}
 import org.apache.kafka.common.{Node, TopicPartition}
 import org.easymock.EasyMock
 import org.junit.Assert._
@@ -36,241 +31,182 @@ import scala.collection.mutable
 
 class TransactionMarkerChannelManagerTest {
   private val metadataCache = EasyMock.createNiceMock(classOf[MetadataCache])
-  private val interBrokerSendThread = EasyMock.createNiceMock(classOf[InterBrokerSendThread])
   private val networkClient = EasyMock.createNiceMock(classOf[NetworkClient])
-  private val channel = new TransactionMarkerChannel(ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT),
-    metadataCache,
-    networkClient,
-    new MockTime())
-  private val purgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("txn-purgatory-name",
-    new MockTimer,
-    reaperEnabled = false)
-  private val requestGenerator = TransactionMarkerChannelManager.requestGenerator(channel, purgatory)
+  private val txnStateManager = EasyMock.createNiceMock(classOf[TransactionStateManager])
+
   private val partition1 = new TopicPartition("topic1", 0)
   private val partition2 = new TopicPartition("topic1", 1)
   private val broker1 = new Node(1, "host", 10)
   private val broker2 = new Node(2, "otherhost", 10)
-  private val metadataPartition = 0
+
+  private val transactionalId1 = "txnId1"
+  private val transactionalId2 = "txnId2"
+  private val transactionalId3 = "txnId3"
+  private val producerId1 = 0.asInstanceOf[Long]
+  private val producerId2 = 1.asInstanceOf[Long]
+  private val producerId3 = 1.asInstanceOf[Long]
+  private val producerEpoch = 0.asInstanceOf[Short]
+  private val txnTopicPartition1 = 0
+  private val txnTopicPartition2 = 1
+  private val coordinatorEpoch = 0
+  private val txnTimeoutMs = 0
+  private val txnResult = TransactionResult.COMMIT
+
+  private val txnMarkerPurgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("txn-purgatory-name",
+    new MockTimer,
+    reaperEnabled = false)
+  private val time = new MockTime
+
   private val channelManager = new TransactionMarkerChannelManager(
     KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:2181")),
     metadataCache,
-    purgatory,
-    interBrokerSendThread,
-    channel)
+    networkClient,
+    txnStateManager,
+    txnMarkerPurgatory,
+    time)
+
+  private val senderThread = channelManager.senderThread
+
+  private def mockCache(): Unit = {
+    EasyMock.expect(txnStateManager.partitionFor(transactionalId1))
+      .andReturn(txnTopicPartition1)
+      .anyTimes()
+    EasyMock.expect(txnStateManager.partitionFor(transactionalId2))
+      .andReturn(txnTopicPartition2)
+      .anyTimes()
+    EasyMock.replay(txnStateManager)
+  }
 
   @Test
   def shouldGenerateEmptyMapWhenNoRequestsOutstanding(): Unit = {
-    assertTrue(requestGenerator().isEmpty)
+    assertTrue(senderThread.generateRequests().isEmpty)
   }
 
   @Test
-  def shouldGenerateRequestPerBroker(): Unit ={
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
+  def shouldGenerateRequestPerBroker(): Unit = {
+    mockCache()
 
-    EasyMock.expect(metadataCache.getPartitionInfo(partition2.topic(), partition2.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(2, 0, List.empty, 0), 0), Set.empty)))
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.eq(partition1.topic),
+      EasyMock.eq(partition1.partition),
+      EasyMock.anyObject()))
+      .andReturn(Some(broker1))
+      .anyTimes()
 
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(2), EasyMock.anyObject())).andReturn(Some(broker2))
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.eq(partition2.topic),
+      EasyMock.eq(partition2.partition),
+      EasyMock.anyObject()))
+      .andReturn(Some(broker2))
+      .anyTimes()
 
     EasyMock.replay(metadataCache)
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1, partition2))
+    val txnMetadata = new TransactionMetadata(producerId1, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L)
+    channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
 
+    assertEquals(1 * 2, txnMarkerPurgatory.watched)
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
+    assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers())
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
+    assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition1))
 
     val expectedBroker1Request = new WriteTxnMarkersRequest.Builder(
-      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList(partition1)))).build()
+      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, TransactionResult.COMMIT, Utils.mkList(partition1)))).build()
     val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(
-      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList(partition2)))).build()
+      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, TransactionResult.COMMIT, Utils.mkList(partition2)))).build()
 
-    val requests: Map[Node, WriteTxnMarkersRequest] = requestGenerator().map{ result =>
-      (result.destination, result.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build())
+    val requests: Map[Node, WriteTxnMarkersRequest] = senderThread.generateRequests().map { handler =>
+      (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build())
     }.toMap
 
-    val broker1Request = requests(broker1)
-    val broker2Request = requests(broker2)
-
-    assertEquals(2, requests.size)
-    assertEquals(expectedBroker1Request, broker1Request)
-    assertEquals(expectedBroker2Request, broker2Request)
-
+    assertEquals(Map(broker1 -> expectedBroker1Request, broker2 -> expectedBroker2Request), requests)
+    assertTrue(senderThread.generateRequests().isEmpty)
   }
 
   @Test
-  def shouldGenerateRequestPerPartitionPerBroker(): Unit ={
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
+  def shouldGenerateRequestPerPartitionPerBroker(): Unit = {
+    mockCache()
 
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.eq(partition1.topic),
+      EasyMock.eq(partition1.partition),
+      EasyMock.anyObject()))
+      .andReturn(Some(broker1))
+      .anyTimes()
 
-    EasyMock.expect(metadataCache.getPartitionInfo(partition2.topic(), partition2.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1)).anyTimes()
     EasyMock.replay(metadataCache)
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
-    channel.addRequestToSend(1, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition2))
+    val txnMetadata1 = new TransactionMetadata(producerId1, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    val txnMetadata2 = new TransactionMetadata(producerId2, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds()))
+    channelManager.addTxnMarkersToSend(transactionalId2, coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds()))
 
-    val expectedPartition1Request = new WriteTxnMarkersRequest.Builder(
-      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList(partition1)))).build()
-    val expectedPartition2Request = new WriteTxnMarkersRequest.Builder(
-      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList(partition2)))).build()
+    assertEquals(2 * 2, txnMarkerPurgatory.watched)
+    assertEquals(2, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2))
 
-    val requests = requestGenerator().map { result =>
-      val markersRequest = result.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()
-      (result.destination, markersRequest)
-    }.toList
-
-    assertEquals(List((broker1, expectedPartition1Request), (broker1, expectedPartition2Request)), requests)
-  }
-
-  @Test
-  def shouldDrainBrokerQueueWhenGeneratingRequests(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-    EasyMock.replay(metadataCache)
+    val expectedBroker1Request = new WriteTxnMarkersRequest.Builder(
+      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, TransactionResult.COMMIT, Utils.mkList(partition1)),
+        new WriteTxnMarkersRequest.TxnMarkerEntry(producerId2, producerEpoch, coordinatorEpoch, TransactionResult.COMMIT, Utils.mkList(partition1)))).build()
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
+    val requests: Map[Node, WriteTxnMarkersRequest] = senderThread.generateRequests().map { handler =>
+      (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build())
+    }.toMap
 
-    val result = requestGenerator()
-    assertTrue(result.nonEmpty)
-    val result2 = requestGenerator()
-    assertTrue(result2.isEmpty)
+    assertEquals(Map(broker1 -> expectedBroker1Request), requests)
+    assertTrue(senderThread.generateRequests().isEmpty)
   }
 
   @Test
   def shouldRetryGettingLeaderWhenNotFound(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(None)
-      .andReturn(None)
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
+    mockCache()
 
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-    EasyMock.replay(metadataCache)
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.eq(partition1.topic),
+      EasyMock.eq(partition1.partition),
+      EasyMock.anyObject())
+    ).andReturn(None)
+     .andReturn(None)
+     .andReturn(Some(broker1))
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
-
-    EasyMock.verify(metadataCache)
-  }
-
-  @Test
-  def shouldRetryGettingLeaderWhenBrokerEndPointNotAvailableException(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-      .times(2)
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject()))
-      .andThrow(new BrokerEndPointNotAvailableException())
-      .andReturn(Some(broker1))
     EasyMock.replay(metadataCache)
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
+    channelManager.addTxnMarkersToBrokerQueue(transactionalId1, producerId1, producerEpoch, TransactionResult.COMMIT, coordinatorEpoch, Set[TopicPartition](partition1))
 
     EasyMock.verify(metadataCache)
   }
 
   @Test
-  def shouldRetryGettingLeaderWhenLeaderDoesntExist(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-      .times(2)
+  def shouldRemoveMarkersForTxnPartitionWhenPartitionEmigrated(): Unit = {
+    mockCache()
 
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject()))
-      .andReturn(None)
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.eq(partition1.topic),
+      EasyMock.eq(partition1.partition),
+      EasyMock.anyObject()))
       .andReturn(Some(broker1))
+      .anyTimes()
 
     EasyMock.replay(metadataCache)
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
+    val txnMetadata1 = new TransactionMetadata(producerId1, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds()))
 
-    EasyMock.verify(metadataCache)
-  }
-
-  @Test
-  def shouldAddPendingTxnRequest(): Unit = {
-    val metadata = new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0, 0L)
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getPartitionInfo(partition2.topic(), partition2.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(2, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(2), EasyMock.anyObject())).andReturn(Some(broker2))
-
-    EasyMock.replay(metadataCache)
-
-    channelManager.addTxnMarkerRequest(metadataPartition, metadata, 0, completionCallback)
+    val txnMetadata2 = new TransactionMetadata(producerId2, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    channelManager.addTxnMarkersToSend(transactionalId2, coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds()))
 
-    assertEquals(Some(metadata), channel.pendingTxnMetadata(metadataPartition, 1))
-
-  }
+    assertEquals(2 * 2, txnMarkerPurgatory.watched)
+    assertEquals(2, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2))
 
-  @Test
-  def shouldAddRequestToBrokerQueue(): Unit = {
-    val metadata = new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1), 0, 0L)
-
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-    EasyMock.replay(metadataCache)
-
-    channelManager.addTxnMarkerRequest(metadataPartition, metadata, 0, completionCallback)
-    assertEquals(1, requestGenerator().size)
-  }
-
-  @Test
-  def shouldAddDelayedTxnMarkerToPurgatory(): Unit = {
-    val metadata = new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1), 0, 0L)
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-
-    EasyMock.replay(metadataCache)
-
-    channelManager.addTxnMarkerRequest(metadataPartition, metadata, 0, completionCallback)
-    assertEquals(1,purgatory.watched)
-  }
-
-  @Test
-  def shouldStartInterBrokerThreadOnStartup(): Unit = {
-    EasyMock.expect(interBrokerSendThread.start())
-    EasyMock.replay(interBrokerSendThread)
-    channelManager.start()
-    EasyMock.verify(interBrokerSendThread)
-  }
-
-
-  @Test
-  def shouldStopInterBrokerThreadOnShutdown(): Unit = {
-    EasyMock.expect(interBrokerSendThread.shutdown())
-    EasyMock.replay(interBrokerSendThread)
-    channelManager.shutdown()
-    EasyMock.verify(interBrokerSendThread)
-  }
-
-  @Test
-  def shouldClearPurgatoryForPartitionWhenPartitionEmigrated(): Unit = {
-    val metadata1 = new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1), 0, 0)
-    purgatory.tryCompleteElseWatch(new DelayedTxnMarker(metadata1, (error:Errors) => {}),Seq(0L))
-    channel.maybeAddPendingRequest(0, metadata1)
-
-    val metadata2 = new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1), 0, 0)
-    purgatory.tryCompleteElseWatch(new DelayedTxnMarker(metadata2, (error:Errors) => {}),Seq(1L))
-    channel.maybeAddPendingRequest(0, metadata2)
-
-    val metadata3 = new TransactionMetadata(2, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1), 0, 0)
-    purgatory.tryCompleteElseWatch(new DelayedTxnMarker(metadata3, (error:Errors) => {}),Seq(2L))
-    channel.maybeAddPendingRequest(1, metadata3)
-
-    channelManager.removeStateForPartition(0)
-
-    assertEquals(1, purgatory.watched)
-    // should not complete as they've been removed
-    purgatory.checkAndComplete(0L)
-    purgatory.checkAndComplete(1L)
-    
-    assertEquals(1, purgatory.watched)
-  }
+    channelManager.removeMarkersForTxnTopicPartition(txnTopicPartition1)
 
-  def completionCallback(errors: Errors): Unit = {
+    assertEquals(1 * 2, txnMarkerPurgatory.watched)
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
+    assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2))
   }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelTest.scala
deleted file mode 100644
index 89a7606..0000000
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelTest.scala
+++ /dev/null
@@ -1,179 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package kafka.coordinator.transaction
-
-import kafka.api.{LeaderAndIsr, PartitionStateInfo}
-import kafka.controller.LeaderIsrAndControllerEpoch
-import kafka.server.{DelayedOperationPurgatory, MetadataCache}
-import kafka.utils.MockTime
-import kafka.utils.timer.MockTimer
-import org.apache.kafka.clients.NetworkClient
-import org.apache.kafka.common.network.ListenerName
-import org.apache.kafka.common.protocol.{Errors, SecurityProtocol}
-import org.apache.kafka.common.requests.{TransactionResult, WriteTxnMarkersRequest}
-import org.apache.kafka.common.utils.Utils
-import org.apache.kafka.common.{Node, TopicPartition}
-import org.easymock.EasyMock
-import org.junit.Assert._
-import org.junit.Test
-
-import scala.collection.mutable
-
-class TransactionMarkerChannelTest {
-
-  private val metadataCache = EasyMock.createNiceMock(classOf[MetadataCache])
-  private val networkClient = EasyMock.createNiceMock(classOf[NetworkClient])
-  private val purgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("name", new MockTimer, reaperEnabled = false)
-  private val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)
-  private val channel = new TransactionMarkerChannel(listenerName, metadataCache, networkClient, new MockTime())
-  private val partition1 = new TopicPartition("topic1", 0)
-
-
-  @Test
-  def shouldAddEmptyBrokerQueueWhenAddingNewBroker(): Unit = {
-    channel.addOrUpdateBroker(new Node(1, "host", 10))
-    channel.addOrUpdateBroker(new Node(2, "host", 10))
-    assertEquals(0, channel.queueForBroker(1).get.eachMetadataPartition{case(partition:Int, _) => partition}.size)
-    assertEquals(0, channel.queueForBroker(2).get.eachMetadataPartition{case(partition:Int, _) => partition}.size)
-  }
-
-  @Test
-  def shouldUpdateDestinationBrokerNodeWhenUpdatingBroker(): Unit = {
-    val newDestination = new Node(1, "otherhost", 100)
-
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-
-    // getAliveEndpoint returns an updated node
-    EasyMock.expect(metadataCache.getAliveEndpoint(1, listenerName)).andReturn(Some(newDestination))
-    EasyMock.replay(metadataCache)
-
-    channel.addOrUpdateBroker(new Node(1, "host", 10))
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
-
-    val brokerRequestQueue = channel.queueForBroker(1).get
-    assertEquals(newDestination, brokerRequestQueue.node)
-    assertEquals(1, brokerRequestQueue.totalQueuedRequests())
-  }
-
-
-  @Test
-  def shouldQueueRequestsByBrokerId(): Unit = {
-    channel.addOrUpdateBroker(new Node(1, "host", 10))
-    channel.addOrUpdateBroker(new Node(2, "otherhost", 10))
-    channel.addRequestForBroker(1, 0, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-    channel.addRequestForBroker(1, 0, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-    channel.addRequestForBroker(2, 0, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-
-    assertEquals(2, channel.queueForBroker(1).get.totalQueuedRequests())
-    assertEquals(1, channel.queueForBroker(2).get.totalQueuedRequests())
-  }
-
-  @Test
-  def shouldNotAddPendingTxnIfOneAlreadyExistsForPid(): Unit = {
-    channel.maybeAddPendingRequest(0, new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0))
-    assertFalse(channel.maybeAddPendingRequest(0, new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0)))
-  }
-
-  @Test
-  def shouldAddRequestsToCorrectBrokerQueues(): Unit = {
-    val partition2 = new TopicPartition("topic1", 1)
-
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-
-    EasyMock.expect(metadataCache.getPartitionInfo(partition2.topic(), partition2.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(2, 0, List.empty, 0), 0), Set.empty)))
-
-    EasyMock.expect(metadataCache.getAliveEndpoint(1, listenerName)).andReturn(Some(new Node(1, "host", 10)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(2, listenerName)).andReturn(Some(new Node(2, "otherhost", 10)))
-
-    EasyMock.replay(metadataCache)
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1, partition2))
-
-    assertEquals(1, channel.queueForBroker(1).get.totalQueuedRequests())
-    assertEquals(1, channel.queueForBroker(2).get.totalQueuedRequests())
-  }
-  @Test
-  def shouldWakeupNetworkClientWhenRequestsQueued(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(1, listenerName)).andReturn(Some(new Node(1, "host", 10)))
-
-    EasyMock.expect(networkClient.wakeup())
-
-    EasyMock.replay(metadataCache, networkClient)
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
-
-    EasyMock.verify(networkClient)
-  }
-
-  @Test
-  def shouldAddNewBrokerQueueIfDoesntAlreadyExistWhenAddingRequest(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(1, listenerName)).andReturn(Some(new Node(1, "host", 10)))
-
-    EasyMock.replay(metadataCache)
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
-
-    assertEquals(1, channel.queueForBroker(1).get.totalQueuedRequests())
-    EasyMock.verify(metadataCache)
-  }
-
-  @Test
-  def shouldGetPendingTxnMetadataByPid(): Unit = {
-    val metadataPartition = 0
-    val transaction = new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0)
-    channel.maybeAddPendingRequest(metadataPartition, transaction)
-    channel.maybeAddPendingRequest(metadataPartition, new TransactionMetadata(2, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0))
-    assertEquals(Some(transaction), channel.pendingTxnMetadata(metadataPartition, 1))
-  }
-
-  @Test
-  def shouldRemovePendingRequestsForPartitionWhenPartitionEmigrated(): Unit = {
-    channel.maybeAddPendingRequest(0, new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0))
-    channel.maybeAddPendingRequest(0, new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0))
-    val metadata = new TransactionMetadata(2, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0)
-    channel.maybeAddPendingRequest(1, metadata)
-
-    channel.removeStateForPartition(0)
-
-    assertEquals(None, channel.pendingTxnMetadata(0, 0))
-    assertEquals(None, channel.pendingTxnMetadata(0, 1))
-    assertEquals(Some(metadata), channel.pendingTxnMetadata(1, 2))
-  }
-
-  @Test
-  def shouldRemoveBrokerRequestsForPartitionWhenPartitionEmigrated(): Unit = {
-    channel.addOrUpdateBroker(new Node(1, "host", 10))
-    channel.addRequestForBroker(1, 0, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-    channel.addRequestForBroker(1, 1, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-    channel.addRequestForBroker(1, 1, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-
-    channel.removeStateForPartition(1)
-
-
-    val result = channel.queueForBroker(1).get.eachMetadataPartition{case (partition:Int, _) => partition}.toList
-    assertEquals(List(0), result)
-  }
-
-
-
-  def errorCallback(error: Errors): Unit = {}
-}

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
index 096b826..082d441 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
@@ -18,14 +18,12 @@ package kafka.coordinator.transaction
 
 import java.{lang, util}
 
-import kafka.server.DelayedOperationPurgatory
-import kafka.utils.timer.MockTimer
 import org.apache.kafka.clients.ClientResponse
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse}
 import org.apache.kafka.common.utils.Utils
-import org.easymock.EasyMock
+import org.easymock.{EasyMock, IAnswer}
 import org.junit.Assert._
 import org.junit.Test
 
@@ -33,71 +31,129 @@ import scala.collection.mutable
 
 class TransactionMarkerRequestCompletionHandlerTest {
 
-  private val markerChannel = EasyMock.createNiceMock(classOf[TransactionMarkerChannel])
-  private val purgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("txn-purgatory-name", new MockTimer, reaperEnabled = false)
-  private val topic1 = new TopicPartition("topic1", 0)
-  private val txnMarkers =
+  private val brokerId = 0
+  private val txnTopicPartition = 0
+  private val transactionalId = "txnId1"
+  private val producerId = 0.asInstanceOf[Long]
+  private val producerEpoch = 0.asInstanceOf[Short]
+  private val txnTimeoutMs = 0
+  private val coordinatorEpoch = 0
+  private val txnResult = TransactionResult.COMMIT
+  private val topicPartition = new TopicPartition("topic1", 0)
+  private val txnIdAndMarkers =
     Utils.mkList(
-      new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList(topic1)))
+      TxnIdAndMarkerEntry(transactionalId, new WriteTxnMarkersRequest.TxnMarkerEntry(producerId, producerEpoch, coordinatorEpoch, txnResult, Utils.mkList(topicPartition))))
 
-  private val handler = new TransactionMarkerRequestCompletionHandler(markerChannel, purgatory, 0, txnMarkers, 0)
+  private val txnMetadata = new TransactionMetadata(producerId, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L)
+
+  private val markerChannelManager = EasyMock.createNiceMock(classOf[TransactionMarkerChannelManager])
+
+  private val txnStateManager = EasyMock.createNiceMock(classOf[TransactionStateManager])
+
+  private val handler = new TransactionMarkerRequestCompletionHandler(brokerId, txnStateManager, markerChannelManager, txnIdAndMarkers)
+
+  private def mockCache(): Unit = {
+    EasyMock.expect(txnStateManager.partitionFor(transactionalId))
+      .andReturn(txnTopicPartition)
+      .anyTimes()
+    EasyMock.expect(txnStateManager.getTransactionState(transactionalId))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
+      .anyTimes()
+    EasyMock.replay(txnStateManager)
+  }
 
   @Test
   def shouldReEnqueuePartitionsWhenBrokerDisconnected(): Unit = {
-    EasyMock.expect(markerChannel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](topic1)))
-    EasyMock.replay(markerChannel)
+    mockCache()
+
+    EasyMock.expect(markerChannelManager.addTxnMarkersToBrokerQueue(transactionalId,
+      producerId, producerEpoch, txnResult, coordinatorEpoch, Set[TopicPartition](topicPartition)))
+    EasyMock.replay(markerChannelManager)
 
     handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, true, null, null))
 
-    EasyMock.verify(markerChannel)
+    EasyMock.verify(markerChannelManager)
   }
 
   @Test
-  def shouldThrowIllegalStateExceptionIfErrorsNullForPid(): Unit = {
-    val response = new WriteTxnMarkersResponse(new java.util.HashMap[java.lang.Long, java.util.Map[TopicPartition, Errors]]())
+  def shouldThrowIllegalStateExceptionIfErrorCodeNotAvailableForPid(): Unit = {
+    mockCache()
+    EasyMock.replay(markerChannelManager)
 
-    EasyMock.replay(markerChannel)
+    val response = new WriteTxnMarkersResponse(new java.util.HashMap[java.lang.Long, java.util.Map[TopicPartition, Errors]]())
 
     try {
       handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
       fail("should have thrown illegal argument exception")
     } catch {
-      case ise: IllegalStateException => // ok
+      case _: IllegalStateException => // ok
     }
   }
 
   @Test
-  def shouldRemoveCompletedPartitionsFromMetadataWhenNoErrors(): Unit = {
-    val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE))
+  def shouldCompleteDelayedOperationWhenNoErrors(): Unit = {
+    mockCache()
 
-    val metadata = new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set[TopicPartition](topic1), 0, 0)
-    EasyMock.expect(markerChannel.pendingTxnMetadata(0, 0))
-      .andReturn(Some(metadata))
-    EasyMock.replay(markerChannel)
+    verifyCompleteDelayedOperationOnError(Errors.NONE)
+  }
 
-    handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
+  @Test
+  def shouldCompleteDelayedOperationWhenNoMetadata(): Unit = {
+    EasyMock.expect(txnStateManager.getTransactionState(transactionalId))
+      .andReturn(None)
+      .anyTimes()
+    EasyMock.replay(txnStateManager)
 
-    assertTrue(metadata.topicPartitions.isEmpty)
+    verifyRemoveDelayedOperationOnError(Errors.NONE)
   }
 
   @Test
-  def shouldTryCompleteDelayedTxnOperation(): Unit = {
-    val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE))
+  def shouldCompleteDelayedOperationWhenCoordinatorEpochChanged(): Unit = {
+    EasyMock.expect(txnStateManager.getTransactionState(transactionalId))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch+1, txnMetadata)))
+      .anyTimes()
+    EasyMock.replay(txnStateManager)
 
-    val metadata = new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set[TopicPartition](topic1), 0, 0)
-    var completed = false
+    verifyRemoveDelayedOperationOnError(Errors.NONE)
+  }
 
-    purgatory.tryCompleteElseWatch(new DelayedTxnMarker(metadata, (errors:Errors) => {
-      completed = true
-    }), Seq(0L))
+  @Test
+  def shouldCompleteDelayedOperationWhenInvalidProducerEpoch(): Unit = {
+    mockCache()
 
-    EasyMock.expect(markerChannel.pendingTxnMetadata(0, 0))
-      .andReturn(Some(metadata))
+    verifyRemoveDelayedOperationOnError(Errors.INVALID_PRODUCER_EPOCH)
+  }
 
-    EasyMock.replay(markerChannel)
+  @Test
+  def shouldCompleteDelayedOperationWheCoordinatorEpochFenced(): Unit = {
+    mockCache()
 
-    handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
-    assertTrue(completed)
+    verifyRemoveDelayedOperationOnError(Errors.TRANSACTION_COORDINATOR_FENCED)
+  }
+
+  @Test
+  def shouldThrowIllegalStateExceptionWhenUnknownError(): Unit = {
+    verifyThrowIllegalStateExceptionOnError(Errors.UNKNOWN)
+  }
+
+  @Test
+  def shouldThrowIllegalStateExceptionWhenCorruptMessageError(): Unit = {
+    verifyThrowIllegalStateExceptionOnError(Errors.CORRUPT_MESSAGE)
+  }
+
+  @Test
+  def shouldThrowIllegalStateExceptionWhenMessageTooLargeError(): Unit = {
+    verifyThrowIllegalStateExceptionOnError(Errors.MESSAGE_TOO_LARGE)
+  }
+
+  @Test
+  def shouldThrowIllegalStateExceptionWhenRecordListTooLargeError(): Unit = {
+    verifyThrowIllegalStateExceptionOnError(Errors.RECORD_LIST_TOO_LARGE)
+  }
+
+  @Test
+  def shouldThrowIllegalStateExceptionWhenInvalidRequiredAcksError(): Unit = {
+    verifyThrowIllegalStateExceptionOnError(Errors.INVALID_REQUIRED_ACKS)
   }
 
   @Test
@@ -120,40 +176,75 @@ class TransactionMarkerRequestCompletionHandlerTest {
     verifyRetriesPartitionOnError(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND)
   }
 
-  @Test
-  def shouldThrowIllegalStateExceptionWhenErrorNotHandled(): Unit = {
-    val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.UNKNOWN))
-    val metadata = new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set[TopicPartition](topic1), 0, 0)
-    EasyMock.replay(markerChannel)
+  private def verifyRetriesPartitionOnError(error: Errors) = {
+    mockCache()
+
+    EasyMock.expect(markerChannelManager.addTxnMarkersToBrokerQueue(transactionalId,
+      producerId, producerEpoch, txnResult, coordinatorEpoch, Set[TopicPartition](topicPartition)))
+    EasyMock.replay(markerChannelManager)
+
+    val response = new WriteTxnMarkersResponse(createPidErrorMap(error))
+    handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
+
+    assertEquals(txnMetadata.topicPartitions, mutable.Set[TopicPartition](topicPartition))
+    EasyMock.verify(markerChannelManager)
+  }
 
+  private def verifyThrowIllegalStateExceptionOnError(error: Errors) = {
+    mockCache()
+
+    val response = new WriteTxnMarkersResponse(createPidErrorMap(error))
     try {
       handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
       fail("should have thrown illegal state exception")
     } catch {
-      case ise: IllegalStateException => // ol
+      case _: IllegalStateException => // ok
     }
+  }
+
+  private def verifyCompleteDelayedOperationOnError(error: Errors): Unit = {
 
+    var completed = false
+    EasyMock.expect(markerChannelManager.completeSendMarkersForTxnId(transactionalId))
+      .andAnswer(new IAnswer[Unit] {
+        override def answer(): Unit = {
+          completed = true
+        }
+      })
+      .once()
+    EasyMock.replay(markerChannelManager)
+
+    val response = new WriteTxnMarkersResponse(createPidErrorMap(error))
+    handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
+
+    assertTrue(txnMetadata.topicPartitions.isEmpty)
+    assertTrue(completed)
   }
 
-  private def verifyRetriesPartitionOnError(errors: Errors) = {
-    val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.UNKNOWN_TOPIC_OR_PARTITION))
-    val metadata = new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set[TopicPartition](topic1), 0, 0)
+  private def verifyRemoveDelayedOperationOnError(error: Errors): Unit = {
 
-    EasyMock.expect(markerChannel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](topic1)))
-    EasyMock.replay(markerChannel)
+    var removed = false
+    EasyMock.expect(markerChannelManager.removeMarkersForTxnId(transactionalId))
+      .andAnswer(new IAnswer[Unit] {
+        override def answer(): Unit = {
+          removed = true
+        }
+      })
+      .once()
+    EasyMock.replay(markerChannelManager)
 
+    val response = new WriteTxnMarkersResponse(createPidErrorMap(error))
     handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
 
-    assertEquals(metadata.topicPartitions, mutable.Set[TopicPartition](topic1))
-    EasyMock.verify(markerChannel)
+    assertTrue(removed)
   }
 
+
   private def createPidErrorMap(errors: Errors) = {
     val pidMap = new java.util.HashMap[lang.Long, util.Map[TopicPartition, Errors]]()
     val errorsMap = new util.HashMap[TopicPartition, Errors]()
-    errorsMap.put(topic1, errors)
-    pidMap.put(0L, errorsMap)
+    errorsMap.put(topicPartition, errors)
+    pidMap.put(producerId, errorsMap)
     pidMap
   }
-
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index 2a14898..0250f60 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -22,13 +22,14 @@ import kafka.common.Topic
 import kafka.common.Topic.TransactionStateTopicName
 import kafka.log.Log
 import kafka.server.{FetchDataInfo, LogOffsetMetadata, ReplicaManager}
-import kafka.utils.{MockScheduler, ZkUtils}
+import kafka.utils.{MockScheduler, Pool, ZkUtils}
 import kafka.utils.TestUtils.fail
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.requests.IsolationLevel
 import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
+import org.apache.kafka.common.requests.TransactionResult
 import org.apache.kafka.common.utils.MockTime
 import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
 import org.junit.{After, Before, Test}
@@ -44,6 +45,7 @@ class TransactionStateManagerTest {
   val numPartitions = 2
   val transactionTimeoutMs: Int = 1000
   val topicPartition = new TopicPartition(TransactionStateTopicName, partitionId)
+  val coordinatorEpoch = 10
 
   val txnRecords: mutable.ArrayBuffer[SimpleRecord] = mutable.ArrayBuffer[SimpleRecord]()
 
@@ -95,10 +97,12 @@ class TransactionStateManagerTest {
 
   @Test
   def testAddGetPids() {
+    transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
+
     assertEquals(None, transactionManager.getTransactionState(txnId1))
-    assertEquals(txnMetadata1, transactionManager.addTransaction(txnId1, txnMetadata1))
-    assertEquals(Some(txnMetadata1), transactionManager.getTransactionState(txnId1))
-    assertEquals(txnMetadata1, transactionManager.addTransaction(txnId1, txnMetadata2))
+    assertEquals(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1), transactionManager.addTransaction(txnId1, txnMetadata1))
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
+    assertEquals(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1), transactionManager.addTransaction(txnId1, txnMetadata2))
   }
 
   @Test
@@ -110,19 +114,19 @@ class TransactionStateManagerTest {
     txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
       new TopicPartition("topic1", 1)))
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit()))
 
     // pid1's transaction adds three more partitions
     txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0),
       new TopicPartition("topic2", 1),
       new TopicPartition("topic2", 2)))
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit()))
 
     // pid1's transaction is preparing to commit
     txnMetadata1.state = PrepareCommit
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit()))
 
     // pid2's transaction started with three partitions
     txnMetadata2.state = Ongoing
@@ -130,23 +134,23 @@ class TransactionStateManagerTest {
       new TopicPartition("topic3", 1),
       new TopicPartition("topic3", 2)))
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit()))
 
     // pid2's transaction is preparing to abort
     txnMetadata2.state = PrepareAbort
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit()))
 
     // pid2's transaction has aborted
     txnMetadata2.state = CompleteAbort
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit()))
 
     // pid2's epoch has advanced, with no ongoing transaction yet
     txnMetadata2.state = Empty
     txnMetadata2.topicPartitions.clear()
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit()))
 
     val startOffset = 15L   // it should work for any start offset
     val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecords: _*)
@@ -157,7 +161,7 @@ class TransactionStateManagerTest {
     assertFalse(transactionManager.isCoordinatorFor(txnId1))
     assertFalse(transactionManager.isCoordinatorFor(txnId2))
 
-    transactionManager.loadTransactionsForPartition(partitionId, 0, _ => ())
+    transactionManager.loadTransactionsForTxnTopicPartition(partitionId, 0, (_, _, _, _, _) => ())
 
     // let the time advance to trigger the background thread loading
     scheduler.tick()
@@ -166,14 +170,14 @@ class TransactionStateManagerTest {
     val cachedPidMetadata2 = transactionManager.getTransactionState(txnId2).getOrElse(fail(txnId2 + "'s transaction state was not loaded into the cache"))
 
     // they should be equal to the latest status of the transaction
-    assertEquals(txnMetadata1, cachedPidMetadata1)
-    assertEquals(txnMetadata2, cachedPidMetadata2)
+    assertEquals(txnMetadata1, cachedPidMetadata1.transactionMetadata)
+    assertEquals(txnMetadata2, cachedPidMetadata2.transactionMetadata)
 
     // this partition should now be part of the owned partitions
     assertTrue(transactionManager.isCoordinatorFor(txnId1))
     assertTrue(transactionManager.isCoordinatorFor(txnId2))
 
-    transactionManager.removeTransactionsForPartition(partitionId)
+    transactionManager.removeTransactionsForTxnTopicPartition(partitionId)
 
     // let the time advance to trigger the background thread removing
     scheduler.tick()
@@ -187,6 +191,8 @@ class TransactionStateManagerTest {
 
   @Test
   def testAppendTransactionToLog() {
+    transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
+
     // first insert the initial transaction metadata
     transactionManager.addTransaction(txnId1, txnMetadata1)
 
@@ -194,78 +200,73 @@ class TransactionStateManagerTest {
     expectedError = Errors.NONE
 
     // update the metadata to ongoing with two partitions
-    val newMetadata = txnMetadata1.copy()
-    newMetadata.state = Ongoing
-    newMetadata.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
-      new TopicPartition("topic1", 1)))
-    txnMetadata1.prepareTransitionTo(Ongoing)
+    val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
+      new TopicPartition("topic1", 1)), time.milliseconds())
 
     // append the new metadata into log
-    transactionManager.appendTransactionToLog(txnId1, newMetadata, assertCallback)
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch, newMetadata, assertCallback)
 
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     // append to log again with expected failures
-    val failedMetadata = newMetadata.copy()
-    failedMetadata.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)))
+    val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds())
 
     // test COORDINATOR_NOT_AVAILABLE cases
     expectedError = Errors.COORDINATOR_NOT_AVAILABLE
 
     prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     prepareForTxnMessageAppend(Errors.REQUEST_TIMED_OUT)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     // test NOT_COORDINATOR cases
     expectedError = Errors.NOT_COORDINATOR
 
     prepareForTxnMessageAppend(Errors.NOT_LEADER_FOR_PARTITION)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     // test NOT_COORDINATOR cases
     expectedError = Errors.UNKNOWN
 
     prepareForTxnMessageAppend(Errors.MESSAGE_TOO_LARGE)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     prepareForTxnMessageAppend(Errors.RECORD_LIST_TOO_LARGE)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
   }
 
-  @Test(expected = classOf[IllegalStateException])
+  @Test
   def testAppendTransactionToLogWhileProducerFenced() = {
+    transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
+
     // first insert the initial transaction metadata
     transactionManager.addTransaction(txnId1, txnMetadata1)
 
     prepareForTxnMessageAppend(Errors.NONE)
-    expectedError = Errors.INVALID_PRODUCER_EPOCH
+    expectedError = Errors.NOT_COORDINATOR
 
-    val newMetadata = txnMetadata1.copy()
-    newMetadata.state = Ongoing
-    newMetadata.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
-      new TopicPartition("topic1", 1)))
-    txnMetadata1.prepareTransitionTo(Ongoing)
+    val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
+      new TopicPartition("topic1", 1)), time.milliseconds())
 
     // modify the cache while trying to append the new metadata
     txnMetadata1.producerEpoch = (txnMetadata1.producerEpoch + 1).toShort
 
     // append the new metadata into log
-    transactionManager.appendTransactionToLog(txnId1, newMetadata, assertCallback)
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, newMetadata, assertCallback)
   }
 
   @Test(expected = classOf[IllegalStateException])
@@ -276,38 +277,29 @@ class TransactionStateManagerTest {
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.INVALID_PRODUCER_EPOCH
 
-    val newMetadata = txnMetadata1.copy()
-    newMetadata.state = Ongoing
-    newMetadata.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
-      new TopicPartition("topic1", 1)))
-    txnMetadata1.prepareTransitionTo(Ongoing)
+    val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
+      new TopicPartition("topic1", 1)), time.milliseconds())
 
     // modify the cache while trying to append the new metadata
     txnMetadata1.pendingState = None
 
     // append the new metadata into log
-    transactionManager.appendTransactionToLog(txnId1, newMetadata, assertCallback)
-  }
-
-  @Test
-  def shouldReturnEpochForTransactionId(): Unit = {
-    val coordinatorEpoch = 10
-    EasyMock.expect(replicaManager.getLog(EasyMock.anyObject(classOf[TopicPartition]))).andReturn(None)
-    EasyMock.replay(replicaManager)
-    transactionManager.loadTransactionsForPartition(partitionId, coordinatorEpoch, _ => ())
-    val epoch = transactionManager.coordinatorEpochFor(txnId1).get
-    assertEquals(coordinatorEpoch, epoch)
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, newMetadata, assertCallback)
   }
 
   @Test
   def shouldReturnNoneIfTransactionIdPartitionNotOwned(): Unit = {
-    assertEquals(None, transactionManager.coordinatorEpochFor(txnId1))
+    assertEquals(None, transactionManager.getTransactionState(txnId1))
   }
 
   @Test
   def shouldOnlyConsiderTransactionsInTheOngoingStateForExpiry(): Unit = {
+    for (partitionId <- 0 until numPartitions) {
+      transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
+    }
+
     txnMetadata1.state = Ongoing
-    txnMetadata1.transactionStartTime = time.milliseconds()
+    txnMetadata1.txnStartTimestamp = time.milliseconds()
     transactionManager.addTransaction(txnId1, txnMetadata1)
     transactionManager.addTransaction(txnId2, txnMetadata2)
 
@@ -333,7 +325,7 @@ class TransactionStateManagerTest {
 
     time.sleep(2000)
     val expiring = transactionManager.transactionsToExpire()
-    assertEquals(List(TransactionalIdAndMetadata(txnId1, txnMetadata1)), expiring)
+    assertEquals(List(TransactionalIdAndProducerIdEpoch(txnId1, txnMetadata1.producerId, txnMetadata1.producerEpoch)), expiring)
   }
 
   @Test
@@ -351,17 +343,25 @@ class TransactionStateManagerTest {
     txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
       new TopicPartition("topic1", 1)))
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit()))
     val startOffset = 0L
     val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecords: _*)
 
     prepareTxnLog(topicPartition, 0, records)
 
-    var receivedArgs: WriteTxnMarkerArgs = null
-    transactionManager.loadTransactionsForPartition(partitionId, 0, markerArgs => receivedArgs = markerArgs)
+    var txnId: String = null
+    def rememberTxnMarkers(transactionalId: String,
+                           coordinatorEpoch: Int,
+                           command: TransactionResult,
+                           metadata: TransactionMetadata,
+                           newMetadata: TransactionMetadataTransition): Unit = {
+      txnId = transactionalId
+    }
+
+    transactionManager.loadTransactionsForTxnTopicPartition(partitionId, 0, rememberTxnMarkers)
     scheduler.tick()
 
-    assertEquals(txnId1, receivedArgs.transactionalId)
+    assertEquals(txnId1, txnId)
   }
 
   private def assertCallback(error: Errors): Unit = {
@@ -414,5 +414,4 @@ class TransactionStateManagerTest {
 
     EasyMock.replay(replicaManager)
   }
-
 }


Mime
View raw message