kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From j...@apache.org
Subject [kafka] branch trunk updated: KAFKA-6096: Add multi-threaded tests for group coordinator, txn manager (#4122)
Date Tue, 09 Jan 2018 00:15:38 GMT
This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 6396d01  KAFKA-6096: Add multi-threaded tests for group coordinator, txn manager
(#4122)
6396d01 is described below

commit 6396d01957ea355c21d658c3614190458229fa5b
Author: Rajini Sivaram <rajinisivaram@googlemail.com>
AuthorDate: Tue Jan 9 00:15:35 2018 +0000

    KAFKA-6096: Add multi-threaded tests for group coordinator, txn manager (#4122)
    
    Reviewers: Jason Gustafson <jason@confluent.io>
---
 .../AbstractCoordinatorConcurrencyTest.scala       | 226 ++++++++++++
 .../group/GroupCoordinatorConcurrencyTest.scala    | 310 ++++++++++++++++
 .../TransactionCoordinatorConcurrencyTest.scala    | 388 +++++++++++++++++++++
 .../scala/unit/kafka/utils/timer/MockTimer.scala   |  31 +-
 4 files changed, 946 insertions(+), 9 deletions(-)

diff --git a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
new file mode 100644
index 0000000..0ecc3f5
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
@@ -0,0 +1,226 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.coordinator
+
+import java.util.{ Collections, Random }
+import java.util.concurrent.{ ConcurrentHashMap, Executors }
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.locks.Lock
+
+import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
+import kafka.log.Log
+import kafka.server._
+import kafka.utils._
+import kafka.utils.timer.MockTimer
+import kafka.zk.KafkaZkClient
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.record.{ MemoryRecords, RecordBatch, RecordsProcessingStats
}
+import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
+import org.easymock.EasyMock
+import org.junit.{ After, Before }
+
+import scala.collection._
+import scala.collection.JavaConverters._
+
+abstract class AbstractCoordinatorConcurrencyTest[M <: CoordinatorMember] {
+
+  val nThreads = 5
+
+  val time = new MockTime
+  val timer = new MockTimer
+  val executor = Executors.newFixedThreadPool(nThreads)
+  val scheduler = new MockScheduler(time)
+  var replicaManager: TestReplicaManager = _
+  var zkClient: KafkaZkClient = _
+  val serverProps = TestUtils.createBrokerConfig(nodeId = 0, zkConnect = "")
+  val random = new Random
+
+  @Before
+  def setUp() {
+
+    replicaManager = EasyMock.partialMockBuilder(classOf[TestReplicaManager]).createMock()
+    replicaManager.createDelayedProducePurgatory(timer)
+
+    zkClient = EasyMock.createNiceMock(classOf[KafkaZkClient])
+  }
+
+  @After
+  def tearDown() {
+    EasyMock.reset(replicaManager)
+    if (executor != null)
+      executor.shutdownNow()
+  }
+
+  /**
+    * Verify that concurrent operations run in the normal sequence produce the expected results.
+    */
+  def verifyConcurrentOperations(createMembers: String => Set[M], operations: Seq[Operation])
{
+    OrderedOperationSequence(createMembers("verifyConcurrentOperations"), operations).run()
+  }
+
+  /**
+    * Verify that arbitrary operations run in some random sequence don't leave the coordinator
+    * in a bad state. Operations in the normal sequence should continue to work as expected.
+    */
+  def verifyConcurrentRandomSequences(createMembers: String => Set[M], operations: Seq[Operation])
{
+    EasyMock.reset(replicaManager)
+    for (i <- 0 to 10) {
+      // Run some random operations
+      RandomOperationSequence(createMembers(s"random$i"), operations).run()
+
+      // Check that proper sequences still work correctly
+      OrderedOperationSequence(createMembers(s"ordered$i"), operations).run()
+    }
+  }
+
+  def verifyConcurrentActions(actions: Set[Action]) {
+    val futures = actions.map(executor.submit)
+    futures.map(_.get)
+    enableCompletion()
+    actions.foreach(_.await())
+  }
+
+  def enableCompletion(): Unit = {
+    replicaManager.tryCompleteDelayedRequests()
+    scheduler.tick()
+  }
+
+  abstract class OperationSequence(members: Set[M], operations: Seq[Operation]) {
+    def actionSequence: Seq[Set[Action]]
+    def run(): Unit = {
+      actionSequence.foreach(verifyConcurrentActions)
+    }
+  }
+
+  case class OrderedOperationSequence(members: Set[M], operations: Seq[Operation])
+    extends OperationSequence(members, operations) {
+    override def actionSequence: Seq[Set[Action]] = {
+      operations.map { op =>
+        members.map(op.actionWithVerify)
+      }
+    }
+  }
+
+  case class RandomOperationSequence(members: Set[M], operations: Seq[Operation])
+    extends OperationSequence(members, operations) {
+    val opCount = operations.length
+    def actionSequence: Seq[Set[Action]] = {
+      (0 to opCount).map { _ =>
+        members.map { member =>
+          val op = operations(random.nextInt(opCount))
+          op.actionNoVerify(member) // Don't wait or verify since these operations may block
+        }
+      }
+    }
+  }
+
+  abstract class Operation {
+    def run(member: M): Unit
+    def awaitAndVerify(member: M): Unit
+    def actionWithVerify(member: M): Action = {
+      new Action() {
+        def run(): Unit = Operation.this.run(member)
+        def await(): Unit = awaitAndVerify(member)
+      }
+    }
+    def actionNoVerify(member: M): Action = {
+      new Action() {
+        def run(): Unit = Operation.this.run(member)
+        def await(): Unit = timer.advanceClock(100) // Don't wait since operation may block
+      }
+    }
+  }
+}
+
+object AbstractCoordinatorConcurrencyTest {
+
+  trait Action extends Runnable {
+    def await(): Unit
+  }
+
+  trait CoordinatorMember {
+  }
+
+  class TestReplicaManager extends ReplicaManager(
+    null, null, null, null, null, null, null, null, null, null, null, null, null, null, None)
{
+
+    var producePurgatory: DelayedOperationPurgatory[DelayedProduce] = _
+    var watchKeys: mutable.Set[TopicPartitionOperationKey] = _
+    def createDelayedProducePurgatory(timer: MockTimer): Unit = {
+      producePurgatory = new DelayedOperationPurgatory[DelayedProduce]("Produce", timer,
1, reaperEnabled = false)
+      watchKeys = Collections.newSetFromMap(new ConcurrentHashMap[TopicPartitionOperationKey,
java.lang.Boolean]()).asScala
+    }
+    def tryCompleteDelayedRequests(): Unit = {
+      watchKeys.map(producePurgatory.checkAndComplete)
+    }
+
+    override def appendRecords(timeout: Long,
+                               requiredAcks: Short,
+                               internalTopicsAllowed: Boolean,
+                               isFromClient: Boolean,
+                               entriesPerPartition: Map[TopicPartition, MemoryRecords],
+                               responseCallback: Map[TopicPartition, PartitionResponse] =>
Unit,
+                               delayedProduceLock: Option[Lock] = None,
+                               processingStatsCallback: Map[TopicPartition, RecordsProcessingStats]
=> Unit = _ => ()) {
+
+      if (entriesPerPartition.isEmpty)
+        return
+      val produceMetadata = ProduceMetadata(1, entriesPerPartition.map {
+        case (tp, _) =>
+          (tp, ProducePartitionStatus(0L, new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP,
0L)))
+      })
+      val delayedProduce = new DelayedProduce(5, produceMetadata, this, responseCallback,
delayedProduceLock) {
+        // Complete produce requests after a few attempts to trigger delayed produce from
different threads
+        val completeAttempts = new AtomicInteger
+        override def tryComplete(): Boolean = {
+          if (completeAttempts.incrementAndGet() >= 3)
+            forceComplete()
+          else
+            false
+        }
+        override def onComplete() {
+          responseCallback(entriesPerPartition.map {
+            case (tp, _) =>
+              (tp, new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L))
+          })
+        }
+      }
+      val producerRequestKeys = entriesPerPartition.keys.map(new TopicPartitionOperationKey(_)).toSeq
+      watchKeys ++= producerRequestKeys
+      producePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)
+      tryCompleteDelayedRequests()
+    }
+    override def getMagic(topicPartition: TopicPartition): Option[Byte] = {
+      Some(RecordBatch.MAGIC_VALUE_V2)
+    }
+    @volatile var logs: mutable.Map[TopicPartition, (Log, Long)] = _
+    def getOrCreateLogs(): mutable.Map[TopicPartition, (Log, Long)] = {
+      if (logs == null)
+        logs = mutable.Map[TopicPartition, (Log, Long)]()
+      logs
+    }
+    def updateLog(topicPartition: TopicPartition, log: Log, endOffset: Long): Unit = {
+      getOrCreateLogs().put(topicPartition, (log, endOffset))
+    }
+    override def getLog(topicPartition: TopicPartition): Option[Log] =
+      getOrCreateLogs().get(topicPartition).map(l => l._1)
+    override def getLogEndOffset(topicPartition: TopicPartition): Option[Long] =
+      getOrCreateLogs().get(topicPartition).map(l => l._2)
+  }
+}
diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
new file mode 100644
index 0000000..44e1356
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
@@ -0,0 +1,310 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.coordinator.group
+
+import java.util.concurrent.{ ConcurrentHashMap, TimeUnit }
+
+import kafka.common.OffsetAndMetadata
+import kafka.coordinator.AbstractCoordinatorConcurrencyTest
+import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
+import kafka.coordinator.group.GroupCoordinatorConcurrencyTest._
+import kafka.server.{ DelayedOperationPurgatory, KafkaConfig }
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.internals.Topic
+import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.requests.{ JoinGroupRequest, TransactionResult }
+import org.easymock.EasyMock
+import org.junit.Assert._
+import org.junit.{ After, Before, Test }
+
+import scala.collection._
+import scala.concurrent.duration.Duration
+import scala.concurrent.{ Await, Future, Promise, TimeoutException }
+
+class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest[GroupMember]
{
+
+  private val protocolType = "consumer"
+  private val metadata = Array[Byte]()
+  private val protocols = List(("range", metadata))
+
+  private val nGroups = nThreads * 10
+  private val nMembersPerGroup = nThreads * 5
+  private val numPartitions = 2
+
+  private val allOperations = Seq(
+      new JoinGroupOperation,
+      new SyncGroupOperation,
+      new CommitOffsetsOperation,
+      new HeartbeatOperation,
+      new LeaveGroupOperation
+    )
+  private val allOperationsWithTxn = Seq(
+    new JoinGroupOperation,
+    new SyncGroupOperation,
+    new CommitTxnOffsetsOperation,
+    new CompleteTxnOperation,
+    new HeartbeatOperation,
+    new LeaveGroupOperation
+  )
+
+  var groupCoordinator: GroupCoordinator = _
+
+  @Before
+  override def setUp() {
+    super.setUp()
+
+    EasyMock.expect(zkClient.getTopicPartitionCount(Topic.GROUP_METADATA_TOPIC_NAME))
+      .andReturn(Some(numPartitions))
+      .anyTimes()
+    EasyMock.replay(zkClient)
+
+    serverProps.setProperty(KafkaConfig.GroupMinSessionTimeoutMsProp, ConsumerMinSessionTimeout.toString)
+    serverProps.setProperty(KafkaConfig.GroupMaxSessionTimeoutMsProp, ConsumerMaxSessionTimeout.toString)
+    serverProps.setProperty(KafkaConfig.GroupInitialRebalanceDelayMsProp, GroupInitialRebalanceDelay.toString)
+
+    val config = KafkaConfig.fromProps(serverProps)
+
+    val heartbeatPurgatory = new DelayedOperationPurgatory[DelayedHeartbeat]("Heartbeat",
timer, config.brokerId, reaperEnabled = false)
+    val joinPurgatory = new DelayedOperationPurgatory[DelayedJoin]("Rebalance", timer, config.brokerId,
reaperEnabled = false)
+
+    groupCoordinator = GroupCoordinator(config, zkClient, replicaManager, heartbeatPurgatory,
joinPurgatory, timer.time)
+    groupCoordinator.startup(false)
+  }
+
+  @After
+  override def tearDown() {
+    try {
+      if (groupCoordinator != null)
+        groupCoordinator.shutdown()
+    } finally {
+      super.tearDown()
+    }
+  }
+
+  def createGroupMembers(groupPrefix: String): Set[GroupMember] = {
+    (0 until nGroups).flatMap { i =>
+      new Group(s"$groupPrefix$i", nMembersPerGroup, groupCoordinator, replicaManager).members
+    }.toSet
+  }
+
+  @Test
+  def testConcurrentGoodPathSequence() {
+    verifyConcurrentOperations(createGroupMembers, allOperations)
+  }
+
+  @Test
+  def testConcurrentTxnGoodPathSequence() {
+    verifyConcurrentOperations(createGroupMembers, allOperationsWithTxn)
+  }
+
+  @Test
+  def testConcurrentRandomSequence() {
+    verifyConcurrentRandomSequences(createGroupMembers, allOperationsWithTxn)
+  }
+
+
+  abstract class GroupOperation[R, C] extends Operation {
+    val responseFutures = new ConcurrentHashMap[GroupMember, Future[R]]()
+
+    def setUpCallback(member: GroupMember): C = {
+      val responsePromise = Promise[R]
+      val responseFuture = responsePromise.future
+      responseFutures.put(member, responseFuture)
+      responseCallback(responsePromise)
+    }
+    def responseCallback(responsePromise: Promise[R]): C
+
+    override def run(member: GroupMember): Unit = {
+      val responseCallback = setUpCallback(member)
+      runWithCallback(member, responseCallback)
+    }
+
+    def runWithCallback(member: GroupMember, responseCallback: C): Unit
+
+    def await(member: GroupMember, timeoutMs: Long): R = {
+      var retries = (timeoutMs + 10) / 10
+      val responseFuture = responseFutures.get(member)
+      while (retries > 0) {
+        timer.advanceClock(10)
+        try {
+          return Await.result(responseFuture, Duration(10, TimeUnit.MILLISECONDS))
+        } catch {
+          case _: TimeoutException =>
+        }
+        retries -= 1
+      }
+      throw new TimeoutException(s"Operation did not complete within $timeoutMs millis")
+    }
+  }
+
+
+  class JoinGroupOperation extends GroupOperation[JoinGroupResult, JoinGroupCallback] {
+    override def responseCallback(responsePromise: Promise[JoinGroupResult]): JoinGroupCallback
= {
+      val callback: JoinGroupCallback = responsePromise.success(_)
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: JoinGroupCallback):
Unit = {
+      groupCoordinator.handleJoinGroup(member.groupId, member.memberId, "clientId", "clientHost",
+       DefaultRebalanceTimeout, DefaultSessionTimeout,
+       protocolType, protocols, responseCallback)
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+       val joinGroupResult = await(member, DefaultRebalanceTimeout)
+       assertEquals(Errors.NONE, joinGroupResult.error)
+       member.memberId = joinGroupResult.memberId
+       member.generationId = joinGroupResult.generationId
+    }
+  }
+
+  class SyncGroupOperation extends GroupOperation[SyncGroupCallbackParams, SyncGroupCallback]
{
+    override def responseCallback(responsePromise: Promise[SyncGroupCallbackParams]): SyncGroupCallback
= {
+      val callback: SyncGroupCallback = (assignment, error) =>
+        responsePromise.success((assignment, error))
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: SyncGroupCallback):
Unit = {
+      if (member.leader) {
+        groupCoordinator.handleSyncGroup(member.groupId, member.generationId, member.memberId,
+            member.group.assignment, responseCallback)
+      } else {
+         groupCoordinator.handleSyncGroup(member.groupId, member.generationId, member.memberId,
+             Map.empty[String, Array[Byte]], responseCallback)
+      }
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+       val result = await(member, DefaultSessionTimeout)
+       assertEquals(Errors.NONE, result._2)
+    }
+  }
+
+  class HeartbeatOperation extends GroupOperation[HeartbeatCallbackParams, HeartbeatCallback]
{
+    override def responseCallback(responsePromise: Promise[HeartbeatCallbackParams]): HeartbeatCallback
= {
+      val callback: HeartbeatCallback = error => responsePromise.success(error)
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: HeartbeatCallback):
Unit = {
+      groupCoordinator.handleHeartbeat( member.groupId, member.memberId,  member.generationId,
responseCallback)
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+       val error = await(member, DefaultSessionTimeout)
+       assertEquals(Errors.NONE, error)
+    }
+  }
+  class CommitOffsetsOperation extends GroupOperation[CommitOffsetCallbackParams, CommitOffsetCallback]
{
+    override def responseCallback(responsePromise: Promise[CommitOffsetCallbackParams]):
CommitOffsetCallback = {
+      val callback: CommitOffsetCallback = offsets => responsePromise.success(offsets)
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: CommitOffsetCallback):
Unit = {
+      val tp = new TopicPartition("topic", 0)
+      val offsets = immutable.Map(tp -> OffsetAndMetadata(1))
+      groupCoordinator.handleCommitOffsets(member.groupId, member.memberId, member.generationId,
+          offsets, responseCallback)
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+       val offsets = await(member, 500)
+       offsets.foreach { case (_, error) => assertEquals(Errors.NONE, error) }
+    }
+  }
+
+  class CommitTxnOffsetsOperation extends CommitOffsetsOperation {
+    override def runWithCallback(member: GroupMember, responseCallback: CommitOffsetCallback):
Unit = {
+      val tp = new TopicPartition("topic", 0)
+      val offsets = immutable.Map(tp -> OffsetAndMetadata(1))
+      val producerId = 1000L
+      val producerEpoch : Short = 2
+      groupCoordinator.handleTxnCommitOffsets(member.group.groupId,
+          producerId, producerEpoch, offsets, responseCallback)
+    }
+  }
+
+  class CompleteTxnOperation extends GroupOperation[CompleteTxnCallbackParams, CompleteTxnCallback]
{
+    override def responseCallback(responsePromise: Promise[CompleteTxnCallbackParams]): CompleteTxnCallback
= {
+      val callback: CompleteTxnCallback = error => responsePromise.success(error)
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: CompleteTxnCallback):
Unit = {
+      val producerId = 1000L
+      val offsetsPartitions = (0 to numPartitions).map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME,
_))
+      groupCoordinator.handleTxnCompletion(producerId, offsetsPartitions, transactionResult(member.group.groupId))
+      responseCallback(Errors.NONE)
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+      val error = await(member, 500)
+      assertEquals(Errors.NONE, error)
+    }
+    // Test both commit and abort. Group ids used in the test have the format <prefix><index>
+    // Use the last digit of the index to decide between commit and abort.
+    private def transactionResult(groupId: String): TransactionResult = {
+      val lastDigit = groupId(groupId.length - 1).toInt
+      if (lastDigit % 2 == 0) TransactionResult.COMMIT else TransactionResult.ABORT
+    }
+  }
+
+  class LeaveGroupOperation extends GroupOperation[LeaveGroupCallbackParams, LeaveGroupCallback]
{
+    override def responseCallback(responsePromise: Promise[LeaveGroupCallbackParams]): LeaveGroupCallback
= {
+      val callback: LeaveGroupCallback = error => responsePromise.success(error)
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: LeaveGroupCallback):
Unit = {
+      groupCoordinator.handleLeaveGroup(member.group.groupId, member.memberId, responseCallback)
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+       val error = await(member, DefaultSessionTimeout)
+       assertEquals(Errors.NONE, error)
+    }
+  }
+}
+
+object GroupCoordinatorConcurrencyTest {
+
+
+  type JoinGroupCallback = JoinGroupResult => Unit
+  type SyncGroupCallbackParams = (Array[Byte], Errors)
+  type SyncGroupCallback = (Array[Byte], Errors) => Unit
+  type HeartbeatCallbackParams = Errors
+  type HeartbeatCallback = Errors => Unit
+  type CommitOffsetCallbackParams = Map[TopicPartition, Errors]
+  type CommitOffsetCallback = Map[TopicPartition, Errors] => Unit
+  type LeaveGroupCallbackParams = Errors
+  type LeaveGroupCallback = Errors => Unit
+  type CompleteTxnCallbackParams = Errors
+  type CompleteTxnCallback = Errors => Unit
+
+  private val ConsumerMinSessionTimeout = 10
+  private val ConsumerMaxSessionTimeout = 120 * 1000
+  private val DefaultRebalanceTimeout = 60 * 1000
+  private val DefaultSessionTimeout = 60 * 1000
+  private val GroupInitialRebalanceDelay = 50
+
+  class Group(val groupId: String, nMembers: Int,
+      groupCoordinator: GroupCoordinator, replicaManager: TestReplicaManager) {
+    val groupPartitionId = groupCoordinator.partitionFor(groupId)
+    groupCoordinator.groupManager.addPartitionOwnership(groupPartitionId)
+    val members = (0 until nMembers).map { i =>
+      new GroupMember(this, groupPartitionId, i == 0)
+    }
+    def assignment = members.map { m => (m.memberId, Array[Byte]()) }.toMap
+  }
+
+  class GroupMember(val group: Group, val groupPartitionId: Int, val leader: Boolean) extends
CoordinatorMember {
+    @volatile var memberId: String = JoinGroupRequest.UNKNOWN_MEMBER_ID
+    @volatile var generationId: Int = -1
+    def groupId: String = group.groupId
+  }
+}
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
new file mode 100644
index 0000000..046741a
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -0,0 +1,388 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package kafka.coordinator.transaction
+
+import java.nio.ByteBuffer
+
+import kafka.coordinator.AbstractCoordinatorConcurrencyTest
+import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
+import kafka.coordinator.transaction.TransactionCoordinatorConcurrencyTest._
+import kafka.log.Log
+import kafka.server.{ DelayedOperationPurgatory, FetchDataInfo, KafkaConfig, LogOffsetMetadata,
MetadataCache }
+import kafka.utils.timer.MockTimer
+import kafka.utils.{ Pool, TestUtils}
+
+import org.apache.kafka.clients.{ ClientResponse, NetworkClient }
+import org.apache.kafka.common.{ Node, TopicPartition }
+import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME
+import org.apache.kafka.common.protocol.{ ApiKeys, Errors }
+import org.apache.kafka.common.record.{ CompressionType, FileRecords, MemoryRecords, SimpleRecord
}
+import org.apache.kafka.common.requests._
+import org.apache.kafka.common.utils.{ LogContext, MockTime }
+
+import org.easymock.EasyMock
+import org.junit.Assert._
+import org.junit.{ After, Before, Test }
+
+import scala.collection.Map
+import scala.collection.mutable
+import scala.collection.JavaConverters._
+
+class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest[Transaction]
{
+  private val nTransactions = nThreads * 10
+  private val coordinatorEpoch = 10
+  private val numPartitions = nThreads * 5
+
+  private val txnConfig = TransactionConfig()
+  private var transactionCoordinator: TransactionCoordinator = _
+  private var txnStateManager: TransactionStateManager = _
+  private var txnMarkerChannelManager: TransactionMarkerChannelManager = _
+
+  private val allOperations = Seq(
+      new InitProducerIdOperation,
+      new AddPartitionsToTxnOperation(Set(new TopicPartition("topic", 0))),
+      new EndTxnOperation)
+
+  private val allTransactions = mutable.Set[Transaction]()
+  private val txnRecordsByPartition: Map[Int, mutable.ArrayBuffer[SimpleRecord]] =
+    (0 until numPartitions).map { i => (i, mutable.ArrayBuffer[SimpleRecord]()) }.toMap
+
+  @Before
+  override def setUp() {
+    super.setUp()
+
+    EasyMock.expect(zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME))
+      .andReturn(Some(numPartitions))
+      .anyTimes()
+    EasyMock.replay(zkClient)
+
+    txnStateManager = new TransactionStateManager(0, zkClient, scheduler, replicaManager,
txnConfig, time)
+    for (i <- 0 until numPartitions)
+      txnStateManager.addLoadedTransactionsToCache(i, coordinatorEpoch, new Pool[String,
TransactionMetadata]())
+
+    val producerId = 11
+    val pidManager: ProducerIdManager = EasyMock.createNiceMock(classOf[ProducerIdManager])
+    EasyMock.expect(pidManager.generateProducerId())
+      .andReturn(producerId)
+      .anyTimes()
+    val txnMarkerPurgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("txn-purgatory-name",
+      new MockTimer,
+      reaperEnabled = false)
+    val brokerNode = new Node(0, "host", 10)
+    val metadataCache = EasyMock.createNiceMock(classOf[MetadataCache])
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.anyString(),
+      EasyMock.anyInt(),
+      EasyMock.anyObject())
+    ).andReturn(Some(brokerNode)).anyTimes()
+    val networkClient = EasyMock.createNiceMock(classOf[NetworkClient])
+    txnMarkerChannelManager = new TransactionMarkerChannelManager(
+      KafkaConfig.fromProps(serverProps),
+      metadataCache,
+      networkClient,
+      txnStateManager,
+      txnMarkerPurgatory,
+      time) {
+        override def shutdown(): Unit = {
+          txnMarkerPurgatory.shutdown()
+        }
+    }
+
+    transactionCoordinator = new TransactionCoordinator(brokerId = 0,
+      txnConfig,
+      scheduler,
+      pidManager,
+      txnStateManager,
+      txnMarkerChannelManager,
+      time,
+      new LogContext)
+    EasyMock.replay(pidManager)
+    EasyMock.replay(metadataCache)
+    EasyMock.replay(networkClient)
+  }
+
+  @After
+  override def tearDown() {
+    try {
+      EasyMock.reset(zkClient, replicaManager)
+      transactionCoordinator.shutdown()
+    } finally {
+      super.tearDown()
+    }
+  }
+
+  @Test
+  def testConcurrentGoodPathSequence(): Unit = {
+    verifyConcurrentOperations(createTransactions, allOperations)
+  }
+
+  @Test
+  def testConcurrentRandomSequences(): Unit = {
+    verifyConcurrentRandomSequences(createTransactions, allOperations)
+  }
+
+  /**
+    * Concurrently load one set of transaction state topic partitions and unload another
+    * set of partitions. This tests partition leader changes of transaction state topic
+    * that are handled by different threads concurrently. Verifies that the metadata of
+    * unloaded partitions are removed from the transaction manager and that the transactions
+    * from the newly loaded partitions are loaded correctly.
+    */
+  @Test
+  def testConcurrentLoadUnloadPartitions(): Unit = {
+    val partitionsToLoad = (0 until numPartitions / 2).toSet
+    val partitionsToUnload = (numPartitions / 2 until numPartitions).toSet
+    verifyConcurrentActions(loadUnloadActions(partitionsToLoad, partitionsToUnload))
+  }
+
+  /**
+    * Concurrently load one set of transaction state topic partitions, unload a second set
+    * of partitions and expire transactions on a third set of partitions. This tests partition
+    * leader changes of transaction state topic that are handled by different threads concurrently
+    * while expiry is performed on another thread. Verifies the state of transactions on
all the partitions.
+    */
+  @Test
+  def testConcurrentTransactionExpiration(): Unit = {
+    val partitionsToLoad = (0 until numPartitions / 3).toSet
+    val partitionsToUnload = (numPartitions / 3 until numPartitions * 2 / 3).toSet
+    val partitionsWithExpiringTxn = (numPartitions * 2 / 3 until numPartitions).toSet
+    val expiringTransactions = allTransactions.filter { txn =>
+      partitionsWithExpiringTxn.contains(txnStateManager.partitionFor(txn.transactionalId))
+    }.toSet
+    val expireAction = new ExpireTransactionsAction(expiringTransactions)
+    verifyConcurrentActions(loadUnloadActions(partitionsToLoad, partitionsToUnload) + expireAction)
+  }
+
+  override def enableCompletion(): Unit = {
+    super.enableCompletion()
+
+    def createResponse(request: WriteTxnMarkersRequest): WriteTxnMarkersResponse  = {
+      val pidErrorMap = request.markers.asScala.map { marker =>
+        (marker.producerId.asInstanceOf[java.lang.Long], marker.partitions.asScala.map {
tp => (tp, Errors.NONE) }.toMap.asJava)
+      }.toMap.asJava
+      new WriteTxnMarkersResponse(pidErrorMap)
+    }
+    synchronized {
+      txnMarkerChannelManager.generateRequests().foreach { requestAndHandler =>
+        val request = requestAndHandler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()
+        val response = createResponse(request)
+        requestAndHandler.handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE,
0, "client", 1),
+          null, null, 0, 0, false, null, response))
+      }
+    }
+  }
+
+  /**
+    * Concurrently load `partitionsToLoad` and unload `partitionsToUnload`. Before the concurrent
operations
+    * are run `partitionsToLoad` must be unloaded first since all partitions were loaded
during setUp.
+    */
+  private def loadUnloadActions(partitionsToLoad: Set[Int], partitionsToUnload: Set[Int]):
Set[Action] = {
+    val transactions = (1 to 10).flatMap(i => createTransactions(s"testConcurrentLoadUnloadPartitions$i-")).toSet
+    transactions.foreach(txn => prepareTransaction(txn))
+    val unload = partitionsToLoad.map(new UnloadTxnPartitionAction(_))
+    unload.foreach(_.run())
+    unload.foreach(_.await())
+    partitionsToLoad.map(new LoadTxnPartitionAction(_)) ++ partitionsToUnload.map(new UnloadTxnPartitionAction(_))
+  }
+
+  private def createTransactions(txnPrefix: String): Set[Transaction] = {
+    val transactions = (0 until nTransactions).map { i => new Transaction(s"$txnPrefix$i",
i, time) }
+    allTransactions ++= transactions
+    transactions.toSet
+  }
+
+  private def verifyTransaction(txn: Transaction, expectedState: TransactionState): Unit
= {
+    val (metadata, success) = TestUtils.computeUntilTrue({
+      enableCompletion()
+      transactionMetadata(txn)
+    })(metadata => metadata.nonEmpty && metadata.forall(m => m.state == expectedState
&& m.pendingState.isEmpty))
+    assertTrue(s"Invalid metadata state $metadata", success)
+  }
+
+  private def transactionMetadata(txn: Transaction): Option[TransactionMetadata] = {
+    txnStateManager.getTransactionState(txn.transactionalId) match {
+      case Left(error) =>
+        if (error == Errors.NOT_COORDINATOR)
+          None
+        else
+          throw new AssertionError(s"Unexpected transaction error $error for $txn")
+      case Right(Some(metadata)) =>
+        Some(metadata.transactionMetadata)
+      case Right(None) =>
+        None
+    }
+  }
+
+  private def prepareTransaction(txn: Transaction): Unit = {
+    val partitionId = txnStateManager.partitionFor(txn.transactionalId)
+    val txnRecords = txnRecordsByPartition(partitionId)
+    val initPidOp = new InitProducerIdOperation()
+    val addPartitionsOp = new AddPartitionsToTxnOperation(Set(new TopicPartition("topic",
0)))
+      initPidOp.run(txn)
+      initPidOp.awaitAndVerify(txn)
+      addPartitionsOp.run(txn)
+      addPartitionsOp.awaitAndVerify(txn)
+
+      val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction
not found $txn"))
+      txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit()))
+
+      txnMetadata.state = PrepareCommit
+      txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit()))
+
+      prepareTxnLog(partitionId)
+  }
+
+  private def prepareTxnLog(partitionId: Int): Unit = {
+
+    val logMock =  EasyMock.mock(classOf[Log])
+    val fileRecordsMock = EasyMock.mock(classOf[FileRecords])
+
+    val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId)
+    val startOffset = replicaManager.getLogEndOffset(topicPartition).getOrElse(20L)
+    val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecordsByPartition(partitionId):
_*)
+    val endOffset = startOffset + records.records.asScala.size
+
+    EasyMock.expect(logMock.logStartOffset).andStubReturn(startOffset)
+    EasyMock.expect(logMock.read(EasyMock.eq(startOffset), EasyMock.anyInt(), EasyMock.eq(None),
+      EasyMock.eq(true), EasyMock.eq(IsolationLevel.READ_UNCOMMITTED)))
+      .andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), fileRecordsMock))
+    EasyMock.expect(fileRecordsMock.readInto(EasyMock.anyObject(classOf[ByteBuffer]), EasyMock.anyInt()))
+      .andReturn(records.buffer)
+
+    EasyMock.replay(logMock, fileRecordsMock)
+    synchronized {
+      replicaManager.updateLog(topicPartition, logMock, endOffset)
+    }
+  }
+
+  abstract class TxnOperation[R] extends Operation {
+    @volatile var result: Option[R] = None
+    def resultCallback(r: R): Unit = this.result = Some(r)
+  }
+
+  class InitProducerIdOperation extends TxnOperation[InitProducerIdResult] {
+    override def run(txn: Transaction): Unit = {
+      transactionCoordinator.handleInitProducerId(txn.transactionalId, 60000, resultCallback)
+    }
+    override def awaitAndVerify(txn: Transaction): Unit = {
+      val initPidResult = result.getOrElse(throw new IllegalStateException("InitProducerId
has not completed"))
+      assertEquals(Errors.NONE, initPidResult.error)
+      verifyTransaction(txn, Empty)
+    }
+  }
+
+  class AddPartitionsToTxnOperation(partitions: Set[TopicPartition]) extends TxnOperation[Errors]
{
+    override def run(txn: Transaction): Unit = {
+      transactionMetadata(txn).foreach { txnMetadata =>
+        transactionCoordinator.handleAddPartitionsToTransaction(txn.transactionalId,
+            txnMetadata.producerId,
+            txnMetadata.producerEpoch,
+            partitions,
+            resultCallback)
+      }
+    }
+    override def awaitAndVerify(txn: Transaction): Unit = {
+      val error = result.getOrElse(throw new IllegalStateException("AddPartitionsToTransaction
has not completed"))
+      assertEquals(Errors.NONE, error)
+      verifyTransaction(txn, Ongoing)
+    }
+  }
+
+  class EndTxnOperation extends TxnOperation[Errors] {
+    override def run(txn: Transaction): Unit = {
+      transactionMetadata(txn).foreach { txnMetadata =>
+        transactionCoordinator.handleEndTransaction(txn.transactionalId,
+          txnMetadata.producerId,
+          txnMetadata.producerEpoch,
+          transactionResult(txn),
+          resultCallback)
+      }
+    }
+    override def awaitAndVerify(txn: Transaction): Unit = {
+      val error = result.getOrElse(throw new IllegalStateException("EndTransaction has not
completed"))
+      if (!txn.ended) {
+        txn.ended = true
+        assertEquals(Errors.NONE, error)
+        val expectedState = if (transactionResult(txn) == TransactionResult.COMMIT) CompleteCommit
else CompleteAbort
+        verifyTransaction(txn, expectedState)
+      } else
+        assertEquals(Errors.INVALID_TXN_STATE, error)
+    }
+    // Test both commit and abort. Transactional ids used in the test have the format <prefix><index>
+    // Use the last digit of the index to decide between commit and abort.
+    private def transactionResult(txn: Transaction): TransactionResult = {
+      val txnId = txn.transactionalId
+      val lastDigit = txnId(txnId.length - 1).toInt
+      if (lastDigit % 2 == 0) TransactionResult.COMMIT else TransactionResult.ABORT
+    }
+  }
+
+  class LoadTxnPartitionAction(txnTopicPartitionId: Int) extends Action {
+    override def run(): Unit = {
+      transactionCoordinator.handleTxnImmigration(txnTopicPartitionId, coordinatorEpoch)
+    }
+    override def await(): Unit = {
+      allTransactions.foreach { txn =>
+        if (txnStateManager.partitionFor(txn.transactionalId) == txnTopicPartitionId) {
+          verifyTransaction(txn, CompleteCommit)
+        }
+      }
+    }
+  }
+
+  class UnloadTxnPartitionAction(txnTopicPartitionId: Int) extends Action {
+    val txnRecords: mutable.ArrayBuffer[SimpleRecord] = mutable.ArrayBuffer[SimpleRecord]()
+    override def run(): Unit = {
+      transactionCoordinator.handleTxnEmigration(txnTopicPartitionId, coordinatorEpoch)
+    }
+    override def await(): Unit = {
+      allTransactions.foreach { txn =>
+        if (txnStateManager.partitionFor(txn.transactionalId) == txnTopicPartitionId)
+          assertTrue("Transaction metadata not removed", transactionMetadata(txn).isEmpty)
+      }
+    }
+  }
+
+  class ExpireTransactionsAction(transactions: Set[Transaction]) extends Action {
+    override def run(): Unit = {
+      transactions.foreach { txn =>
+        transactionMetadata(txn).foreach { txnMetadata =>
+          txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs
+        }
+      }
+      txnStateManager.enableTransactionalIdExpiration()
+      time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs + 1)
+    }
+
+    override def await(): Unit = {
+      val (_, success) = TestUtils.computeUntilTrue({
+        replicaManager.tryCompleteDelayedRequests()
+        transactions.forall(txn => transactionMetadata(txn).isEmpty)
+      })(identity)
+      assertTrue("Transaction not expired", success)
+    }
+  }
+}
+
+object TransactionCoordinatorConcurrencyTest {
+
+  class Transaction(val transactionalId: String, producerId: Long, time: MockTime) extends
CoordinatorMember {
+    val txnMessageKeyBytes: Array[Byte] = TransactionLog.keyToBytes(transactionalId)
+    @volatile var ended = false
+    override def toString: String = transactionalId
+  }
+}
diff --git a/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala b/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
index e4ac4fa..17ee578 100644
--- a/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
+++ b/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
@@ -28,8 +28,11 @@ class MockTimer extends Timer {
   def add(timerTask: TimerTask) {
     if (timerTask.delayMs <= 0)
       timerTask.run()
-    else
-      taskQueue.enqueue(new TimerTaskEntry(timerTask, timerTask.delayMs + time.milliseconds))
+    else {
+      taskQueue synchronized {
+        taskQueue.enqueue(new TimerTaskEntry(timerTask, timerTask.delayMs + time.milliseconds))
+      }
+    }
   }
 
   def advanceClock(timeoutMs: Long): Boolean = {
@@ -38,15 +41,25 @@ class MockTimer extends Timer {
     var executed = false
     val now = time.milliseconds
 
-    while (taskQueue.nonEmpty && now > taskQueue.head.expirationMs) {
-      val taskEntry = taskQueue.dequeue()
-      if (!taskEntry.cancelled) {
-        val task = taskEntry.timerTask
-        task.run()
-        executed = true
+    var hasMore = true
+    while (hasMore) {
+      hasMore = false
+      val head = taskQueue synchronized {
+        if (taskQueue.nonEmpty && now > taskQueue.head.expirationMs) {
+          val entry = Some(taskQueue.dequeue())
+          hasMore = taskQueue.nonEmpty
+          entry
+        } else
+          None
+      }
+      head.foreach { taskEntry =>
+        if (!taskEntry.cancelled) {
+          val task = taskEntry.timerTask
+          task.run()
+          executed = true
+        }
       }
     }
-
     executed
   }
 

-- 
To stop receiving notification emails like this one, please contact
['"commits@kafka.apache.org" <commits@kafka.apache.org>'].

Mime
View raw message