kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From guozh...@apache.org
Subject [1/4] kafka git commit: KAFKA-3888: send consumer heartbeats from a background thread (KIP-62)
Date Wed, 17 Aug 2016 18:50:08 GMT
Repository: kafka
Updated Branches:
  refs/heads/trunk 19997ede0 -> 40b1dd3f4


http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/main/scala/kafka/coordinator/MemberMetadata.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/MemberMetadata.scala b/core/src/main/scala/kafka/coordinator/MemberMetadata.scala
index 19c9e8e..6149276 100644
--- a/core/src/main/scala/kafka/coordinator/MemberMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/MemberMetadata.scala
@@ -55,6 +55,7 @@ private[coordinator] class MemberMetadata(val memberId: String,
                                           val groupId: String,
                                           val clientId: String,
                                           val clientHost: String,
+                                          val rebalanceTimeoutMs: Int,
                                           val sessionTimeoutMs: Int,
                                           val protocolType: String,
                                           var supportedProtocols: List[(String, Array[Byte])]) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/main/scala/kafka/server/KafkaApis.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala
index 6d38f85..bb219ca 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -890,8 +890,8 @@ class KafkaApis(val requestChannel: RequestChannel,
     // the callback for sending a join-group response
     def sendResponseCallback(joinResult: JoinGroupResult) {
       val members = joinResult.members map { case (memberId, metadataArray) => (memberId, ByteBuffer.wrap(metadataArray)) }
-      val responseBody = new JoinGroupResponse(joinResult.errorCode, joinResult.generationId, joinResult.subProtocol,
-        joinResult.memberId, joinResult.leaderId, members)
+      val responseBody = new JoinGroupResponse(request.header.apiVersion, joinResult.errorCode, joinResult.generationId,
+        joinResult.subProtocol, joinResult.memberId, joinResult.leaderId, members)
 
       trace("Sending join group response %s for correlation id %d to client %s."
         .format(responseBody, request.header.correlationId, request.header.clientId))
@@ -900,6 +900,7 @@ class KafkaApis(val requestChannel: RequestChannel,
 
     if (!authorize(request.session, Read, new Resource(Group, joinGroupRequest.groupId()))) {
       val responseBody = new JoinGroupResponse(
+        request.header.apiVersion,
         Errors.GROUP_AUTHORIZATION_FAILED.code,
         JoinGroupResponse.UNKNOWN_GENERATION_ID,
         JoinGroupResponse.UNKNOWN_PROTOCOL,
@@ -916,6 +917,7 @@ class KafkaApis(val requestChannel: RequestChannel,
         joinGroupRequest.memberId,
         request.header.clientId,
         request.session.clientAddress.toString,
+        joinGroupRequest.rebalanceTimeout,
         joinGroupRequest.sessionTimeout,
         joinGroupRequest.protocolType,
         protocols,

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
index 817cdf7..1a5f187 100644
--- a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
@@ -199,7 +199,7 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
   }
 
   private def createJoinGroupRequest = {
-    new JoinGroupRequest(group, 30000, "", "consumer",
+    new JoinGroupRequest(group, 10000, 60000, "", "consumer",
       List( new JoinGroupRequest.ProtocolMetadata("consumer-range",ByteBuffer.wrap("test".getBytes()))).asJava)
   }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala b/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
index f039750..c13bf58 100644
--- a/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
@@ -13,6 +13,7 @@
 package kafka.api
 
 import java.util
+
 import org.apache.kafka.clients.consumer._
 import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord}
 import org.apache.kafka.common.record.TimestampType
@@ -22,11 +23,14 @@ import kafka.utils.{TestUtils, Logging, ShutdownableThread}
 import kafka.common.Topic
 import kafka.server.KafkaConfig
 import java.util.ArrayList
+
 import org.junit.Assert._
 import org.junit.{Before, Test}
+
 import scala.collection.JavaConverters._
 import scala.collection.mutable.Buffer
 import org.apache.kafka.clients.producer.KafkaProducer
+import org.apache.kafka.common.errors.WakeupException
 
 /**
  * Integration tests for the new consumer that cover basic usage as well as server failures
@@ -82,112 +86,19 @@ abstract class BaseConsumerTest extends IntegrationTestHarness with Logging {
   }
 
   @Test
-  def testAutoCommitOnRebalance() {
-    val topic2 = "topic2"
-    TestUtils.createTopic(this.zkUtils, topic2, 2, serverCount, this.servers)
-
-    this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true")
-    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
-    consumers += consumer0
-
-    val numRecords = 10000
-    sendRecords(numRecords)
-
-    val rebalanceListener = new ConsumerRebalanceListener {
-      override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) = {
-        // keep partitions paused in this test so that we can verify the commits based on specific seeks
-        consumer0.pause(partitions)
-      }
-
-      override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) = {}
-    }
-
-    consumer0.subscribe(List(topic).asJava, rebalanceListener)
-
-    val assignment = Set(tp, tp2)
-    TestUtils.waitUntilTrue(() => {
-      consumer0.poll(50)
-      consumer0.assignment() == assignment.asJava
-    }, s"Expected partitions ${assignment.asJava} but actually got ${consumer0.assignment()}")
-
-    consumer0.seek(tp, 300)
-    consumer0.seek(tp2, 500)
-
-    // change subscription to trigger rebalance
-    consumer0.subscribe(List(topic, topic2).asJava, rebalanceListener)
-
-    val newAssignment = Set(tp, tp2, new TopicPartition(topic2, 0), new TopicPartition(topic2, 1))
-    TestUtils.waitUntilTrue(() => {
-      val records = consumer0.poll(50)
-      consumer0.assignment() == newAssignment.asJava
-    }, s"Expected partitions ${newAssignment.asJava} but actually got ${consumer0.assignment()}")
-
-    // after rebalancing, we should have reset to the committed positions
-    assertEquals(300, consumer0.committed(tp).offset)
-    assertEquals(500, consumer0.committed(tp2).offset)
-  }
-
-  @Test
-  def testCommitSpecifiedOffsets() {
-    sendRecords(5, tp)
-    sendRecords(7, tp2)
-
-    this.consumers.head.assign(List(tp, tp2).asJava)
-
-    // Need to poll to join the group
-    this.consumers.head.poll(50)
-    val pos1 = this.consumers.head.position(tp)
-    val pos2 = this.consumers.head.position(tp2)
-    this.consumers.head.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp, new OffsetAndMetadata(3L))).asJava)
-    assertEquals(3, this.consumers.head.committed(tp).offset)
-    assertNull(this.consumers.head.committed(tp2))
-
-    // Positions should not change
-    assertEquals(pos1, this.consumers.head.position(tp))
-    assertEquals(pos2, this.consumers.head.position(tp2))
-    this.consumers.head.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp2, new OffsetAndMetadata(5L))).asJava)
-    assertEquals(3, this.consumers.head.committed(tp).offset)
-    assertEquals(5, this.consumers.head.committed(tp2).offset)
-
-    // Using async should pick up the committed changes after commit completes
-    val commitCallback = new CountConsumerCommitCallback()
-    this.consumers.head.commitAsync(Map[TopicPartition, OffsetAndMetadata]((tp2, new OffsetAndMetadata(7L))).asJava, commitCallback)
-    awaitCommitCallback(this.consumers.head, commitCallback)
-    assertEquals(7, this.consumers.head.committed(tp2).offset)
-  }
-
-  @Test
-  def testListTopics() {
-    val numParts = 2
-    val topic1 = "part-test-topic-1"
-    val topic2 = "part-test-topic-2"
-    val topic3 = "part-test-topic-3"
-    TestUtils.createTopic(this.zkUtils, topic1, numParts, 1, this.servers)
-    TestUtils.createTopic(this.zkUtils, topic2, numParts, 1, this.servers)
-    TestUtils.createTopic(this.zkUtils, topic3, numParts, 1, this.servers)
-
-    val topics = this.consumers.head.listTopics()
-    assertNotNull(topics)
-    assertEquals(5, topics.size())
-    assertEquals(5, topics.keySet().size())
-    assertEquals(2, topics.get(topic1).size)
-    assertEquals(2, topics.get(topic2).size)
-    assertEquals(2, topics.get(topic3).size)
-  }
-
-  @Test
-  def testPartitionReassignmentCallback() {
+  def testCoordinatorFailover() {
     val listener = new TestConsumerReassignmentListener()
-    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test
-    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
+    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "5000")
+    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "2000")
     val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
     consumers += consumer0
 
     consumer0.subscribe(List(topic).asJava, listener)
 
     // the initial subscription should cause a callback execution
-    while (listener.callsToAssigned == 0)
-      consumer0.poll(50)
+    consumer0.poll(2000)
+
+    assertEquals(1, listener.callsToAssigned)
 
     // get metadata for the topic
     var parts: Seq[PartitionInfo] = null
@@ -200,54 +111,13 @@ abstract class BaseConsumerTest extends IntegrationTestHarness with Logging {
     val coordinator = parts.head.leader().id()
     this.servers(coordinator).shutdown()
 
-    // this should cause another callback execution
-    while (listener.callsToAssigned < 2)
-      consumer0.poll(50)
-
-    assertEquals(2, listener.callsToAssigned)
-
-    // only expect one revocation since revoke is not invoked on initial membership
-    assertEquals(2, listener.callsToRevoked)
-  }
-
-  @Test
-  def testUnsubscribeTopic() {
-
-    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test
-    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
-    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
-    consumers += consumer0
-
-    val listener = new TestConsumerReassignmentListener()
-    consumer0.subscribe(List(topic).asJava, listener)
-
-    // the initial subscription should cause a callback execution
-    while (listener.callsToAssigned == 0)
-      consumer0.poll(50)
+    consumer0.poll(5000)
 
-    consumer0.subscribe(List[String]().asJava)
-    assertEquals(0, consumer0.assignment.size())
+    // the failover should not cause a rebalance
+    assertEquals(1, listener.callsToAssigned)
+    assertEquals(1, listener.callsToRevoked)
   }
 
-  @Test
-  def testPauseStateNotPreservedByRebalance() {
-    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test
-    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
-    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
-    consumers += consumer0
-
-    sendRecords(5)
-    consumer0.subscribe(List(topic).asJava)
-    consumeAndVerifyRecords(consumer = consumer0, numRecords = 5, startingOffset = 0)
-    consumer0.pause(List(tp).asJava)
-
-    // subscribe to a new topic to trigger a rebalance
-    consumer0.subscribe(List("topic2").asJava)
-
-    // after rebalance, our position should be reset and our pause state lost,
-    // so we should be able to consume from the beginning
-    consumeAndVerifyRecords(consumer = consumer0, numRecords = 0, startingOffset = 5)
-  }
 
   protected class TestConsumerReassignmentListener extends ConsumerRebalanceListener {
     var callsToAssigned = 0
@@ -394,12 +264,22 @@ abstract class BaseConsumerTest extends IntegrationTestHarness with Logging {
       !subscriptionChanged
     }
 
+    override def initiateShutdown(): Boolean = {
+      val res = super.initiateShutdown()
+      consumer.wakeup()
+      res
+    }
+
     override def doWork(): Unit = {
       if (subscriptionChanged) {
         consumer.subscribe(topicsSubscription.asJava, rebalanceListener)
         subscriptionChanged = false
       }
-      consumer.poll(50)
+      try {
+        consumer.poll(50)
+      } catch {
+        case e: WakeupException => // ignore for shutdown
+      }
     }
   }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala b/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
index 7064052..0900d43 100644
--- a/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
+++ b/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
@@ -49,8 +49,8 @@ class ConsumerBounceTest extends IntegrationTestHarness with Logging {
   this.producerConfig.setProperty(ProducerConfig.ACKS_CONFIG, "all")
   this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "my-test")
   this.consumerConfig.setProperty(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG, 4096.toString)
-  this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100")
-  this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
+  this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "10000")
+  this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "3000")
   this.consumerConfig.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
 
   override def generateConfigs() = {
@@ -81,14 +81,7 @@ class ConsumerBounceTest extends IntegrationTestHarness with Logging {
     var consumed = 0L
     val consumer = this.consumers.head
 
-    consumer.subscribe(List(topic), new ConsumerRebalanceListener {
-      override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) {
-        // TODO: until KAFKA-2017 is merged, we have to handle the case in which
-        // the commit fails on prior to rebalancing on coordinator fail-over.
-        consumer.seek(tp, consumed)
-      }
-      override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) {}
-    })
+    consumer.subscribe(List(topic))
 
     val scheduler = new BounceBrokerScheduler(numIters)
     scheduler.start()

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
index b1e9676..243f913 100644
--- a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
@@ -58,6 +58,31 @@ class PlaintextConsumerTest extends BaseConsumerTest {
   }
 
   @Test
+  def testMaxPollIntervalMs() {
+    this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, 3000.toString)
+    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 500.toString)
+    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 2000.toString)
+
+    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
+    consumers += consumer0
+
+    val listener = new TestConsumerReassignmentListener()
+    consumer0.subscribe(List(topic).asJava, listener)
+
+    // poll once to get the initial assignment
+    consumer0.poll(0)
+    assertEquals(1, listener.callsToAssigned)
+    assertEquals(1, listener.callsToRevoked)
+
+    Thread.sleep(3500)
+
+    // we should fall out of the group and need to rebalance
+    consumer0.poll(0)
+    assertEquals(2, listener.callsToAssigned)
+    assertEquals(2, listener.callsToRevoked)
+  }
+
+  @Test
   def testAutoCommitOnClose() {
     this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true")
     val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
@@ -593,16 +618,14 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     // create a group of consumers, subscribe the consumers to all the topics and start polling
     // for the topic partition assignment
     val (rrConsumers, consumerPollers) = createConsumerGroupAndWaitForAssignment(10, List(topic1, topic2), subscriptions)
+    try {
+      validateGroupAssignment(consumerPollers, subscriptions, s"Did not get valid initial assignment for partitions ${subscriptions.asJava}")
 
-    // add one more consumer and validate re-assignment
-    addConsumersToGroupAndWaitForGroupAssignment(1, consumers, consumerPollers, List(topic1, topic2), subscriptions)
-
-    // done with pollers and consumers
-    for (poller <- consumerPollers)
-      poller.shutdown()
-
-    for (consumer <- consumers)
-      consumer.unsubscribe()
+      // add one more consumer and validate re-assignment
+      addConsumersToGroupAndWaitForGroupAssignment(1, consumers, consumerPollers, List(topic1, topic2), subscriptions)
+    } finally {
+      consumerPollers.foreach(_.shutdown())
+    }
   }
 
   /**
@@ -618,25 +641,25 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     val subscriptions = Set(tp, tp2) ++ createTopicAndSendRecords(topic1, 5, 100)
 
     // subscribe all consumers to all topics and validate the assignment
-    val consumerPollers = subscribeConsumersAndWaitForAssignment(consumers, List(topic, topic1), subscriptions)
+    val consumerPollers = subscribeConsumers(consumers, List(topic, topic1))
 
-    // add 2 more consumers and validate re-assignment
-    addConsumersToGroupAndWaitForGroupAssignment(2, consumers, consumerPollers, List(topic, topic1), subscriptions)
+    try {
+      validateGroupAssignment(consumerPollers, subscriptions, s"Did not get valid initial assignment for partitions ${subscriptions.asJava}")
 
-    // add one more topic and validate partition re-assignment
-    val topic2 = "topic2"
-    val expandedSubscriptions = subscriptions ++ createTopicAndSendRecords(topic2, 3, 100)
-    changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers, List(topic, topic1, topic2), expandedSubscriptions)
+      // add 2 more consumers and validate re-assignment
+      addConsumersToGroupAndWaitForGroupAssignment(2, consumers, consumerPollers, List(topic, topic1), subscriptions)
 
-    // remove the topic we just added and validate re-assignment
-    changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers, List(topic, topic1), subscriptions)
+      // add one more topic and validate partition re-assignment
+      val topic2 = "topic2"
+      val expandedSubscriptions = subscriptions ++ createTopicAndSendRecords(topic2, 3, 100)
+      changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers, List(topic, topic1, topic2), expandedSubscriptions)
 
-    // done with pollers and consumers
-    for (poller <- consumerPollers)
-      poller.shutdown()
+      // remove the topic we just added and validate re-assignment
+      changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers, List(topic, topic1), subscriptions)
 
-    for (consumer <- consumers)
-      consumer.unsubscribe()
+    } finally {
+      consumerPollers.foreach(_.shutdown())
+    }
   }
 
   @Test
@@ -830,6 +853,138 @@ class PlaintextConsumerTest extends BaseConsumerTest {
       startingTimestamp = startTime, timestampType = TimestampType.LOG_APPEND_TIME)
   }
 
+  @Test
+  def testListTopics() {
+    val numParts = 2
+    val topic1 = "part-test-topic-1"
+    val topic2 = "part-test-topic-2"
+    val topic3 = "part-test-topic-3"
+    TestUtils.createTopic(this.zkUtils, topic1, numParts, 1, this.servers)
+    TestUtils.createTopic(this.zkUtils, topic2, numParts, 1, this.servers)
+    TestUtils.createTopic(this.zkUtils, topic3, numParts, 1, this.servers)
+
+    val topics = this.consumers.head.listTopics()
+    assertNotNull(topics)
+    assertEquals(5, topics.size())
+    assertEquals(5, topics.keySet().size())
+    assertEquals(2, topics.get(topic1).size)
+    assertEquals(2, topics.get(topic2).size)
+    assertEquals(2, topics.get(topic3).size)
+  }
+
+  @Test
+  def testUnsubscribeTopic() {
+    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test
+    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
+    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
+    consumers += consumer0
+
+    val listener = new TestConsumerReassignmentListener()
+    consumer0.subscribe(List(topic).asJava, listener)
+
+    // the initial subscription should cause a callback execution
+    while (listener.callsToAssigned == 0)
+      consumer0.poll(50)
+
+    consumer0.subscribe(List[String]().asJava)
+    assertEquals(0, consumer0.assignment.size())
+  }
+
+  @Test
+  def testPauseStateNotPreservedByRebalance() {
+    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test
+    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
+    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
+    consumers += consumer0
+
+    sendRecords(5)
+    consumer0.subscribe(List(topic).asJava)
+    consumeAndVerifyRecords(consumer = consumer0, numRecords = 5, startingOffset = 0)
+    consumer0.pause(List(tp).asJava)
+
+    // subscribe to a new topic to trigger a rebalance
+    consumer0.subscribe(List("topic2").asJava)
+
+    // after rebalance, our position should be reset and our pause state lost,
+    // so we should be able to consume from the beginning
+    consumeAndVerifyRecords(consumer = consumer0, numRecords = 0, startingOffset = 5)
+  }
+
+  @Test
+  def testCommitSpecifiedOffsets() {
+    sendRecords(5, tp)
+    sendRecords(7, tp2)
+
+    this.consumers.head.assign(List(tp, tp2).asJava)
+
+    // Need to poll to join the group
+    this.consumers.head.poll(50)
+    val pos1 = this.consumers.head.position(tp)
+    val pos2 = this.consumers.head.position(tp2)
+    this.consumers.head.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp, new OffsetAndMetadata(3L))).asJava)
+    assertEquals(3, this.consumers.head.committed(tp).offset)
+    assertNull(this.consumers.head.committed(tp2))
+
+    // Positions should not change
+    assertEquals(pos1, this.consumers.head.position(tp))
+    assertEquals(pos2, this.consumers.head.position(tp2))
+    this.consumers.head.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp2, new OffsetAndMetadata(5L))).asJava)
+    assertEquals(3, this.consumers.head.committed(tp).offset)
+    assertEquals(5, this.consumers.head.committed(tp2).offset)
+
+    // Using async should pick up the committed changes after commit completes
+    val commitCallback = new CountConsumerCommitCallback()
+    this.consumers.head.commitAsync(Map[TopicPartition, OffsetAndMetadata]((tp2, new OffsetAndMetadata(7L))).asJava, commitCallback)
+    awaitCommitCallback(this.consumers.head, commitCallback)
+    assertEquals(7, this.consumers.head.committed(tp2).offset)
+  }
+
+  @Test
+  def testAutoCommitOnRebalance() {
+    val topic2 = "topic2"
+    TestUtils.createTopic(this.zkUtils, topic2, 2, serverCount, this.servers)
+
+    this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true")
+    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
+    consumers += consumer0
+
+    val numRecords = 10000
+    sendRecords(numRecords)
+
+    val rebalanceListener = new ConsumerRebalanceListener {
+      override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) = {
+        // keep partitions paused in this test so that we can verify the commits based on specific seeks
+        consumer0.pause(partitions)
+      }
+
+      override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) = {}
+    }
+
+    consumer0.subscribe(List(topic).asJava, rebalanceListener)
+
+    val assignment = Set(tp, tp2)
+    TestUtils.waitUntilTrue(() => {
+      consumer0.poll(50)
+      consumer0.assignment() == assignment.asJava
+    }, s"Expected partitions ${assignment.asJava} but actually got ${consumer0.assignment()}")
+
+    consumer0.seek(tp, 300)
+    consumer0.seek(tp2, 500)
+
+    // change subscription to trigger rebalance
+    consumer0.subscribe(List(topic, topic2).asJava, rebalanceListener)
+
+    val newAssignment = Set(tp, tp2, new TopicPartition(topic2, 0), new TopicPartition(topic2, 1))
+    TestUtils.waitUntilTrue(() => {
+      val records = consumer0.poll(50)
+      consumer0.assignment() == newAssignment.asJava
+    }, s"Expected partitions ${newAssignment.asJava} but actually got ${consumer0.assignment()}")
+
+    // after rebalancing, we should have reset to the committed positions
+    assertEquals(300, consumer0.committed(tp).offset)
+    assertEquals(500, consumer0.committed(tp2).offset)
+  }
+
   def runMultiConsumerSessionTimeoutTest(closeConsumer: Boolean): Unit = {
     // use consumers defined in this class plus one additional consumer
     // Use topic defined in this class + one additional topic
@@ -887,7 +1042,8 @@ class PlaintextConsumerTest extends BaseConsumerTest {
    * Subscribes consumer 'consumer' to a given list of topics 'topicsToSubscribe', creates
    * consumer poller and starts polling.
    * Assumes that the consumer is not subscribed to any topics yet
-   * @param consumer consumer
+    *
+    * @param consumer consumer
    * @param topicsToSubscribe topics that this consumer will subscribe to
    * @return consumer poller for the given consumer
    */
@@ -901,34 +1057,25 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
   /**
    * Creates consumer pollers corresponding to a given consumer group, one per consumer; subscribes consumers to
-   * 'topicsToSubscribe' topics, waits until consumers get topics assignment, and validates the assignment
-   * Currently, assignment validation requires that total number of partitions is greater or equal to
-   * number of consumers (i.e. subscriptions.size >= consumerGroup.size)
-   * Assumes that topics are already created with partitions corresponding to a given set of topic partitions ('subscriptions')
+   * 'topicsToSubscribe' topics, waits until consumers get topics assignment.
    *
    * When the function returns, consumer pollers will continue to poll until shutdown is called on every poller.
    *
    * @param consumerGroup consumer group
    * @param topicsToSubscribe topics to which consumers will subscribe to
-   * @param subscriptions set of all topic partitions
    * @return collection of consumer pollers
    */
-  def subscribeConsumersAndWaitForAssignment(consumerGroup: Buffer[KafkaConsumer[Array[Byte], Array[Byte]]],
-                                             topicsToSubscribe: List[String],
-                                             subscriptions: Set[TopicPartition]): Buffer[ConsumerAssignmentPoller] = {
+  def subscribeConsumers(consumerGroup: Buffer[KafkaConsumer[Array[Byte], Array[Byte]]],
+                         topicsToSubscribe: List[String]): Buffer[ConsumerAssignmentPoller] = {
     val consumerPollers = Buffer[ConsumerAssignmentPoller]()
     for (consumer <- consumerGroup)
       consumerPollers += subscribeConsumerAndStartPolling(consumer, topicsToSubscribe)
-    validateGroupAssignment(consumerPollers, subscriptions, s"Did not get valid initial assignment for partitions ${subscriptions.asJava}")
     consumerPollers
   }
 
   /**
    * Creates 'consumerCount' consumers and consumer pollers, one per consumer; subscribes consumers to
-   * 'topicsToSubscribe' topics, waits until consumers get topics assignment, and validates the assignment
-   * Currently, assignment validation requires that total number of partitions is greater or equal to
-   * number of consumers (i.e. subscriptions.size >= consumerCount)
-   * Assumes that topics are already created with partitions corresponding to a given set of topic partitions ('subscriptions')
+   * 'topicsToSubscribe' topics, waits until consumers get topics assignment.
    *
    * When the function returns, consumer pollers will continue to poll until shutdown is called on every poller.
    *
@@ -947,7 +1094,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumers ++= consumerGroup
 
     // create consumer pollers, wait for assignment and validate it
-    val consumerPollers = subscribeConsumersAndWaitForAssignment(consumerGroup, topicsToSubscribe, subscriptions)
+    val consumerPollers = subscribeConsumers(consumerGroup, topicsToSubscribe)
 
     (consumerGroup, consumerPollers)
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
index 63636c0..591479e 100644
--- a/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
@@ -16,7 +16,6 @@
   */
 package kafka.api
 
-import kafka.server.KafkaConfig
 import org.apache.kafka.common.protocol.SecurityProtocol
 
 class SaslPlainSslEndToEndAuthorizationTest extends EndToEndAuthorizationTest {

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/unit/kafka/coordinator/GroupCoordinatorResponseTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/GroupCoordinatorResponseTest.scala b/core/src/test/scala/unit/kafka/coordinator/GroupCoordinatorResponseTest.scala
index c917ca4..a981e68 100644
--- a/core/src/test/scala/unit/kafka/coordinator/GroupCoordinatorResponseTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/GroupCoordinatorResponseTest.scala
@@ -54,6 +54,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   val ClientHost = "localhost"
   val ConsumerMinSessionTimeout = 10
   val ConsumerMaxSessionTimeout = 1000
+  val DefaultRebalanceTimeout = 500
   val DefaultSessionTimeout = 500
   var timer: MockTimer = null
   var groupCoordinator: GroupCoordinator = null
@@ -113,7 +114,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testJoinGroupWrongCoordinator() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(otherGroupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(otherGroupId, memberId, protocolType, protocols)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NOT_COORDINATOR_FOR_GROUP.code, joinGroupErrorCode)
   }
@@ -122,7 +123,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testJoinGroupSessionTimeoutTooSmall() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, ConsumerMinSessionTimeout - 1, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols, sessionTimeout = ConsumerMinSessionTimeout - 1)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.INVALID_SESSION_TIMEOUT.code, joinGroupErrorCode)
   }
@@ -131,14 +132,14 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testJoinGroupSessionTimeoutTooLarge() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, ConsumerMaxSessionTimeout + 1, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols, sessionTimeout = ConsumerMaxSessionTimeout + 1)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.INVALID_SESSION_TIMEOUT.code, joinGroupErrorCode)
   }
 
   @Test
   def testJoinGroupUnknownConsumerNewGroup() {
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.UNKNOWN_MEMBER_ID.code, joinGroupErrorCode)
   }
