kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jun...@apache.org
Subject [kafka] branch trunk updated: KAFKA-8334 Make sure the thread which tries to complete delayed reque… (#8657)
Date Wed, 09 Sep 2020 21:44:28 GMT
This is an automated email from the ASF dual-hosted git repository.

junrao pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new c2273ad  KAFKA-8334 Make sure the thread which tries to complete delayed reque…
(#8657)
c2273ad is described below

commit c2273adc25b2bab0a3ac95bf7844fedf2860b40b
Author: Chia-Ping Tsai <chia7712@gmail.com>
AuthorDate: Thu Sep 10 05:42:37 2020 +0800

    KAFKA-8334 Make sure the thread which tries to complete delayed reque… (#8657)
    
    The main changes of this PR are shown below.
    
    1. replace tryLock by lock for DelayedOperation#maybeTryComplete
    2. complete the delayed requests without holding group lock
    
    Reviewers: Ismael Juma <ismael@juma.me.uk>, Jun Rao <junrao@gmail.com>
---
 core/src/main/scala/kafka/cluster/Partition.scala  |  22 +---
 .../kafka/coordinator/group/DelayedJoin.scala      |  11 +-
 .../coordinator/group/GroupMetadataManager.scala   |   2 +-
 core/src/main/scala/kafka/log/Log.scala            |  13 ++-
 core/src/main/scala/kafka/server/ActionQueue.scala |  56 ++++++++++
 .../main/scala/kafka/server/DelayedOperation.scala | 118 +++++++++------------
 core/src/main/scala/kafka/server/KafkaApis.scala   |   4 +
 .../main/scala/kafka/server/ReplicaManager.scala   |  31 ++++++
 .../AbstractCoordinatorConcurrencyTest.scala       |  10 +-
 .../group/GroupCoordinatorConcurrencyTest.scala    |  65 ++++++++----
 .../TransactionCoordinatorConcurrencyTest.scala    |   5 +-
 .../unit/kafka/server/DelayedOperationTest.scala   |  96 +++++++----------
 12 files changed, 257 insertions(+), 176 deletions(-)

diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala
index fb0576e..fc0852f 100755
--- a/core/src/main/scala/kafka/cluster/Partition.scala
+++ b/core/src/main/scala/kafka/cluster/Partition.scala
@@ -103,19 +103,7 @@ class DelayedOperations(topicPartition: TopicPartition,
     fetch.checkAndComplete(TopicPartitionOperationKey(topicPartition))
   }
 
-  def checkAndCompleteProduce(): Unit = {
-    produce.checkAndComplete(TopicPartitionOperationKey(topicPartition))
-  }
-
-  def checkAndCompleteDeleteRecords(): Unit = {
-    deleteRecords.checkAndComplete(TopicPartitionOperationKey(topicPartition))
-  }
-
   def numDelayedDelete: Int = deleteRecords.numDelayed
-
-  def numDelayedFetch: Int = fetch.numDelayed
-
-  def numDelayedProduce: Int = produce.numDelayed
 }
 
 object Partition extends KafkaMetricsGroup {
@@ -1010,15 +998,7 @@ class Partition(val topicPartition: TopicPartition,
       }
     }
 
-    // some delayed operations may be unblocked after HW changed
-    if (leaderHWIncremented)
-      tryCompleteDelayedRequests()
-    else {
-      // probably unblock some follower fetch requests since log end offset has been updated
-      delayedOperations.checkAndCompleteFetch()
-    }
-
-    info
+    info.copy(leaderHwChange = if (leaderHWIncremented) LeaderHwChange.Increased else LeaderHwChange.Same)
   }
 
   def readRecords(fetchOffset: Long,
diff --git a/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala b/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala
index dad2b1e..92e8835 100644
--- a/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala
+++ b/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala
@@ -36,8 +36,15 @@ private[group] class DelayedJoin(coordinator: GroupCoordinator,
                                  rebalanceTimeout: Long) extends DelayedOperation(rebalanceTimeout,
Some(group.lock)) {
 
   override def tryComplete(): Boolean = coordinator.tryCompleteJoin(group, forceComplete
_)
-  override def onExpiration() = coordinator.onExpireJoin()
-  override def onComplete() = coordinator.onCompleteJoin(group)
+  override def onExpiration(): Unit = {
+    coordinator.onExpireJoin()
+    // try to complete delayed actions introduced by coordinator.onCompleteJoin
+    tryToCompleteDelayedAction()
+  }
+  override def onComplete(): Unit = coordinator.onCompleteJoin(group)
+
+  // TODO: remove this ugly chain after we move the action queue to handler thread
+  private def tryToCompleteDelayedAction(): Unit = coordinator.groupManager.replicaManager.tryCompleteActions()
 }
 
 /**
diff --git a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala
index 1fcdd91..9d58b2d 100644
--- a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala
+++ b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala
@@ -56,7 +56,7 @@ import scala.jdk.CollectionConverters._
 class GroupMetadataManager(brokerId: Int,
                            interBrokerProtocolVersion: ApiVersion,
                            config: OffsetConfig,
-                           replicaManager: ReplicaManager,
+                           val replicaManager: ReplicaManager,
                            zkClient: KafkaZkClient,
                            time: Time,
                            metrics: Metrics) extends Logging with KafkaMetricsGroup {
diff --git a/core/src/main/scala/kafka/log/Log.scala b/core/src/main/scala/kafka/log/Log.scala
index fa78666..8953156 100644
--- a/core/src/main/scala/kafka/log/Log.scala
+++ b/core/src/main/scala/kafka/log/Log.scala
@@ -68,6 +68,13 @@ object LogAppendInfo {
       offsetsMonotonic = false, -1L, recordErrors, errorMessage)
 }
 
+sealed trait LeaderHwChange
+object LeaderHwChange {
+  case object Increased extends LeaderHwChange
+  case object Same extends LeaderHwChange
+  case object None extends LeaderHwChange
+}
+
 /**
  * Struct to hold various quantities we compute about each message set before appending to
the log
  *
@@ -85,6 +92,9 @@ object LogAppendInfo {
  * @param validBytes The number of valid bytes
  * @param offsetsMonotonic Are the offsets in this message set monotonically increasing
  * @param lastOffsetOfFirstBatch The last offset of the first batch
+ * @param leaderHwChange Incremental if the high watermark needs to be increased after appending
record.
+ *                       Same if high watermark is not changed. None is the default value
and it means append failed
+ *
  */
 case class LogAppendInfo(var firstOffset: Option[Long],
                          var lastOffset: Long,
@@ -100,7 +110,8 @@ case class LogAppendInfo(var firstOffset: Option[Long],
                          offsetsMonotonic: Boolean,
                          lastOffsetOfFirstBatch: Long,
                          recordErrors: Seq[RecordError] = List(),
-                         errorMessage: String = null) {
+                         errorMessage: String = null,
+                         leaderHwChange: LeaderHwChange = LeaderHwChange.None) {
   /**
    * Get the first offset if it exists, else get the last offset of the first batch
    * For magic versions 2 and newer, this method will return first offset. For magic versions
diff --git a/core/src/main/scala/kafka/server/ActionQueue.scala b/core/src/main/scala/kafka/server/ActionQueue.scala
new file mode 100644
index 0000000..1b6b832
--- /dev/null
+++ b/core/src/main/scala/kafka/server/ActionQueue.scala
@@ -0,0 +1,56 @@
+/**
+  * Licensed to the Apache Software Foundation (ASF) under one or more
+  * contributor license agreements.  See the NOTICE file distributed with
+  * this work for additional information regarding copyright ownership.
+  * The ASF licenses this file to You under the Apache License, Version 2.0
+  * (the "License"); you may not use this file except in compliance with
+  * the License.  You may obtain a copy of the License at
+  *
+  *    http://www.apache.org/licenses/LICENSE-2.0
+  *
+  * Unless required by applicable law or agreed to in writing, software
+  * distributed under the License is distributed on an "AS IS" BASIS,
+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  * See the License for the specific language governing permissions and
+  * limitations under the License.
+  */
+
+package kafka.server
+
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import kafka.utils.Logging
+
+/**
+ * This queue is used to collect actions which need to be executed later. One use case is
that ReplicaManager#appendRecords
+ * produces record changes so we need to check and complete delayed requests. In order to
avoid conflicting locking,
+ * we add those actions to this queue and then complete them at the end of KafkaApis.handle()
or DelayedJoin.onExpiration.
+ */
+class ActionQueue extends Logging {
+  private val queue = new ConcurrentLinkedQueue[() => Unit]()
+
+  /**
+   * add action to this queue.
+   * @param action action
+   */
+  def add(action: () => Unit): Unit = queue.add(action)
+
+  /**
+   * try to complete all delayed actions
+   */
+  def tryCompleteActions(): Unit = {
+    val maxToComplete = queue.size()
+    var count = 0
+    var done = false
+    while (!done && count < maxToComplete) {
+      try {
+        val action = queue.poll()
+        if (action == null) done = true
+        else action()
+      } catch {
+        case e: Throwable =>
+          error("failed to complete delayed actions", e)
+      } finally count += 1
+    }
+  }
+}
diff --git a/core/src/main/scala/kafka/server/DelayedOperation.scala b/core/src/main/scala/kafka/server/DelayedOperation.scala
index 2756b4f..09fd337 100644
--- a/core/src/main/scala/kafka/server/DelayedOperation.scala
+++ b/core/src/main/scala/kafka/server/DelayedOperation.scala
@@ -41,13 +41,15 @@ import scala.collection.mutable.ListBuffer
  * forceComplete().
  *
  * A subclass of DelayedOperation needs to provide an implementation of both onComplete()
and tryComplete().
+ *
+ * Noted that if you add a future delayed operation that calls ReplicaManager.appendRecords()
in onComplete()
+ * like DelayedJoin, you must be aware that this operation's onExpiration() needs to call
actionQueue.tryCompleteAction().
  */
 abstract class DelayedOperation(override val delayMs: Long,
                                 lockOpt: Option[Lock] = None)
   extends TimerTask with Logging {
 
   private val completed = new AtomicBoolean(false)
-  private val tryCompletePending = new AtomicBoolean(false)
   // Visible for testing
   private[server] val lock: Lock = lockOpt.getOrElse(new ReentrantLock)
 
@@ -100,42 +102,24 @@ abstract class DelayedOperation(override val delayMs: Long,
   def tryComplete(): Boolean
 
   /**
-   * Thread-safe variant of tryComplete() that attempts completion only if the lock can be
acquired
-   * without blocking.
-   *
-   * If threadA acquires the lock and performs the check for completion before completion
criteria is met
-   * and threadB satisfies the completion criteria, but fails to acquire the lock because
threadA has not
-   * yet released the lock, we need to ensure that completion is attempted again without
blocking threadA
-   * or threadB. `tryCompletePending` is set by threadB when it fails to acquire the lock
and at least one
-   * of threadA or threadB will attempt completion of the operation if this flag is set.
This ensures that
-   * every invocation of `maybeTryComplete` is followed by at least one invocation of `tryComplete`
until
-   * the operation is actually completed.
+   * Thread-safe variant of tryComplete() and call extra function if first tryComplete returns
false
+   * @param f else function to be executed after first tryComplete returns false
+   * @return result of tryComplete
    */
-  private[server] def maybeTryComplete(): Boolean = {
-    var retry = false
-    var done = false
-    do {
-      if (lock.tryLock()) {
-        try {
-          tryCompletePending.set(false)
-          done = tryComplete()
-        } finally {
-          lock.unlock()
-        }
-        // While we were holding the lock, another thread may have invoked `maybeTryComplete`
and set
-        // `tryCompletePending`. In this case we should retry.
-        retry = tryCompletePending.get()
-      } else {
-        // Another thread is holding the lock. If `tryCompletePending` is already set and
this thread failed to
-        // acquire the lock, then the thread that is holding the lock is guaranteed to see
the flag and retry.
-        // Otherwise, we should set the flag and retry on this thread since the thread holding
the lock may have
-        // released the lock and returned by the time the flag is set.
-        retry = !tryCompletePending.getAndSet(true)
-      }
-    } while (!isCompleted && retry)
-    done
+  private[server] def safeTryCompleteOrElse(f: => Unit): Boolean = inLock(lock) {
+    if (tryComplete()) true
+    else {
+      f
+      // last completion check
+      tryComplete()
+    }
   }
 
+  /**
+   * Thread-safe variant of tryComplete()
+   */
+  private[server] def safeTryComplete(): Boolean = inLock(lock)(tryComplete())
+
   /*
    * run() method defines a task that is executed on timeout
    */
@@ -219,38 +203,38 @@ final class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName:
Stri
   def tryCompleteElseWatch(operation: T, watchKeys: Seq[Any]): Boolean = {
     assert(watchKeys.nonEmpty, "The watch key list can't be empty")
 
-    // The cost of tryComplete() is typically proportional to the number of keys. Calling
-    // tryComplete() for each key is going to be expensive if there are many keys. Instead,
-    // we do the check in the following way. Call tryComplete(). If the operation is not
completed,
-    // we just add the operation to all keys. Then we call tryComplete() again. At this time,
if
-    // the operation is still not completed, we are guaranteed that it won't miss any future
triggering
-    // event since the operation is already on the watcher list for all keys. This does mean
that
-    // if the operation is completed (by another thread) between the two tryComplete() calls,
the
-    // operation is unnecessarily added for watch. However, this is a less severe issue since
the
-    // expire reaper will clean it up periodically.
-
-    // At this point the only thread that can attempt this operation is this current thread
-    // Hence it is safe to tryComplete() without a lock
-    var isCompletedByMe = operation.tryComplete()
-    if (isCompletedByMe)
-      return true
-
-    var watchCreated = false
-    for(key <- watchKeys) {
-      // If the operation is already completed, stop adding it to the rest of the watcher
list.
-      if (operation.isCompleted)
-        return false
-      watchForOperation(key, operation)
-
-      if (!watchCreated) {
-        watchCreated = true
-        estimatedTotalOperations.incrementAndGet()
-      }
-    }
-
-    isCompletedByMe = operation.maybeTryComplete()
-    if (isCompletedByMe)
-      return true
+    // The cost of tryComplete() is typically proportional to the number of keys. Calling
tryComplete() for each key is
+    // going to be expensive if there are many keys. Instead, we do the check in the following
way through safeTryCompleteOrElse().
+    // If the operation is not completed, we just add the operation to all keys. Then we
call tryComplete() again. At
+    // this time, if the operation is still not completed, we are guaranteed that it won't
miss any future triggering
+    // event since the operation is already on the watcher list for all keys.
+    //
+    // ==============[story about lock]==============
+    // Through safeTryCompleteOrElse(), we hold the operation's lock while adding the operation
to watch list and doing
+    // the tryComplete() check. This is to avoid a potential deadlock between the callers
to tryCompleteElseWatch() and
+    // checkAndComplete(). For example, the following deadlock can happen if the lock is
only held for the final tryComplete()
+    // 1) thread_a holds readlock of stateLock from TransactionStateManager
+    // 2) thread_a is executing tryCompleteElseWatch()
+    // 3) thread_a adds op to watch list
+    // 4) thread_b requires writelock of stateLock from TransactionStateManager (blocked
by thread_a)
+    // 5) thread_c calls checkAndComplete() and holds lock of op
+    // 6) thread_c is waiting readlock of stateLock to complete op (blocked by thread_b)
+    // 7) thread_a is waiting lock of op to call the final tryComplete() (blocked by thread_c)
+    //
+    // Note that even with the current approach, deadlocks could still be introduced. For
example,
+    // 1) thread_a calls tryCompleteElseWatch() and gets lock of op
+    // 2) thread_a adds op to watch list
+    // 3) thread_a calls op#tryComplete and tries to require lock_b
+    // 4) thread_b holds lock_b and calls checkAndComplete()
+    // 5) thread_b sees op from watch list
+    // 6) thread_b needs lock of op
+    // To avoid the above scenario, we recommend DelayedOperationPurgatory.checkAndComplete()
be called without holding
+    // any exclusive lock. Since DelayedOperationPurgatory.checkAndComplete() completes delayed
operations asynchronously,
+    // holding a exclusive lock to make the call is often unnecessary.
+    if (operation.safeTryCompleteOrElse {
+      watchKeys.foreach(key => watchForOperation(key, operation))
+      if (watchKeys.nonEmpty) estimatedTotalOperations.incrementAndGet()
+    }) return true
 
     // if it cannot be completed by now and hence is watched, add to the expire queue also
     if (!operation.isCompleted) {
@@ -375,7 +359,7 @@ final class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName:
Stri
         if (curr.isCompleted) {
           // another thread has completed this operation, just remove it
           iter.remove()
-        } else if (curr.maybeTryComplete()) {
+        } else if (curr.safeTryComplete()) {
           iter.remove()
           completed += 1
         }
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala
index 867ff6a..ef2df97 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -186,6 +186,10 @@ class KafkaApis(val requestChannel: RequestChannel,
       case e: FatalExitError => throw e
       case e: Throwable => handleError(request, e)
     } finally {
+      // try to complete delayed action. In order to avoid conflicting locking, the actions
to complete delayed requests
+      // are kept in a queue. We add the logic to check the ReplicaManager queue at the end
of KafkaApis.handle() and the
+      // expiration thread for certain delayed operations (e.g. DelayedJoin)
+      replicaManager.tryCompleteActions()
       // The local completion time may be set while processing the request. Only record it
if it's unset.
       if (request.apiLocalCompleteTimeNanos < 0)
         request.apiLocalCompleteTimeNanos = time.nanoseconds
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 32c9aa4..1ad37ef 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -559,9 +559,20 @@ class ReplicaManager(val config: KafkaConfig,
   }
 
   /**
+   * TODO: move this action queue to handle thread so we can simplify concurrency handling
+   */
+  private val actionQueue = new ActionQueue
+
+  def tryCompleteActions(): Unit = actionQueue.tryCompleteActions()
+
+  /**
    * Append messages to leader replicas of the partition, and wait for them to be replicated
to other replicas;
    * the callback function will be triggered either when timeout or the required acks are
satisfied;
    * if the callback function itself is already synchronized on some object then pass this
object to avoid deadlock.
+   *
+   * Noted that all pending delayed check operations are stored in a queue. All callers to
ReplicaManager.appendRecords()
+   * are expected to call ActionQueue.tryCompleteActions for all affected partitions, without
holding any conflicting
+   * locks.
    */
   def appendRecords(timeout: Long,
                     requiredAcks: Short,
@@ -585,6 +596,26 @@ class ReplicaManager(val config: KafkaConfig,
                     result.info.logStartOffset, result.info.recordErrors.asJava, result.info.errorMessage))
// response status
       }
 
+      actionQueue.add {
+        () =>
+          localProduceResults.foreach {
+            case (topicPartition, result) =>
+              val requestKey = TopicPartitionOperationKey(topicPartition)
+              result.info.leaderHwChange match {
+                case LeaderHwChange.Increased =>
+                  // some delayed operations may be unblocked after HW changed
+                  delayedProducePurgatory.checkAndComplete(requestKey)
+                  delayedFetchPurgatory.checkAndComplete(requestKey)
+                  delayedDeleteRecordsPurgatory.checkAndComplete(requestKey)
+                case LeaderHwChange.Same =>
+                  // probably unblock some follower fetch requests since log end offset has
been updated
+                  delayedFetchPurgatory.checkAndComplete(requestKey)
+                case LeaderHwChange.None =>
+                  // nothing
+              }
+          }
+      }
+
       recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats
})
 
       if (delayedProduceRequestRequired(requiredAcks, entriesPerPartition, localProduceResults))
{
diff --git a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
index c7145ce..62ee85d 100644
--- a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
@@ -17,8 +17,8 @@
 
 package kafka.coordinator
 
-import java.util.{Collections, Random}
 import java.util.concurrent.{ConcurrentHashMap, Executors}
+import java.util.{Collections, Random}
 import java.util.concurrent.atomic.AtomicInteger
 import java.util.concurrent.locks.Lock
 
@@ -97,7 +97,7 @@ abstract class AbstractCoordinatorConcurrencyTest[M <: CoordinatorMember]
{
   }
 
   def enableCompletion(): Unit = {
-    replicaManager.tryCompleteDelayedRequests()
+    replicaManager.tryCompleteActions()
     scheduler.tick()
   }
 
@@ -166,9 +166,8 @@ object AbstractCoordinatorConcurrencyTest {
       producePurgatory = new DelayedOperationPurgatory[DelayedProduce]("Produce", timer,
1, reaperEnabled = false)
       watchKeys = Collections.newSetFromMap(new ConcurrentHashMap[TopicPartitionOperationKey,
java.lang.Boolean]()).asScala
     }
-    def tryCompleteDelayedRequests(): Unit = {
-      watchKeys.map(producePurgatory.checkAndComplete)
-    }
+
+    override def tryCompleteActions(): Unit = watchKeys.map(producePurgatory.checkAndComplete)
 
     override def appendRecords(timeout: Long,
                                requiredAcks: Short,
@@ -204,7 +203,6 @@ object AbstractCoordinatorConcurrencyTest {
       val producerRequestKeys = entriesPerPartition.keys.map(TopicPartitionOperationKey(_)).toSeq
       watchKeys ++= producerRequestKeys
       producePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)
-      tryCompleteDelayedRequests()
     }
     override def getMagic(topicPartition: TopicPartition): Option[Byte] = {
       Some(RecordBatch.MAGIC_VALUE_V2)
diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
index 72de4a1..1f54bd5 100644
--- a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
@@ -18,6 +18,7 @@
 package kafka.coordinator.group
 
 import java.util.Properties
+import java.util.concurrent.locks.{Lock, ReentrantLock}
 import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
 
 import kafka.common.OffsetAndMetadata
@@ -60,16 +61,6 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
       new LeaveGroupOperation
   )
 
-  private val allOperationsWithTxn = Seq(
-    new JoinGroupOperation,
-    new SyncGroupOperation,
-    new OffsetFetchOperation,
-    new CommitTxnOffsetsOperation,
-    new CompleteTxnOperation,
-    new HeartbeatOperation,
-    new LeaveGroupOperation
-  )
-
   var heartbeatPurgatory: DelayedOperationPurgatory[DelayedHeartbeat] = _
   var joinPurgatory: DelayedOperationPurgatory[DelayedJoin] = _
   var groupCoordinator: GroupCoordinator = _
@@ -119,12 +110,33 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
 
   @Test
   def testConcurrentTxnGoodPathSequence(): Unit = {
-    verifyConcurrentOperations(createGroupMembers, allOperationsWithTxn)
+    verifyConcurrentOperations(createGroupMembers, Seq(
+      new JoinGroupOperation,
+      new SyncGroupOperation,
+      new OffsetFetchOperation,
+      new CommitTxnOffsetsOperation,
+      new CompleteTxnOperation,
+      new HeartbeatOperation,
+      new LeaveGroupOperation
+    ))
   }
 
   @Test
   def testConcurrentRandomSequence(): Unit = {
-    verifyConcurrentRandomSequences(createGroupMembers, allOperationsWithTxn)
+    /**
+     * handleTxnCommitOffsets does not complete delayed requests now so it causes error if
handleTxnCompletion is executed
+     * before completing delayed request. In random mode, we use this global lock to prevent
such an error.
+     */
+    val lock = new ReentrantLock()
+    verifyConcurrentRandomSequences(createGroupMembers, Seq(
+      new JoinGroupOperation,
+      new SyncGroupOperation,
+      new OffsetFetchOperation,
+      new CommitTxnOffsetsOperation(lock = Some(lock)),
+      new CompleteTxnOperation(lock = Some(lock)),
+      new HeartbeatOperation,
+      new LeaveGroupOperation
+    ))
   }
 
   @Test
@@ -198,6 +210,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
       groupCoordinator.handleJoinGroup(member.groupId, member.memberId, None, requireKnownMemberId
= false, "clientId", "clientHost",
        DefaultRebalanceTimeout, DefaultSessionTimeout,
        protocolType, protocols, responseCallback)
+      replicaManager.tryCompleteActions()
     }
     override def awaitAndVerify(member: GroupMember): Unit = {
        val joinGroupResult = await(member, DefaultRebalanceTimeout)
@@ -221,6 +234,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
         groupCoordinator.handleSyncGroup(member.groupId, member.generationId, member.memberId,
           Some(protocolType), Some(protocolName), member.groupInstanceId, Map.empty[String,
Array[Byte]], responseCallback)
       }
+      replicaManager.tryCompleteActions()
     }
     override def awaitAndVerify(member: GroupMember): Unit = {
       val result = await(member, DefaultSessionTimeout)
@@ -238,6 +252,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
     override def runWithCallback(member: GroupMember, responseCallback: HeartbeatCallback):
Unit = {
       groupCoordinator.handleHeartbeat(member.groupId, member.memberId,
         member.groupInstanceId, member.generationId, responseCallback)
+      replicaManager.tryCompleteActions()
     }
     override def awaitAndVerify(member: GroupMember): Unit = {
        val error = await(member, DefaultSessionTimeout)
@@ -252,6 +267,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
     }
     override def runWithCallback(member: GroupMember, responseCallback: OffsetFetchCallback):
Unit = {
       val (error, partitionData) = groupCoordinator.handleFetchOffsets(member.groupId, requireStable
= true, None)
+      replicaManager.tryCompleteActions()
       responseCallback(error, partitionData)
     }
     override def awaitAndVerify(member: GroupMember): Unit = {
@@ -271,6 +287,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
       val offsets = immutable.Map(tp -> OffsetAndMetadata(1, "", Time.SYSTEM.milliseconds()))
       groupCoordinator.handleCommitOffsets(member.groupId, member.memberId,
         member.groupInstanceId, member.generationId, offsets, responseCallback)
+      replicaManager.tryCompleteActions()
     }
     override def awaitAndVerify(member: GroupMember): Unit = {
        val offsets = await(member, 500)
@@ -278,7 +295,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
     }
   }
 
-  class CommitTxnOffsetsOperation extends CommitOffsetsOperation {
+  class CommitTxnOffsetsOperation(lock: Option[Lock] = None) extends CommitOffsetsOperation
{
     override def runWithCallback(member: GroupMember, responseCallback: CommitOffsetCallback):
Unit = {
       val tp = new TopicPartition("topic", 0)
       val offsets = immutable.Map(tp -> OffsetAndMetadata(1, "", Time.SYSTEM.milliseconds()))
@@ -293,13 +310,17 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
           offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean)
         responseCallback(errors)
       }
-      groupCoordinator.handleTxnCommitOffsets(member.group.groupId, producerId, producerEpoch,
-        JoinGroupRequest.UNKNOWN_MEMBER_ID, Option.empty, JoinGroupRequest.UNKNOWN_GENERATION_ID,
-        offsets, callbackWithTxnCompletion)
+      lock.foreach(_.lock())
+      try {
+        groupCoordinator.handleTxnCommitOffsets(member.group.groupId, producerId, producerEpoch,
+          JoinGroupRequest.UNKNOWN_MEMBER_ID, Option.empty, JoinGroupRequest.UNKNOWN_GENERATION_ID,
+          offsets, callbackWithTxnCompletion)
+        replicaManager.tryCompleteActions()
+      } finally lock.foreach(_.unlock())
     }
   }
 
-  class CompleteTxnOperation extends GroupOperation[CompleteTxnCallbackParams, CompleteTxnCallback]
{
+  class CompleteTxnOperation(lock: Option[Lock] = None) extends GroupOperation[CompleteTxnCallbackParams,
CompleteTxnCallback] {
     override def responseCallback(responsePromise: Promise[CompleteTxnCallbackParams]): CompleteTxnCallback
= {
       val callback: CompleteTxnCallback = error => responsePromise.success(error)
       callback
@@ -307,9 +328,13 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
     override def runWithCallback(member: GroupMember, responseCallback: CompleteTxnCallback):
Unit = {
       val producerId = 1000L
       val offsetsPartitions = (0 to numPartitions).map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME,
_))
-      groupCoordinator.groupManager.handleTxnCompletion(producerId,
-        offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean)
-      responseCallback(Errors.NONE)
+      lock.foreach(_.lock())
+      try {
+        groupCoordinator.groupManager.handleTxnCompletion(producerId,
+          offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean)
+        responseCallback(Errors.NONE)
+      } finally lock.foreach(_.unlock())
+
     }
     override def awaitAndVerify(member: GroupMember): Unit = {
       val error = await(member, 500)
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index 1be4969..3788cb1 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -509,6 +509,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
   class InitProducerIdOperation(val producerIdAndEpoch: Option[ProducerIdAndEpoch] = None)
extends TxnOperation[InitProducerIdResult] {
     override def run(txn: Transaction): Unit = {
       transactionCoordinator.handleInitProducerId(txn.transactionalId, 60000, producerIdAndEpoch,
resultCallback)
+      replicaManager.tryCompleteActions()
     }
     override def awaitAndVerify(txn: Transaction): Unit = {
       val initPidResult = result.getOrElse(throw new IllegalStateException("InitProducerId
has not completed"))
@@ -525,6 +526,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
             txnMetadata.producerEpoch,
             partitions,
             resultCallback)
+        replicaManager.tryCompleteActions()
       }
     }
     override def awaitAndVerify(txn: Transaction): Unit = {
@@ -597,12 +599,13 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
         }
       }
       txnStateManager.enableTransactionalIdExpiration()
+      replicaManager.tryCompleteActions()
       time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs + 1)
     }
 
     override def await(): Unit = {
       val (_, success) = TestUtils.computeUntilTrue({
-        replicaManager.tryCompleteDelayedRequests()
+        replicaManager.tryCompleteActions()
         transactions.forall(txn => transactionMetadata(txn).isEmpty)
       })(identity)
       assertTrue("Transaction not expired", success)
diff --git a/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala b/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala
index c29dfec..8f481f1 100644
--- a/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala
+++ b/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala
@@ -33,12 +33,12 @@ import scala.jdk.CollectionConverters._
 
 class DelayedOperationTest {
 
-  var purgatory: DelayedOperationPurgatory[MockDelayedOperation] = null
+  var purgatory: DelayedOperationPurgatory[DelayedOperation] = null
   var executorService: ExecutorService = null
 
   @Before
   def setUp(): Unit = {
-    purgatory = DelayedOperationPurgatory[MockDelayedOperation](purgatoryName = "mock")
+    purgatory = DelayedOperationPurgatory[DelayedOperation](purgatoryName = "mock")
   }
 
   @After
@@ -49,6 +49,43 @@ class DelayedOperationTest {
   }
 
   @Test
+  def testLockInTryCompleteElseWatch(): Unit = {
+    val op = new DelayedOperation(100000L) {
+      override def onExpiration(): Unit = {}
+      override def onComplete(): Unit = {}
+      override def tryComplete(): Boolean = {
+        assertTrue(lock.asInstanceOf[ReentrantLock].isHeldByCurrentThread)
+        false
+      }
+      override def safeTryComplete(): Boolean = {
+        fail("tryCompleteElseWatch should not use safeTryComplete")
+        super.safeTryComplete()
+      }
+    }
+    purgatory.tryCompleteElseWatch(op, Seq("key"))
+  }
+
+  @Test
+  def testSafeTryCompleteOrElse(): Unit = {
+    def op(shouldComplete: Boolean) = new DelayedOperation(100000L) {
+      override def onExpiration(): Unit = {}
+      override def onComplete(): Unit = {}
+      override def tryComplete(): Boolean = {
+        assertTrue(lock.asInstanceOf[ReentrantLock].isHeldByCurrentThread)
+        shouldComplete
+      }
+    }
+    var pass = false
+    assertFalse(op(false).safeTryCompleteOrElse {
+      pass = true
+    })
+    assertTrue(pass)
+    assertTrue(op(true).safeTryCompleteOrElse {
+      fail("this method should NOT be executed")
+    })
+  }
+
+  @Test
   def testRequestSatisfaction(): Unit = {
     val r1 = new MockDelayedOperation(100000L)
     val r2 = new MockDelayedOperation(100000L)
@@ -193,44 +230,6 @@ class DelayedOperationTest {
   }
 
   /**
-    * Verify that if there is lock contention between two threads attempting to complete,
-    * completion is performed without any blocking in either thread.
-    */
-  @Test
-  def testTryCompleteLockContention(): Unit = {
-    executorService = Executors.newSingleThreadExecutor()
-    val completionAttemptsRemaining = new AtomicInteger(Int.MaxValue)
-    val tryCompleteSemaphore = new Semaphore(1)
-    val key = "key"
-
-    val op = new MockDelayedOperation(100000L, None, None) {
-      override def tryComplete() = {
-        val shouldComplete = completionAttemptsRemaining.decrementAndGet <= 0
-        tryCompleteSemaphore.acquire()
-        try {
-          if (shouldComplete)
-            forceComplete()
-          else
-            false
-        } finally {
-          tryCompleteSemaphore.release()
-        }
-      }
-    }
-
-    purgatory.tryCompleteElseWatch(op, Seq(key))
-    completionAttemptsRemaining.set(2)
-    tryCompleteSemaphore.acquire()
-    val future = runOnAnotherThread(purgatory.checkAndComplete(key), shouldComplete = false)
-    TestUtils.waitUntilTrue(() => tryCompleteSemaphore.hasQueuedThreads, "Not attempting
to complete")
-    purgatory.checkAndComplete(key) // this should not block even though lock is not free
-    assertFalse("Operation should not have completed", op.isCompleted)
-    tryCompleteSemaphore.release()
-    future.get(10, TimeUnit.SECONDS)
-    assertTrue("Operation should have completed", op.isCompleted)
-  }
-
-  /**
     * Test `tryComplete` with multiple threads to verify that there are no timing windows
     * when completion is not performed even if the thread that makes the operation completable
     * may not be able to acquire the operation lock. Since it is difficult to test all scenarios,
@@ -280,23 +279,6 @@ class DelayedOperationTest {
     ops.foreach { op => assertTrue("Operation should have completed", op.isCompleted)
}
   }
 
-  @Test
-  def testDelayedOperationLock(): Unit = {
-    verifyDelayedOperationLock(new MockDelayedOperation(100000L), mismatchedLocks = false)
-  }
-
-  @Test
-  def testDelayedOperationLockOverride(): Unit = {
-    def newMockOperation = {
-      val lock = new ReentrantLock
-      new MockDelayedOperation(100000L, Some(lock), Some(lock))
-    }
-    verifyDelayedOperationLock(newMockOperation, mismatchedLocks = false)
-
-    verifyDelayedOperationLock(new MockDelayedOperation(100000L, None, Some(new ReentrantLock)),
-        mismatchedLocks = true)
-  }
-
   def verifyDelayedOperationLock(mockDelayedOperation: => MockDelayedOperation, mismatchedLocks:
Boolean): Unit = {
     val key = "key"
     executorService = Executors.newSingleThreadExecutor


Mime
View raw message