kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From guozh...@apache.org
Subject [kafka] 03/03: KAFKA-9417: New Integration Test for KIP-447 (#8000)
Date Wed, 12 Feb 2020 20:46:13 GMT
This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch 2.5
in repository https://gitbox.apache.org/repos/asf/kafka.git

commit a017dcbd3e7d059a0ca4e76777a42912e6b91e65
Author: Boyang Chen <boyang@confluent.io>
AuthorDate: Wed Feb 12 12:34:12 2020 -0800

    KAFKA-9417: New Integration Test for KIP-447 (#8000)
    
    This change mainly have 2 components:
    
    1. extend the existing transactions_test.py to also try out new sendTxnOffsets(groupMetadata) API to make sure we are not introducing any regression or compatibility issue
      a. We shrink the time window to 10 seconds for the txn timeout scheduler on broker so that we could trigger expiration earlier than later
    
    2. create a completely new system test class called group_mode_transactions_test which is more complicated than the existing system test, as we are taking rebalance into consideration and using multiple partitions instead of one. For further breakdown:
      a. The message count was done on partition level, instead of global as we need to visualize
    the per partition order throughout the test. For this sake, we extend ConsoleConsumer to print out the data partition as well to help message copier interpret the per partition data.
      b. The progress count includes the time for completing the pending txn offset expiration
      c. More visibility and feature improvements on TransactionMessageCopier to better work under either standalone or group mode.
    
    Reviewers: Matthias J. Sax <matthias@confluent.io>, Guozhang Wang <wangguoz@gmail.com>
---
 checkstyle/suppressions.xml                        |   2 +-
 .../consumer/internals/ConsumerCoordinator.java    |   2 +-
 .../kafka/clients/consumer/internals/Fetcher.java  |  11 +-
 .../producer/internals/TransactionManager.java     |   2 +
 .../kafka/coordinator/group/GroupCoordinator.scala |   4 +-
 .../transaction/TransactionStateManager.scala      |   2 +-
 .../main/scala/kafka/tools/ConsoleConsumer.scala   |  10 +-
 .../examples/ExactlyOnceMessageProcessor.java      |   5 +-
 .../kafka/examples/KafkaConsumerProducerDemo.java  |   2 +-
 .../java/kafka/examples/KafkaExactlyOnceDemo.java  |   2 +-
 .../src/main/java/kafka/examples/Producer.java     |   4 +
 tests/docker/ducker-ak                             |   2 +
 tests/kafkatest/services/console_consumer.py       |  13 +-
 tests/kafkatest/services/kafka/kafka.py            |   2 +-
 .../services/transactional_message_copier.py       |  16 +-
 tests/kafkatest/services/verifiable_producer.py    |  14 +-
 ...ons_test.py => group_mode_transactions_test.py} | 171 +++++++++++++--------
 tests/kafkatest/tests/core/transactions_test.py    |  34 ++--
 .../kafka/tools/TransactionalMessageCopier.java    | 153 ++++++++++++------
 .../org/apache/kafka/tools/VerifiableConsumer.java |  10 +-
 20 files changed, 316 insertions(+), 145 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 9b05f59..7ebe7fc 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -245,7 +245,7 @@
     <suppress checks="BooleanExpressionComplexity"
               files="StreamsResetter.java"/>
     <suppress checks="NPathComplexity"
-              files="(ProducerPerformance|StreamsResetter|Agent).java"/>
+              files="(ProducerPerformance|StreamsResetter|Agent|TransactionalMessageCopier).java"/>
     <suppress checks="ImportControl"
               files="SignalLogger.java"/>
     <suppress checks="IllegalImport"
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index 999921f..7a5f627 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -1291,7 +1291,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                 // just retry
                 log.info("The following partitions still have unstable offsets " +
                              "which are not cleared on the broker side: {}" +
-                             ", this could be either" +
+                             ", this could be either " +
                              "transactional offsets waiting for completion, or " +
                              "normal offsets waiting for replication after appending to local log", unstableTxnOffsetTopicPartitions);
                 future.raise(new UnstableOffsetCommitException("There are unstable offsets for the requested topic partitions"));
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index 7890c9a..f0aaa13 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -1018,15 +1018,16 @@ public class Fetcher<K, V> implements Closeable {
                        error == Errors.REPLICA_NOT_AVAILABLE ||
                        error == Errors.KAFKA_STORAGE_ERROR ||
                        error == Errors.OFFSET_NOT_AVAILABLE ||
-                       error == Errors.LEADER_NOT_AVAILABLE) {
-                log.debug("Attempt to fetch offsets for partition {} failed due to {}, retrying.",
-                        topicPartition, error);
-                partitionsToRetry.add(topicPartition);
-            } else if (error == Errors.FENCED_LEADER_EPOCH ||
+                       error == Errors.LEADER_NOT_AVAILABLE ||
                        error == Errors.UNKNOWN_LEADER_EPOCH) {
                 log.debug("Attempt to fetch offsets for partition {} failed due to {}, retrying.",
                         topicPartition, error);
                 partitionsToRetry.add(topicPartition);
+            } else if (error == Errors.FENCED_LEADER_EPOCH) {
+                log.debug("Attempt to fetch offsets for partition {} failed due to fenced leader epoch, refresh " +
+                              "the metadata and retrying.", topicPartition);
+                metadata.requestUpdate();
+                partitionsToRetry.add(topicPartition);
             } else if (error == Errors.UNKNOWN_TOPIC_OR_PARTITION) {
                 log.warn("Received unknown topic or partition error in ListOffset request for partition {}", topicPartition);
                 partitionsToRetry.add(topicPartition);
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index 0872d62..18ab408 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -66,6 +66,7 @@ import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
+import java.util.Locale;
 import java.util.Map;
 import java.util.OptionalInt;
 import java.util.OptionalLong;
@@ -1329,6 +1330,7 @@ public class TransactionManager {
                         transactionCoordinator = node;
                 }
                 result.done();
+                log.info("Discovered {} coordinator {}", coordinatorType.toString().toLowerCase(Locale.ROOT), node);
             } else if (error == Errors.COORDINATOR_NOT_AVAILABLE) {
                 reenqueue();
             } else if (error == Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED) {
diff --git a/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala b/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala
index 90f391a..4843154 100644
--- a/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala
@@ -240,13 +240,13 @@ class GroupCoordinator(val brokerId: Int,
         } else if (requireKnownMemberId) {
             // If member id required (dynamic membership), register the member in the pending member list
             // and send back a response to call for another join group request with allocated member id.
-          debug(s"Dynamic member with unknown member id rejoins group ${group.groupId} in " +
+          debug(s"Dynamic member with unknown member id joins group ${group.groupId} in " +
               s"${group.currentState} state. Created a new member id $newMemberId and request the member to rejoin with this id.")
           group.addPendingMember(newMemberId)
           addPendingMemberExpiration(group, newMemberId, sessionTimeoutMs)
           responseCallback(JoinGroupResult(newMemberId, Errors.MEMBER_ID_REQUIRED))
         } else {
-          debug(s"Dynamic member with unknown member id rejoins group ${group.groupId} in " +
+          debug(s"Dynamic member with unknown member id joins group ${group.groupId} in " +
             s"${group.currentState} state. Created a new member id $newMemberId for this member and add to the group.")
           addMemberAndRebalance(rebalanceTimeoutMs, sessionTimeoutMs, newMemberId, groupInstanceId,
             clientId, clientHost, protocolType, protocols, group, responseCallback)
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index da7502c..174e3a5 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -47,7 +47,7 @@ object TransactionStateManager {
   // default transaction management config values
   val DefaultTransactionsMaxTimeoutMs: Int = TimeUnit.MINUTES.toMillis(15).toInt
   val DefaultTransactionalIdExpirationMs: Int = TimeUnit.DAYS.toMillis(7).toInt
-  val DefaultAbortTimedOutTransactionsIntervalMs: Int = TimeUnit.MINUTES.toMillis(1).toInt
+  val DefaultAbortTimedOutTransactionsIntervalMs: Int = TimeUnit.SECONDS.toMillis(10).toInt
   val DefaultRemoveExpiredTransactionalIdsIntervalMs: Int = TimeUnit.HOURS.toMillis(1).toInt
 }
 
diff --git a/core/src/main/scala/kafka/tools/ConsoleConsumer.scala b/core/src/main/scala/kafka/tools/ConsoleConsumer.scala
index 691cde5..92c6c9c 100755
--- a/core/src/main/scala/kafka/tools/ConsoleConsumer.scala
+++ b/core/src/main/scala/kafka/tools/ConsoleConsumer.scala
@@ -455,9 +455,10 @@ object ConsoleConsumer extends Logging {
 }
 
 class DefaultMessageFormatter extends MessageFormatter {
+  var printTimestamp = false
   var printKey = false
   var printValue = true
-  var printTimestamp = false
+  var printPartition = false
   var keySeparator = "\t".getBytes(StandardCharsets.UTF_8)
   var lineSeparator = "\n".getBytes(StandardCharsets.UTF_8)
 
@@ -471,6 +472,8 @@ class DefaultMessageFormatter extends MessageFormatter {
       printKey = props.getProperty("print.key").trim.equalsIgnoreCase("true")
     if (props.containsKey("print.value"))
       printValue = props.getProperty("print.value").trim.equalsIgnoreCase("true")
+    if (props.containsKey("print.partition"))
+      printPartition = props.getProperty("print.partition").trim.equalsIgnoreCase("true")
     if (props.containsKey("key.separator"))
       keySeparator = props.getProperty("key.separator").getBytes(StandardCharsets.UTF_8)
     if (props.containsKey("line.separator"))
@@ -531,6 +534,11 @@ class DefaultMessageFormatter extends MessageFormatter {
 
     if (printValue) {
       write(valueDeserializer, value, topic)
+      writeSeparator(printPartition)
+    }
+
+    if (printPartition) {
+      output.write(s"$partition".getBytes(StandardCharsets.UTF_8))
       output.write(lineSeparator)
     }
   }
diff --git a/examples/src/main/java/kafka/examples/ExactlyOnceMessageProcessor.java b/examples/src/main/java/kafka/examples/ExactlyOnceMessageProcessor.java
index 53685f3..482e442 100644
--- a/examples/src/main/java/kafka/examples/ExactlyOnceMessageProcessor.java
+++ b/examples/src/main/java/kafka/examples/ExactlyOnceMessageProcessor.java
@@ -76,8 +76,11 @@ public class ExactlyOnceMessageProcessor extends Thread {
         this.numInstances = numInstances;
         this.instanceIdx = instanceIdx;
         this.transactionalId = "Processor-" + instanceIdx;
+        // If we are using the group mode, it is recommended to have a relatively short txn timeout
+        // in order to clear pending offsets faster.
+        final int transactionTimeoutMs = this.mode.equals("groupMode") ? 10000 : -1;
         // A unique transactional.id must be provided in order to properly use EOS.
-        producer = new Producer(outputTopic, true, transactionalId, true, -1, null).get();
+        producer = new Producer(outputTopic, true, transactionalId, true, -1, transactionTimeoutMs, null).get();
         // Consumer must be in read_committed mode, which means it won't be able to read uncommitted data.
         consumer = new Consumer(inputTopic, consumerGroupId, READ_COMMITTED, -1, null).get();
         this.latch = latch;
diff --git a/examples/src/main/java/kafka/examples/KafkaConsumerProducerDemo.java b/examples/src/main/java/kafka/examples/KafkaConsumerProducerDemo.java
index 21d85c3..8a29402 100644
--- a/examples/src/main/java/kafka/examples/KafkaConsumerProducerDemo.java
+++ b/examples/src/main/java/kafka/examples/KafkaConsumerProducerDemo.java
@@ -25,7 +25,7 @@ public class KafkaConsumerProducerDemo {
     public static void main(String[] args) throws InterruptedException {
         boolean isAsync = args.length == 0 || !args[0].trim().equalsIgnoreCase("sync");
         CountDownLatch latch = new CountDownLatch(2);
-        Producer producerThread = new Producer(KafkaProperties.TOPIC, isAsync, null, false, 10000, latch);
+        Producer producerThread = new Producer(KafkaProperties.TOPIC, isAsync, null, false, 10000, -1, latch);
         producerThread.start();
 
         Consumer consumerThread = new Consumer(KafkaProperties.TOPIC, "DemoConsumer", false, 10000, latch);
diff --git a/examples/src/main/java/kafka/examples/KafkaExactlyOnceDemo.java b/examples/src/main/java/kafka/examples/KafkaExactlyOnceDemo.java
index 288b786..6da159c 100644
--- a/examples/src/main/java/kafka/examples/KafkaExactlyOnceDemo.java
+++ b/examples/src/main/java/kafka/examples/KafkaExactlyOnceDemo.java
@@ -88,7 +88,7 @@ public class KafkaExactlyOnceDemo {
         CountDownLatch prePopulateLatch = new CountDownLatch(1);
 
         /* Stage 2: pre-populate records */
-        Producer producerThread = new Producer(INPUT_TOPIC, false, null, true, numRecords, prePopulateLatch);
+        Producer producerThread = new Producer(INPUT_TOPIC, false, null, true, numRecords, -1, prePopulateLatch);
         producerThread.start();
 
         if (!prePopulateLatch.await(5, TimeUnit.MINUTES)) {
diff --git a/examples/src/main/java/kafka/examples/Producer.java b/examples/src/main/java/kafka/examples/Producer.java
index 3805dd3..da8c00b 100644
--- a/examples/src/main/java/kafka/examples/Producer.java
+++ b/examples/src/main/java/kafka/examples/Producer.java
@@ -40,12 +40,16 @@ public class Producer extends Thread {
                     final String transactionalId,
                     final boolean enableIdempotency,
                     final int numRecords,
+                    final int transactionTimeoutMs,
                     final CountDownLatch latch) {
         Properties props = new Properties();
         props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, KafkaProperties.KAFKA_SERVER_URL + ":" + KafkaProperties.KAFKA_SERVER_PORT);
         props.put(ProducerConfig.CLIENT_ID_CONFIG, "DemoProducer");
         props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class.getName());
         props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName());
+        if (transactionTimeoutMs > 0) {
+            props.put(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG, transactionTimeoutMs);
+        }
         if (transactionalId != null) {
             props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, transactionalId);
         }
diff --git a/tests/docker/ducker-ak b/tests/docker/ducker-ak
index a29395e..f923de2 100755
--- a/tests/docker/ducker-ak
+++ b/tests/docker/ducker-ak
@@ -223,6 +223,8 @@ ducker_build() {
     SECONDS=0
 
     must_pushd "${ducker_dir}"
+    # Tip: if you are scratching your head for some dependency problems that are referring to an old code version
+    # (for example java.lang.NoClassDefFoundError), add --no-cache flag to the build shall give you a clean start.
     must_do -v -o docker build --memory="${docker_build_memory_limit}" \
         --build-arg "ducker_creator=${user_name}" --build-arg "jdk_version=${jdk_version}" -t "${image_name}" \
         -f "${ducker_dir}/Dockerfile" ${docker_args} -- .
diff --git a/tests/kafkatest/services/console_consumer.py b/tests/kafkatest/services/console_consumer.py
index 5fd4712..0811bcd 100644
--- a/tests/kafkatest/services/console_consumer.py
+++ b/tests/kafkatest/services/console_consumer.py
@@ -60,7 +60,7 @@ class ConsoleConsumer(KafkaPathResolverMixin, JmxMixin, BackgroundThreadService)
     def __init__(self, context, num_nodes, kafka, topic, group_id="test-consumer-group", new_consumer=True,
                  message_validator=None, from_beginning=True, consumer_timeout_ms=None, version=DEV_BRANCH,
                  client_id="console-consumer", print_key=False, jmx_object_names=None, jmx_attributes=None,
-                 enable_systest_events=False, stop_timeout_sec=35, print_timestamp=False,
+                 enable_systest_events=False, stop_timeout_sec=35, print_timestamp=False, print_partition=False,
                  isolation_level="read_uncommitted", jaas_override_variables=None,
                  kafka_opts_override="", client_prop_file_override="", consumer_properties={}):
         """
@@ -76,12 +76,13 @@ class ConsoleConsumer(KafkaPathResolverMixin, JmxMixin, BackgroundThreadService)
                                         successively consumed messages exceeds this timeout. Setting this and
                                         waiting for the consumer to stop is a pretty good way to consume all messages
                                         in a topic.
+            print_timestamp             if True, print each message's timestamp as well
             print_key                   if True, print each message's key as well
+            print_partition             if True, print each message's partition as well
             enable_systest_events       if True, console consumer will print additional lifecycle-related information
                                         only available in 0.10.0 and later.
             stop_timeout_sec            After stopping a node, wait up to stop_timeout_sec for the node to stop,
                                         and the corresponding background thread to finish successfully.
-            print_timestamp             if True, print each message's timestamp as well
             isolation_level             How to handle transactional messages.
             jaas_override_variables     A dict of variables to be used in the jaas.conf template file
             kafka_opts_override         Override parameters of the KAFKA_OPTS environment variable
@@ -108,6 +109,7 @@ class ConsoleConsumer(KafkaPathResolverMixin, JmxMixin, BackgroundThreadService)
         self.clean_shutdown_nodes = set()
         self.client_id = client_id
         self.print_key = print_key
+        self.print_partition = print_partition
         self.log_level = "TRACE"
         self.stop_timeout_sec = stop_timeout_sec
 
@@ -191,11 +193,14 @@ class ConsoleConsumer(KafkaPathResolverMixin, JmxMixin, BackgroundThreadService)
             if node.version > LATEST_0_8_2:
                 cmd += " --timeout-ms %s" % self.consumer_timeout_ms
 
+        if self.print_timestamp:
+            cmd += " --property print.timestamp=true"
+
         if self.print_key:
             cmd += " --property print.key=true"
 
-        if self.print_timestamp:
-            cmd += " --property print.timestamp=true"
+        if self.print_partition:
+            cmd += " --property print.partition=true"
 
         # LoggingMessageFormatter was introduced after 0.9
         if node.version > LATEST_0_9:
diff --git a/tests/kafkatest/services/kafka/kafka.py b/tests/kafkatest/services/kafka/kafka.py
index acb2137..e59431e 100644
--- a/tests/kafkatest/services/kafka/kafka.py
+++ b/tests/kafkatest/services/kafka/kafka.py
@@ -829,7 +829,7 @@ class KafkaService(KafkaPathResolverMixin, JmxMixin, Service):
         """
         Check whether a broker is registered in Zookeeper
         """
-        self.logger.debug("Querying zookeeper to see if broker %s is registered", node)
+        self.logger.debug("Querying zookeeper to see if broker %s is registered", str(node))
         broker_info = self.zk.query("/brokers/ids/%s" % self.idx(node), chroot=self.zk_chroot)
         self.logger.debug("Broker info: %s", broker_info)
         return broker_info is not None
diff --git a/tests/kafkatest/services/transactional_message_copier.py b/tests/kafkatest/services/transactional_message_copier.py
index 1a6a34c..dc972d7 100644
--- a/tests/kafkatest/services/transactional_message_copier.py
+++ b/tests/kafkatest/services/transactional_message_copier.py
@@ -46,13 +46,14 @@ class TransactionalMessageCopier(KafkaPathResolverMixin, BackgroundThreadService
     }
 
     def __init__(self, context, num_nodes, kafka, transactional_id, consumer_group,
-                 input_topic, input_partition, output_topic, max_messages = -1,
-                 transaction_size = 1000, enable_random_aborts=True):
+                 input_topic, input_partition, output_topic, max_messages=-1,
+                 transaction_size=1000, transaction_timeout=None, enable_random_aborts=True, use_group_metadata=False, group_mode=False):
         super(TransactionalMessageCopier, self).__init__(context, num_nodes)
         self.kafka = kafka
         self.transactional_id = transactional_id
         self.consumer_group = consumer_group
         self.transaction_size = transaction_size
+        self.transaction_timeout = transaction_timeout
         self.input_topic = input_topic
         self.input_partition = input_partition
         self.output_topic = output_topic
@@ -62,6 +63,8 @@ class TransactionalMessageCopier(KafkaPathResolverMixin, BackgroundThreadService
         self.remaining = -1
         self.stop_timeout_sec = 60
         self.enable_random_aborts = enable_random_aborts
+        self.use_group_metadata = use_group_metadata
+        self.group_mode = group_mode
         self.loggers = {
             "org.apache.kafka.clients.producer": "TRACE",
             "org.apache.kafka.clients.consumer": "TRACE"
@@ -120,9 +123,18 @@ class TransactionalMessageCopier(KafkaPathResolverMixin, BackgroundThreadService
         cmd += " --input-partition %s" % str(self.input_partition)
         cmd += " --transaction-size %s" % str(self.transaction_size)
 
+        if self.transaction_timeout is not None:
+            cmd += " --transaction-timeout %s" % str(self.transaction_timeout)
+
         if self.enable_random_aborts:
             cmd += " --enable-random-aborts"
 
+        if self.use_group_metadata:
+            cmd += " --use-group-metadata"
+
+        if self.group_mode:
+            cmd += " --group-mode"
+
         if self.max_messages > 0:
             cmd += " --max-messages %s" % str(self.max_messages)
         cmd += " 2>> %s | tee -a %s &" % (TransactionalMessageCopier.STDERR_CAPTURE, TransactionalMessageCopier.STDOUT_CAPTURE)
diff --git a/tests/kafkatest/services/verifiable_producer.py b/tests/kafkatest/services/verifiable_producer.py
index 893baa4..caa3961 100644
--- a/tests/kafkatest/services/verifiable_producer.py
+++ b/tests/kafkatest/services/verifiable_producer.py
@@ -91,6 +91,7 @@ class VerifiableProducer(KafkaPathResolverMixin, VerifiableClientMixin, Backgrou
         for node in self.nodes:
             node.version = version
         self.acked_values = []
+        self.acked_values_by_partition = {}
         self._last_acked_offsets = {}
         self.not_acked_values = []
         self.produced_count = {}
@@ -178,7 +179,13 @@ class VerifiableProducer(KafkaPathResolverMixin, VerifiableClientMixin, Backgrou
 
                     elif data["name"] == "producer_send_success":
                         partition = TopicPartition(data["topic"], data["partition"])
-                        self.acked_values.append(self.message_validator(data["value"]))
+                        value = self.message_validator(data["value"])
+                        self.acked_values.append(value)
+
+                        if partition not in self.acked_values_by_partition:
+                            self.acked_values_by_partition[partition] = []
+                        self.acked_values_by_partition[partition].append(value)
+
                         self._last_acked_offsets[partition] = data["offset"]
                         self.produced_count[idx] += 1
 
@@ -256,6 +263,11 @@ class VerifiableProducer(KafkaPathResolverMixin, VerifiableClientMixin, Backgrou
             return self.acked_values
 
     @property
+    def acked_by_partition(self):
+        with self.lock:
+            return self.acked_values_by_partition
+
+    @property
     def not_acked(self):
         with self.lock:
             return self.not_acked_values
diff --git a/tests/kafkatest/tests/core/transactions_test.py b/tests/kafkatest/tests/core/group_mode_transactions_test.py
similarity index 60%
copy from tests/kafkatest/tests/core/transactions_test.py
copy to tests/kafkatest/tests/core/group_mode_transactions_test.py
index 4da5960..deaa26d 100644
--- a/tests/kafkatest/tests/core/transactions_test.py
+++ b/tests/kafkatest/tests/core/group_mode_transactions_test.py
@@ -26,16 +26,18 @@ from ducktape.mark.resource import cluster
 from ducktape.utils.util import wait_until
 
 
-class TransactionsTest(Test):
-    """Tests transactions by transactionally copying data from a source topic to
-    a destination topic and killing the copy process as well as the broker
-    randomly through the process. In the end we verify that the final output
-    topic contains exactly one committed copy of each message in the input
-    topic
+class GroupModeTransactionsTest(Test):
+    """Essentially testing the same functionality as TransactionsTest by transactionally copying data
+    from a source topic to a destination topic and killing the copy process as well as the broker
+    randomly through the process. The major difference is that we choose to work as a collaborated
+    group with same topic subscription instead of individual copiers.
+
+    In the end we verify that the final output topic contains exactly one committed copy of
+    each message from the original producer.
     """
     def __init__(self, test_context):
         """:type test_context: ducktape.tests.test.TestContext"""
-        super(TransactionsTest, self).__init__(test_context=test_context)
+        super(GroupModeTransactionsTest, self).__init__(test_context=test_context)
 
         self.input_topic = "input-topic"
         self.output_topic = "output-topic"
@@ -43,11 +45,13 @@ class TransactionsTest(Test):
         self.num_brokers = 3
 
         # Test parameters
-        self.num_input_partitions = 2
-        self.num_output_partitions = 3
+        self.num_input_partitions = 9
+        self.num_output_partitions = 9
+        self.num_copiers = 3
         self.num_seed_messages = 100000
         self.transaction_size = 750
-        self.consumer_group = "transactions-test-consumer-group"
+        self.transaction_timeout = 10000
+        self.consumer_group = "grouped-transactions-test-consumer-group"
 
         self.zk = ZookeeperService(test_context, num_nodes=1)
         self.kafka = KafkaService(test_context,
@@ -65,20 +69,21 @@ class TransactionsTest(Test):
                                            topic=topic,
                                            message_validator=is_int,
                                            max_messages=num_seed_messages,
-                                           enable_idempotence=True)
+                                           enable_idempotence=True,
+                                           repeating_keys=self.num_input_partitions)
         seed_producer.start()
         wait_until(lambda: seed_producer.num_acked >= num_seed_messages,
                    timeout_sec=seed_timeout_sec,
-                   err_msg="Producer failed to produce messages %d in  %ds." %\
-                   (self.num_seed_messages, seed_timeout_sec))
-        return seed_producer.acked
+                   err_msg="Producer failed to produce messages %d in  %ds." % \
+                           (self.num_seed_messages, seed_timeout_sec))
+        return seed_producer.acked_by_partition
 
     def get_messages_from_topic(self, topic, num_messages):
         consumer = self.start_consumer(topic, group_id="verifying_consumer")
         return self.drain_consumer(consumer, num_messages)
 
     def bounce_brokers(self, clean_shutdown):
-       for node in self.kafka.nodes:
+        for node in self.kafka.nodes:
             if clean_shutdown:
                 self.kafka.restart_node(node, clean_shutdown = True)
             else:
@@ -89,7 +94,7 @@ class TransactionsTest(Test):
                            hard-killed broker %s" % str(node.account))
                 self.kafka.start_node(node)
 
-    def create_and_start_message_copier(self, input_topic, input_partition, output_topic, transactional_id):
+    def create_and_start_message_copier(self, input_topic, output_topic, transactional_id):
         message_copier = TransactionalMessageCopier(
             context=self.test_context,
             num_nodes=1,
@@ -97,10 +102,13 @@ class TransactionsTest(Test):
             transactional_id=transactional_id,
             consumer_group=self.consumer_group,
             input_topic=input_topic,
-            input_partition=input_partition,
+            input_partition=-1,
             output_topic=output_topic,
             max_messages=-1,
-            transaction_size=self.transaction_size
+            transaction_size=self.transaction_size,
+            transaction_timeout=self.transaction_timeout,
+            use_group_metadata=True,
+            group_mode=True
         )
         message_copier.start()
         wait_until(lambda: message_copier.alive(message_copier.nodes[0]),
@@ -108,13 +116,13 @@ class TransactionsTest(Test):
                    err_msg="Message copier failed to start after 10 s")
         return message_copier
 
-    def bounce_copiers(self, copiers, clean_shutdown):
+    def bounce_copiers(self, copiers, clean_shutdown, timeout_sec=240):
         for _ in range(3):
             for copier in copiers:
                 wait_until(lambda: copier.progress_percent() >= 20.0,
-                           timeout_sec=30,
-                           err_msg="%s : Message copier didn't make enough progress in 30s. Current progress: %s" \
-                           % (copier.transactional_id, str(copier.progress_percent())))
+                           timeout_sec=timeout_sec,
+                           err_msg="%s : Message copier didn't make enough progress in %ds. Current progress: %s" \
+                                   % (copier.transactional_id, timeout_sec, str(copier.progress_percent())))
                 self.logger.info("%s - progress: %s" % (copier.transactional_id,
                                                         str(copier.progress_percent())))
                 copier.restart(clean_shutdown)
@@ -125,28 +133,54 @@ class TransactionsTest(Test):
             copiers.append(self.create_and_start_message_copier(
                 input_topic=input_topic,
                 output_topic=output_topic,
-                input_partition=i,
                 transactional_id="copier-" + str(i)
             ))
         return copiers
 
+    @staticmethod
+    def valid_value_and_partition(msg):
+        """Method used to check whether the given message is a valid tab
+        separated value + partition
+
+        return value and partition as a size-two array represented tuple: [value, partition]
+        """
+        try:
+            splitted_msg = msg.split('\t')
+            tuple = [int(splitted_msg[0]), int(splitted_msg[1])]
+            return tuple
+
+        except ValueError:
+            raise Exception("Unexpected message format (expected a tab separated [value, partition] tuple). Message: %s" % (msg))
+
     def start_consumer(self, topic_to_read, group_id):
         consumer = ConsoleConsumer(context=self.test_context,
                                    num_nodes=1,
                                    kafka=self.kafka,
                                    topic=topic_to_read,
                                    group_id=group_id,
-                                   message_validator=is_int,
+                                   message_validator=self.valid_value_and_partition,
                                    from_beginning=True,
+                                   print_partition=True,
                                    isolation_level="read_committed")
         consumer.start()
         # ensure that the consumer is up.
         wait_until(lambda: (len(consumer.messages_consumed[1]) > 0) == True,
                    timeout_sec=60,
-                   err_msg="Consumer failed to consume any messages for %ds" %\
-                   60)
+                   err_msg="Consumer failed to consume any messages for %ds" % \
+                           60)
         return consumer
 
+    @staticmethod
+    def split_by_partition(messages_consumed):
+        messages_by_partition = {}
+
+        for msg in messages_consumed:
+            partition = msg[1]
+            if partition not in messages_by_partition:
+                messages_by_partition[partition] = []
+            messages_by_partition[partition].append(msg[0])
+        return messages_by_partition
+
     def drain_consumer(self, consumer, num_messages):
         # wait until we read at least the expected number of messages.
         # This is a safe check because both failure modes will be caught:
@@ -157,10 +191,10 @@ class TransactionsTest(Test):
         #     test to fail.
         wait_until(lambda: len(consumer.messages_consumed[1]) >= num_messages,
                    timeout_sec=90,
-                   err_msg="Consumer consumed only %d out of %d messages in %ds" %\
-                   (len(consumer.messages_consumed[1]), num_messages, 90))
+                   err_msg="Consumer consumed only %d out of %d messages in %ds" % \
+                           (len(consumer.messages_consumed[1]), num_messages, 90))
         consumer.stop()
-        return consumer.messages_consumed[1]
+        return self.split_by_partition(consumer.messages_consumed[1])
 
     def copy_messages_transactionally(self, failure_mode, bounce_target,
                                       input_topic, output_topic,
@@ -188,11 +222,12 @@ class TransactionsTest(Test):
         elif bounce_target == "clients":
             self.bounce_copiers(copiers, clean_shutdown)
 
+        copier_timeout_sec = 240
         for copier in copiers:
             wait_until(lambda: copier.is_done,
-                       timeout_sec=120,
-                       err_msg="%s - Failed to copy all messages in  %ds." %\
-                       (copier.transactional_id, 120))
+                       timeout_sec=copier_timeout_sec,
+                       err_msg="%s - Failed to copy all messages in %ds." % \
+                               (copier.transactional_id, copier_timeout_sec))
         self.logger.info("finished copying messages")
 
         return self.drain_consumer(concurrent_consumer, num_messages_to_copy)
@@ -215,53 +250,63 @@ class TransactionsTest(Test):
             }
         }
 
-    @cluster(num_nodes=9)
+    @cluster(num_nodes=10)
     @matrix(failure_mode=["hard_bounce", "clean_bounce"],
-            bounce_target=["brokers", "clients"],
-            check_order=[True, False])
-    def test_transactions(self, failure_mode, bounce_target, check_order):
+            bounce_target=["brokers", "clients"])
+    def test_transactions(self, failure_mode, bounce_target):
         security_protocol = 'PLAINTEXT'
         self.kafka.security_protocol = security_protocol
         self.kafka.interbroker_security_protocol = security_protocol
         self.kafka.logs["kafka_data_1"]["collect_default"] = True
         self.kafka.logs["kafka_data_2"]["collect_default"] = True
         self.kafka.logs["kafka_operational_logs_debug"]["collect_default"] = True
-        if check_order:
-            # To check ordering, we simply create input and output topics
-            # with a single partition.
-            # We reduce the number of seed messages to copy to account for the fewer output
-            # partitions, and thus lower parallelism. This helps keep the test
-            # time shorter.
-            self.num_seed_messages = self.num_seed_messages / 3
-            self.num_input_partitions = 1
-            self.num_output_partitions = 1
 
         self.setup_topics()
         self.kafka.start()
 
-        input_messages = self.seed_messages(self.input_topic, self.num_seed_messages)
-        concurrently_consumed_messages = self.copy_messages_transactionally(
+        input_messages_by_partition = self.seed_messages(self.input_topic, self.num_seed_messages)
+        concurrently_consumed_message_by_partition = self.copy_messages_transactionally(
             failure_mode, bounce_target, input_topic=self.input_topic,
-            output_topic=self.output_topic, num_copiers=self.num_input_partitions,
+            output_topic=self.output_topic, num_copiers=self.num_copiers,
             num_messages_to_copy=self.num_seed_messages)
-        output_messages = self.get_messages_from_topic(self.output_topic, self.num_seed_messages)
+        output_messages_by_partition = self.get_messages_from_topic(self.output_topic, self.num_seed_messages)
 
-        concurrently_consumed_message_set = set(concurrently_consumed_messages)
-        output_message_set = set(output_messages)
-        input_message_set = set(input_messages)
+        assert len(input_messages_by_partition) == \
+               len(concurrently_consumed_message_by_partition), "The lengths of partition count doesn't match: " \
+                                                                "input partitions count %d, " \
+                                                                "concurrently consumed partitions count %d" % \
+                                                                (len(input_messages_by_partition), len(concurrently_consumed_message_by_partition))
 
-        num_dups = abs(len(output_messages) - len(output_message_set))
-        num_dups_in_concurrent_consumer = abs(len(concurrently_consumed_messages)
+        assert len(input_messages_by_partition) == \
+               len(output_messages_by_partition), "The lengths of partition count doesn't match: " \
+                                                  "input partitions count %d, " \
+                                                  "output partitions count %d" % \
+                                                  (len(input_messages_by_partition), len(concurrently_consumed_message_by_partition))
+
+        for p in range(self.num_input_partitions):
+            if p not in input_messages_by_partition:
+                continue
+
+            assert p in output_messages_by_partition, "Partition %d not in output messages"
+            assert p in concurrently_consumed_message_by_partition, "Partition %d not in concurrently consumed messages"
+
+            output_message_set = set(output_messages_by_partition[p])
+            input_message_set = set(input_messages_by_partition[p])
+
+            concurrently_consumed_message_set = set(concurrently_consumed_message_by_partition[p])
+
+            num_dups = abs(len(output_messages) - len(output_message_set))
+            num_dups_in_concurrent_consumer = abs(len(concurrently_consumed_messages)
                                               - len(concurrently_consumed_message_set))
-        assert num_dups == 0, "Detected %d duplicates in the output stream" % num_dups
-        assert input_message_set == output_message_set, "Input and output message sets are not equal. Num input messages %d. Num output messages %d" %\
-            (len(input_message_set), len(output_message_set))
-
-        assert num_dups_in_concurrent_consumer == 0, "Detected %d dups in concurrently consumed messages" % num_dups_in_concurrent_consumer
-        assert input_message_set == concurrently_consumed_message_set, \
-            "Input and concurrently consumed output message sets are not equal. Num input messages: %d. Num concurrently_consumed_messages: %d" %\
-            (len(input_message_set), len(concurrently_consumed_message_set))
-        if check_order:
+            assert num_dups == 0, "Detected %d duplicates in the output stream" % num_dups
+            assert input_message_set == output_message_set, "Input and output message sets are not equal. Num input messages %d. Num output messages %d" % \
+                                                        (len(input_message_set), len(output_message_set))
+
+            assert num_dups_in_concurrent_consumer == 0, "Detected %d dups in concurrently consumed messages" % num_dups_in_concurrent_consumer
+            assert input_message_set == concurrently_consumed_message_set, \
+                "Input and concurrently consumed output message sets are not equal. Num input messages: %d. Num concurrently_consumed_messages: %d" % \
+                (len(input_message_set), len(concurrently_consumed_message_set))
+
             assert input_messages == sorted(input_messages), "The seed messages themselves were not in order"
             assert output_messages == input_messages, "Output messages are not in order"
             assert concurrently_consumed_messages == output_messages, "Concurrently consumed messages are not in order"
diff --git a/tests/kafkatest/tests/core/transactions_test.py b/tests/kafkatest/tests/core/transactions_test.py
index 4da5960..2889f84 100644
--- a/tests/kafkatest/tests/core/transactions_test.py
+++ b/tests/kafkatest/tests/core/transactions_test.py
@@ -31,7 +31,7 @@ class TransactionsTest(Test):
     a destination topic and killing the copy process as well as the broker
     randomly through the process. In the end we verify that the final output
     topic contains exactly one committed copy of each message in the input
-    topic
+    topic.
     """
     def __init__(self, test_context):
         """:type test_context: ducktape.tests.test.TestContext"""
@@ -47,6 +47,7 @@ class TransactionsTest(Test):
         self.num_output_partitions = 3
         self.num_seed_messages = 100000
         self.transaction_size = 750
+        self.transaction_timeout = 10000
         self.consumer_group = "transactions-test-consumer-group"
 
         self.zk = ZookeeperService(test_context, num_nodes=1)
@@ -69,7 +70,7 @@ class TransactionsTest(Test):
         seed_producer.start()
         wait_until(lambda: seed_producer.num_acked >= num_seed_messages,
                    timeout_sec=seed_timeout_sec,
-                   err_msg="Producer failed to produce messages %d in  %ds." %\
+                   err_msg="Producer failed to produce messages %d in %ds." %\
                    (self.num_seed_messages, seed_timeout_sec))
         return seed_producer.acked
 
@@ -89,7 +90,7 @@ class TransactionsTest(Test):
                            hard-killed broker %s" % str(node.account))
                 self.kafka.start_node(node)
 
-    def create_and_start_message_copier(self, input_topic, input_partition, output_topic, transactional_id):
+    def create_and_start_message_copier(self, input_topic, input_partition, output_topic, transactional_id, use_group_metadata):
         message_copier = TransactionalMessageCopier(
             context=self.test_context,
             num_nodes=1,
@@ -100,7 +101,9 @@ class TransactionsTest(Test):
             input_partition=input_partition,
             output_topic=output_topic,
             max_messages=-1,
-            transaction_size=self.transaction_size
+            transaction_size=self.transaction_size,
+            transaction_timeout=self.transaction_timeout,
+            use_group_metadata=use_group_metadata
         )
         message_copier.start()
         wait_until(lambda: message_copier.alive(message_copier.nodes[0]),
@@ -119,14 +122,15 @@ class TransactionsTest(Test):
                                                         str(copier.progress_percent())))
                 copier.restart(clean_shutdown)
 
-    def create_and_start_copiers(self, input_topic, output_topic, num_copiers):
+    def create_and_start_copiers(self, input_topic, output_topic, num_copiers, use_group_metadata):
         copiers = []
         for i in range(0, num_copiers):
             copiers.append(self.create_and_start_message_copier(
                 input_topic=input_topic,
                 output_topic=output_topic,
                 input_partition=i,
-                transactional_id="copier-" + str(i)
+                transactional_id="copier-" + str(i),
+                use_group_metadata=use_group_metadata
             ))
         return copiers
 
@@ -164,7 +168,8 @@ class TransactionsTest(Test):
 
     def copy_messages_transactionally(self, failure_mode, bounce_target,
                                       input_topic, output_topic,
-                                      num_copiers, num_messages_to_copy):
+                                      num_copiers, num_messages_to_copy,
+                                      use_group_metadata):
         """Copies messages transactionally from the seeded input topic to the
         output topic, either bouncing brokers or clients in a hard and soft
         way as it goes.
@@ -176,7 +181,8 @@ class TransactionsTest(Test):
         """
         copiers = self.create_and_start_copiers(input_topic=input_topic,
                                                 output_topic=output_topic,
-                                                num_copiers=num_copiers)
+                                                num_copiers=num_copiers,
+                                                use_group_metadata=use_group_metadata)
         concurrent_consumer = self.start_consumer(output_topic,
                                                   group_id="concurrent_consumer")
         clean_shutdown = False
@@ -188,11 +194,12 @@ class TransactionsTest(Test):
         elif bounce_target == "clients":
             self.bounce_copiers(copiers, clean_shutdown)
 
+        copier_timeout_sec = 120
         for copier in copiers:
             wait_until(lambda: copier.is_done,
-                       timeout_sec=120,
+                       timeout_sec=copier_timeout_sec,
                        err_msg="%s - Failed to copy all messages in  %ds." %\
-                       (copier.transactional_id, 120))
+                       (copier.transactional_id, copier_timeout_sec))
         self.logger.info("finished copying messages")
 
         return self.drain_consumer(concurrent_consumer, num_messages_to_copy)
@@ -218,8 +225,9 @@ class TransactionsTest(Test):
     @cluster(num_nodes=9)
     @matrix(failure_mode=["hard_bounce", "clean_bounce"],
             bounce_target=["brokers", "clients"],
-            check_order=[True, False])
-    def test_transactions(self, failure_mode, bounce_target, check_order):
+            check_order=[True, False],
+            use_group_metadata=[True, False])
+    def test_transactions(self, failure_mode, bounce_target, check_order, use_group_metadata):
         security_protocol = 'PLAINTEXT'
         self.kafka.security_protocol = security_protocol
         self.kafka.interbroker_security_protocol = security_protocol
@@ -243,7 +251,7 @@ class TransactionsTest(Test):
         concurrently_consumed_messages = self.copy_messages_transactionally(
             failure_mode, bounce_target, input_topic=self.input_topic,
             output_topic=self.output_topic, num_copiers=self.num_input_partitions,
-            num_messages_to_copy=self.num_seed_messages)
+            num_messages_to_copy=self.num_seed_messages, use_group_metadata=use_group_metadata)
         output_messages = self.get_messages_from_topic(self.output_topic, self.num_seed_messages)
 
         concurrently_consumed_message_set = set(concurrently_consumed_messages)
diff --git a/tools/src/main/java/org/apache/kafka/tools/TransactionalMessageCopier.java b/tools/src/main/java/org/apache/kafka/tools/TransactionalMessageCopier.java
index 1ea826d..13c18d0 100644
--- a/tools/src/main/java/org/apache/kafka/tools/TransactionalMessageCopier.java
+++ b/tools/src/main/java/org/apache/kafka/tools/TransactionalMessageCopier.java
@@ -22,6 +22,7 @@ import net.sourceforge.argparse4j.ArgumentParsers;
 import net.sourceforge.argparse4j.inf.ArgumentParser;
 import net.sourceforge.argparse4j.inf.Namespace;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
@@ -35,8 +36,12 @@ import org.apache.kafka.common.errors.OutOfOrderSequenceException;
 import org.apache.kafka.common.errors.ProducerFencedException;
 import org.apache.kafka.common.utils.Exit;
 
-import java.io.IOException;
+import java.text.DateFormat;
+import java.text.SimpleDateFormat;
 import java.time.Duration;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Date;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Properties;
@@ -54,6 +59,8 @@ import static net.sourceforge.argparse4j.impl.Arguments.storeTrue;
  */
 public class TransactionalMessageCopier {
 
+    private static final DateFormat FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss:SSS");
+
     /** Get the command-line argument parser. */
     private static ArgumentParser argParser() {
         ArgumentParser parser = ArgumentParsers
@@ -122,6 +129,15 @@ public class TransactionalMessageCopier {
                 .dest("messagesPerTransaction")
                 .help("The number of messages to put in each transaction. Default is 200.");
 
+        parser.addArgument("--transaction-timeout")
+                .action(store())
+                .required(false)
+                .setDefault(60000)
+                .type(Integer.class)
+                .metavar("TRANSACTION-TIMEOUT")
+                .dest("transactionTimeout")
+                .help("The transaction timeout in milliseconds. Default is 60000(1 minute).");
+
         parser.addArgument("--transactional-id")
                 .action(store())
                 .required(true)
@@ -137,16 +153,28 @@ public class TransactionalMessageCopier {
                 .dest("enableRandomAborts")
                 .help("Whether or not to enable random transaction aborts (for system testing)");
 
+        parser.addArgument("--group-mode")
+                .action(storeTrue())
+                .type(Boolean.class)
+                .metavar("GROUP-MODE")
+                .dest("groupMode")
+                .help("Whether to let consumer subscribe to the input topic or do manual assign. If we do" +
+                          " subscription based consumption, the input partition shall be ignored");
+
+        parser.addArgument("--use-group-metadata")
+                .action(storeTrue())
+                .type(Boolean.class)
+                .metavar("USE-GROUP-METADATA")
+                .dest("useGroupMetadata")
+                .help("Whether to use the new transactional commit API with group metadata");
+
         return parser;
     }
 
     private static KafkaProducer<String, String> createProducer(Namespace parsedArgs) {
-        String transactionalId = parsedArgs.getString("transactionalId");
-        String brokerList = parsedArgs.getString("brokerList");
-
         Properties props = new Properties();
-        props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList);
-        props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, transactionalId);
+        props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, parsedArgs.getString("brokerList"));
+        props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, parsedArgs.getString("transactionalId"));
         props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG,
                 "org.apache.kafka.common.serialization.StringSerializer");
         props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG,
@@ -156,6 +184,7 @@ public class TransactionalMessageCopier {
         // the case with multiple inflights.
         props.put(ProducerConfig.BATCH_SIZE_CONFIG, "512");
         props.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, "5");
+        props.put(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG, parsedArgs.getInt("transactionTimeout"));
 
         return new KafkaProducer<>(props);
     }
@@ -173,6 +202,7 @@ public class TransactionalMessageCopier {
         props.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, numMessagesPerTransaction.toString());
         props.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false");
         props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "10000");
+        props.put(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, "180000");
         props.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "3000");
         props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
         props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG,
@@ -184,7 +214,7 @@ public class TransactionalMessageCopier {
     }
 
     private static ProducerRecord<String, String> producerRecordFromConsumerRecord(String topic, ConsumerRecord<String, String> record) {
-        return new ProducerRecord<>(topic, record.key(), record.value());
+        return new ProducerRecord<>(topic, record.partition(), record.key(), record.value());
     }
 
     private static Map<TopicPartition, OffsetAndMetadata> consumerPositions(KafkaConsumer<String, String> consumer) {
@@ -226,45 +256,72 @@ public class TransactionalMessageCopier {
         return json;
     }
 
-    private static String statusAsJson(long consumed, long remaining, String transactionalId) {
+    private static synchronized String statusAsJson(long totalProcessed, long consumedSinceLastRebalanced, long remaining, String transactionalId, String stage) {
         Map<String, Object> statusData = new HashMap<>();
         statusData.put("progress", transactionalId);
-        statusData.put("consumed", consumed);
+        statusData.put("totalProcessed", totalProcessed);
+        statusData.put("consumed", consumedSinceLastRebalanced);
         statusData.put("remaining", remaining);
+        statusData.put("time", FORMAT.format(new Date()));
+        statusData.put("stage", stage);
         return toJsonString(statusData);
     }
 
-    private static String shutDownString(long consumed, long remaining, String transactionalId) {
+    private static synchronized String shutDownString(long totalProcessed, long consumedSinceLastRebalanced, long remaining, String transactionalId) {
         Map<String, Object> shutdownData = new HashMap<>();
-        shutdownData.put("remaining", remaining);
-        shutdownData.put("consumed", consumed);
         shutdownData.put("shutdown_complete", transactionalId);
+        shutdownData.put("totalProcessed", totalProcessed);
+        shutdownData.put("consumed", consumedSinceLastRebalanced);
+        shutdownData.put("remaining", remaining);
+        shutdownData.put("time", FORMAT.format(new Date()));
         return toJsonString(shutdownData);
     }
 
-    public static void main(String[] args) throws IOException {
+    public static void main(String[] args) {
         Namespace parsedArgs = argParser().parseArgsOrFail(args);
-        Integer numMessagesPerTransaction = parsedArgs.getInt("messagesPerTransaction");
         final String transactionalId = parsedArgs.getString("transactionalId");
         final String outputTopic = parsedArgs.getString("outputTopic");
 
         String consumerGroup = parsedArgs.getString("consumerGroup");
-        TopicPartition inputPartition = new TopicPartition(parsedArgs.getString("inputTopic"), parsedArgs.getInt("inputPartition"));
 
         final KafkaProducer<String, String> producer = createProducer(parsedArgs);
         final KafkaConsumer<String, String> consumer = createConsumer(parsedArgs);
 
-        consumer.assign(singleton(inputPartition));
+        final AtomicLong remainingMessages = new AtomicLong(
+            parsedArgs.getInt("maxMessages") == -1 ? Long.MAX_VALUE : parsedArgs.getInt("maxMessages"));
+
+        boolean groupMode = parsedArgs.getBoolean("groupMode");
+        String topicName = parsedArgs.getString("inputTopic");
+        final AtomicLong numMessagesProcessedSinceLastRebalance = new AtomicLong(0);
+        final AtomicLong totalMessageProcessed = new AtomicLong(0);
+        if (groupMode) {
+            consumer.subscribe(Collections.singleton(topicName), new ConsumerRebalanceListener() {
+                @Override
+                public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
+                }
+
+                @Override
+                public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
+                    remainingMessages.set(partitions.stream()
+                        .mapToLong(partition -> messagesRemaining(consumer, partition)).sum());
+                    numMessagesProcessedSinceLastRebalance.set(0);
+                    // We use message cap for remaining here as the remainingMessages are not set yet.
+                    System.out.println(statusAsJson(totalMessageProcessed.get(),
+                        numMessagesProcessedSinceLastRebalance.get(), remainingMessages.get(), transactionalId, "RebalanceComplete"));
+                }
+            });
+        } else {
+            TopicPartition inputPartition = new TopicPartition(topicName, parsedArgs.getInt("inputPartition"));
+            consumer.assign(singleton(inputPartition));
+            remainingMessages.set(Math.min(messagesRemaining(consumer, inputPartition), remainingMessages.get()));
+        }
 
-        long maxMessages = parsedArgs.getInt("maxMessages") == -1 ? Long.MAX_VALUE : parsedArgs.getInt("maxMessages");
-        maxMessages = Math.min(messagesRemaining(consumer, inputPartition), maxMessages);
         final boolean enableRandomAborts = parsedArgs.getBoolean("enableRandomAborts");
 
         producer.initTransactions();
 
         final AtomicBoolean isShuttingDown = new AtomicBoolean(false);
-        final AtomicLong remainingMessages = new AtomicLong(maxMessages);
-        final AtomicLong numMessagesProcessed = new AtomicLong(0);
+
         Exit.addShutdownHook("transactional-message-copier-shutdown-hook", () -> {
             isShuttingDown.set(true);
             // Flush any remaining messages
@@ -272,41 +329,51 @@ public class TransactionalMessageCopier {
             synchronized (consumer) {
                 consumer.close();
             }
-            System.out.println(shutDownString(numMessagesProcessed.get(), remainingMessages.get(), transactionalId));
+            System.out.println(shutDownString(totalMessageProcessed.get(),
+                numMessagesProcessedSinceLastRebalance.get(), remainingMessages.get(), transactionalId));
         });
 
+        final boolean useGroupMetadata = parsedArgs.getBoolean("useGroupMetadata");
         try {
             Random random = new Random();
-            while (0 < remainingMessages.get()) {
-                System.out.println(statusAsJson(numMessagesProcessed.get(), remainingMessages.get(), transactionalId));
+            while (remainingMessages.get() > 0) {
+                System.out.println(statusAsJson(totalMessageProcessed.get(),
+                    numMessagesProcessedSinceLastRebalance.get(), remainingMessages.get(), transactionalId, "ProcessLoop"));
                 if (isShuttingDown.get())
                     break;
-                int messagesInCurrentTransaction = 0;
-                long numMessagesForNextTransaction = Math.min(numMessagesPerTransaction, remainingMessages.get());
 
-                try {
-                    producer.beginTransaction();
-                    while (messagesInCurrentTransaction < numMessagesForNextTransaction) {
-                        ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(200));
+                ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(200));
+                if (records.count() > 0) {
+                    try {
+                        producer.beginTransaction();
+
                         for (ConsumerRecord<String, String> record : records) {
                             producer.send(producerRecordFromConsumerRecord(outputTopic, record));
-                            messagesInCurrentTransaction++;
                         }
-                    }
-                    producer.sendOffsetsToTransaction(consumerPositions(consumer), consumerGroup);
 
-                    if (enableRandomAborts && random.nextInt() % 3 == 0) {
-                        throw new KafkaException("Aborting transaction");
-                    } else {
-                        producer.commitTransaction();
-                        remainingMessages.set(maxMessages - numMessagesProcessed.addAndGet(messagesInCurrentTransaction));
+                        long messagesSentWithinCurrentTxn = records.count();
+
+                        if (useGroupMetadata) {
+                            producer.sendOffsetsToTransaction(consumerPositions(consumer), consumer.groupMetadata());
+                        } else {
+                            producer.sendOffsetsToTransaction(consumerPositions(consumer), consumerGroup);
+                        }
+
+                        if (enableRandomAborts && random.nextInt() % 3 == 0) {
+                            throw new KafkaException("Aborting transaction");
+                        } else {
+                            producer.commitTransaction();
+                            remainingMessages.getAndAdd(-messagesSentWithinCurrentTxn);
+                            numMessagesProcessedSinceLastRebalance.getAndAdd(messagesSentWithinCurrentTxn);
+                            totalMessageProcessed.getAndAdd(messagesSentWithinCurrentTxn);
+                        }
+                    } catch (ProducerFencedException | OutOfOrderSequenceException e) {
+                        // We cannot recover from these errors, so just rethrow them and let the process fail
+                        throw e;
+                    } catch (KafkaException e) {
+                        producer.abortTransaction();
+                        resetToLastCommittedPositions(consumer);
                     }
-                } catch (ProducerFencedException | OutOfOrderSequenceException e) {
-                    // We cannot recover from these errors, so just rethrow them and let the process fail
-                    throw e;
-                } catch (KafkaException e) {
-                    producer.abortTransaction();
-                    resetToLastCommittedPositions(consumer);
                 }
             }
         } finally {
diff --git a/tools/src/main/java/org/apache/kafka/tools/VerifiableConsumer.java b/tools/src/main/java/org/apache/kafka/tools/VerifiableConsumer.java
index 9cad90f..f34b9e2 100644
--- a/tools/src/main/java/org/apache/kafka/tools/VerifiableConsumer.java
+++ b/tools/src/main/java/org/apache/kafka/tools/VerifiableConsumer.java
@@ -159,8 +159,9 @@ public class VerifiableConsumer implements Closeable, OffsetCommitCallback, Cons
                     partitionRecords.size(), minOffset, maxOffset));
 
             if (verbose) {
-                for (ConsumerRecord<String, String> record : partitionRecords)
+                for (ConsumerRecord<String, String> record : partitionRecords) {
                     printJson(new RecordData(record));
+                }
             }
 
             consumedMessages += partitionRecords.size();
@@ -595,10 +596,7 @@ public class VerifiableConsumer implements Closeable, OffsetCommitCallback, Cons
     public static VerifiableConsumer createFromArgs(ArgumentParser parser, String[] args) throws ArgumentParserException {
         Namespace res = parser.parseArgs(args);
 
-        String topic = res.getString("topic");
         boolean useAutoCommit = res.getBoolean("useAutoCommit");
-        int maxMessages = res.getInt("maxMessages");
-        boolean verbose = res.getBoolean("verbose");
         String configFile = res.getString("consumer.config");
 
         Properties consumerProps = new Properties();
@@ -625,6 +623,10 @@ public class VerifiableConsumer implements Closeable, OffsetCommitCallback, Cons
         StringDeserializer deserializer = new StringDeserializer();
         KafkaConsumer<String, String> consumer = new KafkaConsumer<>(consumerProps, deserializer, deserializer);
 
+        String topic = res.getString("topic");
+        int maxMessages = res.getInt("maxMessages");
+        boolean verbose = res.getBoolean("verbose");
+
         return new VerifiableConsumer(
                 consumer,
                 System.out,


Mime
View raw message