@@ -148,7 +149,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val groupId = ""
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     assertEquals(Errors.INVALID_GROUP_ID.code, joinGroupResult.errorCode)
   }
 
@@ -156,8 +157,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testValidJoinGroup() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType,
-      protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
   }
@@ -167,12 +167,11 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
     val otherMemberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType,
-      protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
 
     EasyMock.reset(replicaManager)
-    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, DefaultSessionTimeout, "connect", protocols)
+    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, "connect", protocols)
     assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL.code, otherJoinGroupResult.errorCode)
   }
 
@@ -182,12 +181,11 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
     val otherMemberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, List(("range", metadata)))
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, List(("range", metadata)))
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
 
     EasyMock.reset(replicaManager)
-    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, DefaultSessionTimeout, protocolType,
-      List(("roundrobin", metadata)))
+    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, protocolType, List(("roundrobin", metadata)))
     assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL.code, otherJoinGroupResult.errorCode)
   }
 
@@ -196,11 +194,11 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
     val otherMemberId = "memberId"
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
 
     EasyMock.reset(replicaManager)
-    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, protocolType, protocols)
     assertEquals(Errors.UNKNOWN_MEMBER_ID.code, otherJoinGroupResult.errorCode)
   }
 
@@ -223,7 +221,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
     val otherMemberId = "memberId"
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
@@ -242,7 +240,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testHeartbeatRebalanceInProgress() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
@@ -256,7 +254,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testHeartbeatIllegalGeneration() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
@@ -275,7 +273,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testValidHeartbeat() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -295,7 +293,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testSessionTimeout() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -322,7 +320,8 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
     val sessionTimeout = 1000
 
-    val joinGroupResult = joinGroup(groupId, memberId, sessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols,
+      rebalanceTimeout = sessionTimeout, sessionTimeout = sessionTimeout)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -352,7 +351,8 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val tp = new TopicPartition("topic", 0)
     val offset = OffsetAndMetadata(0)
 
-    val joinGroupResult = joinGroup(groupId, memberId, sessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols,
+      rebalanceTimeout = sessionTimeout, sessionTimeout = sessionTimeout)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -376,10 +376,82 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   }
 
   @Test
+  def testSessionTimeoutDuringRebalance() {
+    // create a group with a single member
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols,
+      rebalanceTimeout = 2000, sessionTimeout = 1000)
+    val firstMemberId = firstJoinResult.memberId
+    val firstGenerationId = firstJoinResult.generationId
+    assertEquals(firstMemberId, firstJoinResult.leaderId)
+    assertEquals(Errors.NONE.code, firstJoinResult.errorCode)
+
+    EasyMock.reset(replicaManager)
+    val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]()))
+    assertEquals(Errors.NONE.code, firstSyncResult._2)
+
+    // now have a new member join to trigger a rebalance
+    EasyMock.reset(replicaManager)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
+
+    timer.advanceClock(500)
+
+    EasyMock.reset(replicaManager)
+    var heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId)
+    assertEquals(Errors.REBALANCE_IN_PROGRESS.code, heartbeatResult)
+
+    // letting the session expire should make the member fall out of the group
+    timer.advanceClock(1100)
+
+    EasyMock.reset(replicaManager)
+    heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId)
+    assertEquals(Errors.UNKNOWN_MEMBER_ID.code, heartbeatResult)
+
+    // and the rebalance should complete with only the new member
+    val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
+    assertEquals(Errors.NONE.code, otherJoinResult.errorCode)
+  }
+
+  @Test
+  def testRebalanceCompletesBeforeMemberJoins() {
+    // create a group with a single member
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols,
+      rebalanceTimeout = 1200, sessionTimeout = 1000)
+    val firstMemberId = firstJoinResult.memberId
+    val firstGenerationId = firstJoinResult.generationId
+    assertEquals(firstMemberId, firstJoinResult.leaderId)
+    assertEquals(Errors.NONE.code, firstJoinResult.errorCode)
+
+    EasyMock.reset(replicaManager)
+    val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]()))
+    assertEquals(Errors.NONE.code, firstSyncResult._2)
+
+    // now have a new member join to trigger a rebalance
+    EasyMock.reset(replicaManager)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
+
+    // send a couple heartbeats to keep the member alive while the rebalance finishes
+    timer.advanceClock(500)
+    EasyMock.reset(replicaManager)
+    var heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId)
+    assertEquals(Errors.REBALANCE_IN_PROGRESS.code, heartbeatResult)
+
+    timer.advanceClock(500)
+    EasyMock.reset(replicaManager)
+    heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId)
+    assertEquals(Errors.REBALANCE_IN_PROGRESS.code, heartbeatResult)
+
+    // now timeout the rebalance, which should kick the unjoined member out of the group
+    // and let the rebalance finish with only the new member
+    timer.advanceClock(500)
+    val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
+    assertEquals(Errors.NONE.code, otherJoinResult.errorCode)
+  }
+
+  @Test
   def testSyncGroupEmptyAssignment() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -416,7 +488,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testSyncGroupFromUnknownMember() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
@@ -436,7 +508,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testSyncGroupFromIllegalGeneration() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
@@ -453,8 +525,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     // 1. join and sync with a single member (because we can't immediately join with two members)
     // 2. join and sync with the first member and a new member
 
-    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val firstMemberId = firstJoinResult.memberId
     val firstGenerationId = firstJoinResult.generationId
     assertEquals(firstMemberId, firstJoinResult.leaderId)
@@ -465,11 +536,10 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     assertEquals(Errors.NONE.code, firstSyncResult._2)
 
     EasyMock.reset(replicaManager)
-    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
 
     EasyMock.reset(replicaManager)
-    val joinFuture = sendJoinGroup(groupId, firstMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols)
 
     val joinResult = await(joinFuture, DefaultSessionTimeout+100)
     val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
@@ -484,7 +554,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
     // this shouldn't cause a rebalance since protocol information hasn't changed
     EasyMock.reset(replicaManager)
-    val followerJoinResult = joinGroup(groupId, otherJoinResult.memberId, DefaultSessionTimeout, protocolType, protocols)
+    val followerJoinResult = joinGroup(groupId, otherJoinResult.memberId, protocolType, protocols)
 
     assertEquals(Errors.NONE.code, followerJoinResult.errorCode)
     assertEquals(nextGenerationId, followerJoinResult.generationId)
@@ -492,8 +562,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
   @Test
   def testJoinGroupFromUnchangedLeaderShouldRebalance() {
-    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val firstMemberId = firstJoinResult.memberId
     val firstGenerationId = firstJoinResult.generationId
     assertEquals(firstMemberId, firstJoinResult.leaderId)
@@ -507,7 +576,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     // leader to push new assignments when local metadata changes
 
     EasyMock.reset(replicaManager)
-    val secondJoinResult = joinGroup(groupId, firstMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val secondJoinResult = joinGroup(groupId, firstMemberId, protocolType, protocols)
 
     assertEquals(Errors.NONE.code, secondJoinResult.errorCode)
     assertNotEquals(firstGenerationId, secondJoinResult.generationId)
@@ -519,8 +588,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     // 1. join and sync with a single member (because we can't immediately join with two members)
     // 2. join and sync with the first member and a new member
 
-    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val firstMemberId = firstJoinResult.memberId
     val firstGenerationId = firstJoinResult.generationId
     assertEquals(firstMemberId, firstJoinResult.leaderId)
@@ -531,11 +599,10 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     assertEquals(Errors.NONE.code, firstSyncResult._2)
 
     EasyMock.reset(replicaManager)
-    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
 
     EasyMock.reset(replicaManager)
-    val joinFuture = sendJoinGroup(groupId, firstMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols)
 
     val joinResult = await(joinFuture, DefaultSessionTimeout+100)
     val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
@@ -565,8 +632,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     // 1. join and sync with a single member (because we can't immediately join with two members)
     // 2. join and sync with the first member and a new member
 
-    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val firstMemberId = firstJoinResult.memberId
     val firstGenerationId = firstJoinResult.generationId
     assertEquals(firstMemberId, firstJoinResult.leaderId)
@@ -577,11 +643,10 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     assertEquals(Errors.NONE.code, firstSyncResult._2)
 
     EasyMock.reset(replicaManager)
-    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
 
     EasyMock.reset(replicaManager)
-    val joinFuture = sendJoinGroup(groupId, firstMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols)
 
     val joinResult = await(joinFuture, DefaultSessionTimeout+100)
     val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
@@ -616,8 +681,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     // 1. join and sync with a single member (because we can't immediately join with two members)
     // 2. join and sync with the first member and a new member
 
-    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val firstMemberId = joinGroupResult.memberId
     val firstGenerationId = joinGroupResult.generationId
     assertEquals(firstMemberId, joinGroupResult.leaderId)
@@ -629,11 +693,10 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     assertEquals(Errors.NONE.code, syncGroupErrorCode)
 
     EasyMock.reset(replicaManager)
-    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
 
     EasyMock.reset(replicaManager)
-    val joinFuture = sendJoinGroup(groupId, firstMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols)
 
     val joinResult = await(joinFuture, DefaultSessionTimeout+100)
     val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
@@ -690,7 +753,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val tp = new TopicPartition("topic", 0)
     val offset = OffsetAndMetadata(0)
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -704,8 +767,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   @Test
   def testHeartbeatDuringRebalanceCausesRebalanceInProgress() {
     // First start up a group (with a slightly larger timeout to give us time to heartbeat when the rebalance starts)
-    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val initialGenerationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -713,7 +775,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
     // Then join with a new consumer to trigger a rebalance
     EasyMock.reset(replicaManager)
-    sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout, protocolType, protocols)
+    sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
 
     // We should be in the middle of a rebalance, so the heartbeat should return rebalance in progress
     EasyMock.reset(replicaManager)
@@ -723,7 +785,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
   @Test
   def testGenerationIdIncrementsOnRebalance() {
-    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val initialGenerationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
     val memberId = joinGroupResult.memberId
@@ -736,7 +798,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     assertEquals(Errors.NONE.code, syncGroupErrorCode)
 
     EasyMock.reset(replicaManager)
-    val otherJoinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val otherJoinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val nextGenerationId = otherJoinGroupResult.generationId
     val otherJoinGroupErrorCode = otherJoinGroupResult.errorCode
     assertEquals(2, nextGenerationId)
@@ -763,7 +825,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
     val otherMemberId = "consumerId"
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
 
@@ -776,7 +838,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testValidLeaveGroup() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
@@ -789,7 +851,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   @Test
   def testListGroupsIncludesStableGroups() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
@@ -808,7 +870,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   @Test
   def testListGroupsIncludesRebalancingGroups() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
 
     val (error, groups) = groupCoordinator.handleListGroups()
@@ -835,14 +897,15 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   @Test
   def testDescribeGroupStable() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
 
     EasyMock.reset(replicaManager)
-    val syncGroupResult = syncGroupLeader(groupId, generationId, assignedMemberId,  Map(assignedMemberId -> Array[Byte]()))
+    val syncGroupResult = syncGroupLeader(groupId, generationId, assignedMemberId, Map(assignedMemberId -> Array[Byte]()))
+
     val syncGroupErrorCode = syncGroupResult._2
     assertEquals(Errors.NONE.code, syncGroupErrorCode)
 
@@ -857,7 +920,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   @Test
   def testDescribeGroupRebalancing() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
 
@@ -903,14 +966,15 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
   private def sendJoinGroup(groupId: String,
                             memberId: String,
-                            sessionTimeout: Int,
                             protocolType: String,
-                            protocols: List[(String, Array[Byte])]): Future[JoinGroupResult] = {
+                            protocols: List[(String, Array[Byte])],
+                            rebalanceTimeout: Int = DefaultRebalanceTimeout,
+                            sessionTimeout: Int = DefaultSessionTimeout): Future[JoinGroupResult] = {
     val (responseFuture, responseCallback) = setupJoinGroupCallback
 
     EasyMock.replay(replicaManager)
 
-    groupCoordinator.handleJoinGroup(groupId, memberId, "clientId", "clientHost", sessionTimeout,
+    groupCoordinator.handleJoinGroup(groupId, memberId, "clientId", "clientHost", rebalanceTimeout, sessionTimeout,
       protocolType, protocols, responseCallback)
     responseFuture
   }
@@ -954,29 +1018,32 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
   private def joinGroup(groupId: String,
                         memberId: String,
-                        sessionTimeout: Int,
                         protocolType: String,
-                        protocols: List[(String, Array[Byte])]): JoinGroupResult = {
-    val responseFuture = sendJoinGroup(groupId, memberId, sessionTimeout, protocolType, protocols)
+                        protocols: List[(String, Array[Byte])],
+                        sessionTimeout: Int = DefaultSessionTimeout,
+                        rebalanceTimeout: Int = DefaultRebalanceTimeout): JoinGroupResult = {
+    val responseFuture = sendJoinGroup(groupId, memberId, protocolType, protocols, rebalanceTimeout, sessionTimeout)
     timer.advanceClock(10)
     // should only have to wait as long as session timeout, but allow some extra time in case of an unexpected delay
-    Await.result(responseFuture, Duration(sessionTimeout+100, TimeUnit.MILLISECONDS))
+    Await.result(responseFuture, Duration(rebalanceTimeout + 100, TimeUnit.MILLISECONDS))
   }
 
 
   private def syncGroupFollower(groupId: String,
                                 generationId: Int,
-                                memberId: String): SyncGroupCallbackParams = {
+                                memberId: String,
+                                sessionTimeout: Int = DefaultSessionTimeout): SyncGroupCallbackParams = {
     val responseFuture = sendSyncGroupFollower(groupId, generationId, memberId)
-    Await.result(responseFuture, Duration(DefaultSessionTimeout+100, TimeUnit.MILLISECONDS))
+    Await.result(responseFuture, Duration(sessionTimeout + 100, TimeUnit.MILLISECONDS))
   }
 
   private def syncGroupLeader(groupId: String,
                               generationId: Int,
                               memberId: String,
-                              assignment: Map[String, Array[Byte]]): SyncGroupCallbackParams = {
+                              assignment: Map[String, Array[Byte]],
+                              sessionTimeout: Int = DefaultSessionTimeout): SyncGroupCallbackParams = {
     val responseFuture = sendSyncGroupLeader(groupId, generationId, memberId, assignment)
-    Await.result(responseFuture, Duration(DefaultSessionTimeout+100, TimeUnit.MILLISECONDS))
+    Await.result(responseFuture, Duration(sessionTimeout + 100, TimeUnit.MILLISECONDS))
   }
 
   private def heartbeat(groupId: String,

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/unit/kafka/coordinator/GroupMetadataManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/GroupMetadataManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/GroupMetadataManagerTest.scala
index b9569ca..b4f9ba3 100644
--- a/core/src/test/scala/unit/kafka/coordinator/GroupMetadataManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/GroupMetadataManagerTest.scala
@@ -17,6 +17,7 @@
 
 package kafka.coordinator
 
+import kafka.api.ApiVersion
 import kafka.cluster.Partition
 import kafka.common.{OffsetAndMetadata, Topic}
 import kafka.log.LogAppendInfo
@@ -46,7 +47,8 @@ class GroupMetadataManagerTest {
   val groupId = "foo"
   val groupPartitionId = 0
   val protocolType = "protocolType"
-  val sessionTimeout = 30000
+  val rebalanceTimeout = 60000
+  val sessionTimeout = 10000
 
 
   @Before
@@ -74,9 +76,8 @@ class GroupMetadataManagerTest {
 
     time = new MockTime
     replicaManager = EasyMock.createNiceMock(classOf[ReplicaManager])
-    groupMetadataManager = new GroupMetadataManager(0, offsetConfig, replicaManager, zkUtils, time)
+    groupMetadataManager = new GroupMetadataManager(0, ApiVersion.latestVersion, offsetConfig, replicaManager, zkUtils, time)
     partition = EasyMock.niceMock(classOf[Partition])
-
   }
 
   @After
@@ -119,7 +120,7 @@ class GroupMetadataManagerTest {
     val group = new GroupMetadata(groupId)
     groupMetadataManager.addGroup(group)
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeout,
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeout, sessionTimeout,
       protocolType, List(("protocol", Array[Byte]())))
     member.awaitingJoinCallback = (joinGroupResult: JoinGroupResult) => {}
     group.add(memberId, member)
@@ -337,7 +338,7 @@ class GroupMetadataManagerTest {
     val group = new GroupMetadata(groupId)
     groupMetadataManager.addGroup(group)
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeout,
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeout, sessionTimeout,
       protocolType, List(("protocol", Array[Byte]())))
     member.awaitingJoinCallback = (joinGroupResult: JoinGroupResult) => {}
     group.add(memberId, member)

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/unit/kafka/coordinator/GroupMetadataTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/GroupMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/GroupMetadataTest.scala
index 18dd143..8539340 100644
--- a/core/src/test/scala/unit/kafka/coordinator/GroupMetadataTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/GroupMetadataTest.scala
@@ -27,7 +27,14 @@ import org.scalatest.junit.JUnitSuite
  * Test group state transitions and other GroupMetadata functionality
  */
 class GroupMetadataTest extends JUnitSuite {
-  var group: GroupMetadata = null
+  private val protocolType = "consumer"
+  private val groupId = "groupId"
+  private val clientId = "clientId"
+  private val clientHost = "clientHost"
+  private val rebalanceTimeoutMs = 60000
+  private val sessionTimeoutMs = 10000
+
+  private var group: GroupMetadata = null
 
   @Before
   def setUp() {
@@ -169,30 +176,24 @@ class GroupMetadataTest extends JUnitSuite {
 
   @Test
   def testSelectProtocol() {
-    val protocolType = "consumer"
-    val groupId = "groupId"
-    val clientId = "clientId"
-    val clientHost = "clientHost"
-    val sessionTimeoutMs = 10000
-
     val memberId = "memberId"
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs,
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
       protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))
 
     group.add(memberId, member)
     assertEquals("range", group.selectProtocol)
 
     val otherMemberId = "otherMemberId"
-    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, sessionTimeoutMs,
-      protocolType, List(("roundrobin", Array.empty[Byte]), ("range", Array.empty[Byte])))
+    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("range", Array.empty[Byte])))
 
     group.add(otherMemberId, otherMember)
     // now could be either range or robin since there is no majority preference
     assertTrue(Set("range", "roundrobin")(group.selectProtocol))
 
     val lastMemberId = "lastMemberId"
-    val lastMember = new MemberMetadata(lastMemberId, groupId, clientId, clientHost, sessionTimeoutMs,
-      protocolType, List(("roundrobin", Array.empty[Byte]), ("range", Array.empty[Byte])))
+    val lastMember = new MemberMetadata(lastMemberId, groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("range", Array.empty[Byte])))
 
     group.add(lastMemberId, lastMember)
     // now we should prefer 'roundrobin'
@@ -207,19 +208,13 @@ class GroupMetadataTest extends JUnitSuite {
 
   @Test
   def testSelectProtocolChoosesCompatibleProtocol() {
-    val protocolType = "consumer"
-    val groupId = "groupId"
-    val clientId = "clientId"
-    val clientHost = "clientHost"
-    val sessionTimeoutMs = 10000
-
     val memberId = "memberId"
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs,
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
       protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))
 
     val otherMemberId = "otherMemberId"
-    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, sessionTimeoutMs,
-      protocolType, List(("roundrobin", Array.empty[Byte]), ("blah", Array.empty[Byte])))
+    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("blah", Array.empty[Byte])))
 
     group.add(memberId, member)
     group.add(otherMemberId, otherMember)
@@ -228,18 +223,12 @@ class GroupMetadataTest extends JUnitSuite {
 
   @Test
   def testSupportsProtocols() {
-    val protocolType = "consumer"
-    val groupId = "groupId"
-    val clientId = "clientId"
-    val clientHost = "clientHost"
-    val sessionTimeoutMs = 10000
-
     // by default, the group supports everything
     assertTrue(group.supportsProtocols(Set("roundrobin", "range")))
 
     val memberId = "memberId"
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs,
-      protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))
 
     group.add(memberId, member)
     assertTrue(group.supportsProtocols(Set("roundrobin", "foo")))
@@ -247,8 +236,8 @@ class GroupMetadataTest extends JUnitSuite {
     assertFalse(group.supportsProtocols(Set("foo", "bar")))
 
     val otherMemberId = "otherMemberId"
-    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, sessionTimeoutMs,
-      protocolType, List(("roundrobin", Array.empty[Byte]), ("blah", Array.empty[Byte])))
+    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("blah", Array.empty[Byte])))
 
     group.add(otherMemberId, otherMember)
 
@@ -258,14 +247,8 @@ class GroupMetadataTest extends JUnitSuite {
 
   @Test
   def testInitNextGeneration() {
-    val protocolType = "consumer"
-    val groupId = "groupId"
-    val clientId = "clientId"
-    val clientHost = "clientHost"
-    val sessionTimeoutMs = 10000
     val memberId = "memberId"
-
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs,
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
       protocolType, List(("roundrobin", Array.empty[Byte])))
 
     group.transitionTo(PreparingRebalance)

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/unit/kafka/coordinator/MemberMetadataTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/MemberMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/MemberMetadataTest.scala
index 0688424..257dde7 100644
--- a/core/src/test/scala/unit/kafka/coordinator/MemberMetadataTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/MemberMetadataTest.scala
@@ -28,6 +28,7 @@ class MemberMetadataTest extends JUnitSuite {
   val clientHost = "clientHost"
   val memberId = "memberId"
   val protocolType = "consumer"
+  val rebalanceTimeoutMs = 60000
   val sessionTimeoutMs = 10000
 
 
@@ -35,7 +36,8 @@ class MemberMetadataTest extends JUnitSuite {
   def testMatchesSupportedProtocols {
     val protocols = List(("range", Array.empty[Byte]))
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, protocols)
     assertTrue(member.matches(protocols))
     assertFalse(member.matches(List(("range", Array[Byte](0)))))
     assertFalse(member.matches(List(("roundrobin", Array.empty[Byte]))))
@@ -46,7 +48,8 @@ class MemberMetadataTest extends JUnitSuite {
   def testVoteForPreferredProtocol {
     val protocols = List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, protocols)
     assertEquals("range", member.vote(Set("range", "roundrobin")))
     assertEquals("roundrobin", member.vote(Set("blah", "roundrobin")))
   }
@@ -55,7 +58,8 @@ class MemberMetadataTest extends JUnitSuite {
   def testMetadata {
     val protocols = List(("range", Array[Byte](0)), ("roundrobin", Array[Byte](1)))
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, protocols)
     assertTrue(util.Arrays.equals(Array[Byte](0), member.metadata("range")))
     assertTrue(util.Arrays.equals(Array[Byte](1), member.metadata("roundrobin")))
   }
@@ -64,7 +68,8 @@ class MemberMetadataTest extends JUnitSuite {
   def testMetadataRaisesOnUnsupportedProtocol {
     val protocols = List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, protocols)
     member.metadata("blah")
     fail()
   }
@@ -73,7 +78,8 @@ class MemberMetadataTest extends JUnitSuite {
   def testVoteRaisesOnNoSupportedProtocols {
     val protocols = List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, protocols)
     member.vote(Set("blah"))
     fail()
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala b/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
index d18a060..e4ac4fa 100644
--- a/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
+++ b/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
 class MockTimer extends Timer {
 
   val time = new MockTime
-  private val taskQueue = mutable.PriorityQueue[TimerTaskEntry]()
+  private val taskQueue = mutable.PriorityQueue[TimerTaskEntry]()(Ordering[TimerTaskEntry].reverse)
 
   def add(timerTask: TimerTask) {
     if (timerTask.delayMs <= 0)


Mime
View raw message